From 83ea485fe1959ea3097e3ddbff99798ef049faa9 Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 5 Oct 2021 17:13:47 +0300 Subject: [PATCH] feat(tracing): enable implicit broadcasting for binary operations --- concrete/common/data_types/dtypes_helpers.py | 48 ++++++++++++--- .../common/data_types/test_dtypes_helpers.py | 61 +++++++++++++++++++ tests/numpy/test_tracing.py | 46 ++++++++++++++ 3 files changed, 148 insertions(+), 7 deletions(-) diff --git a/concrete/common/data_types/dtypes_helpers.py b/concrete/common/data_types/dtypes_helpers.py index 00482d79e..87cf75988 100644 --- a/concrete/common/data_types/dtypes_helpers.py +++ b/concrete/common/data_types/dtypes_helpers.py @@ -2,7 +2,7 @@ from copy import deepcopy from functools import partial -from typing import Callable, Union, cast +from typing import Callable, Optional, Tuple, Union, cast from ..debugging.custom_assert import custom_assert from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue @@ -212,21 +212,21 @@ def mix_tensor_values_determine_holding_dtype( isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue" ) + resulting_shape = broadcast_shapes(value1.shape, value2.shape) custom_assert( - value1.shape == value2.shape, + resulting_shape is not None, ( - f"Tensors have different shapes which is not supported.\n" + f"Tensors have incompatible shapes which is not supported.\n" f"value1: {value1.shape}, value2: {value2.shape}" ), ) + assert resulting_shape is not None # this is to make mypy happy holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype) - shape = value1.shape - if value1.is_encrypted or value2.is_encrypted: - mixed_value = EncryptedTensor(dtype=holding_type, shape=shape) + mixed_value = EncryptedTensor(dtype=holding_type, shape=resulting_shape) else: - mixed_value = ClearTensor(dtype=holding_type, shape=shape) + mixed_value = ClearTensor(dtype=holding_type, shape=resulting_shape) return mixed_value @@ -344,3 +344,37 @@ def is_data_type_compatible_with( combination = find_type_to_hold_both_lossy(dtype, other) return other == combination + + +def broadcast_shapes(shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> Optional[Tuple[int, ...]]: + """Broadcast two shapes into a single shape. + + We are mimicing the exact semantics of broadcasting in numpy. + You can learn more about it here: https://numpy.org/doc/stable/user/theory.broadcasting.html + + Args: + shape1 (Tuple[int, ...]): first shape to broadcast + shape2 (Tuple[int, ...]): second shape to broadcast + + Returns: + Optional[Tuple[int, ...]]: None if the shapes are not broadcastable else broadcasted shape + """ + + result = [] + for size1, size2 in zip(shape1[::-1], shape2[::-1]): + if size1 != size2 and size1 != 1 and size2 != 1 and size1 != 0 and size2 != 0: + return None + + if size1 == 0 or size2 == 0: + result.append(0) + else: + result.append(max(size1, size2)) + + if len(result) < len(shape1): + for i in reversed(range(len(shape1) - len(result))): + result.append(shape1[i]) + elif len(result) < len(shape2): + for i in reversed(range(len(shape2) - len(result))): + result.append(shape2[i]) + + return tuple(reversed(result)) diff --git a/tests/common/data_types/test_dtypes_helpers.py b/tests/common/data_types/test_dtypes_helpers.py index 5a0d52c9a..d853d2b26 100644 --- a/tests/common/data_types/test_dtypes_helpers.py +++ b/tests/common/data_types/test_dtypes_helpers.py @@ -3,6 +3,7 @@ import pytest from concrete.common.data_types.base import BaseDataType from concrete.common.data_types.dtypes_helpers import ( + broadcast_shapes, find_type_to_hold_both_lossy, mix_values_determine_holding_dtype, value_is_encrypted_scalar_integer, @@ -236,3 +237,63 @@ def test_fail_mix_values_determine_holding_dtype(): DummyValue(Integer(32, True), True), DummyValue(Integer(32, True), True), ) + + +@pytest.mark.parametrize( + "shape1,shape2,expected_shape", + [ + pytest.param((), (), ()), + pytest.param((3,), (), (3,)), + pytest.param((3,), (1,), (3,)), + pytest.param((3,), (2,), None), + pytest.param((3,), (3,), (3,)), + pytest.param((2, 3), (), (2, 3)), + pytest.param((2, 3), (1,), (2, 3)), + pytest.param((2, 3), (2,), None), + pytest.param((2, 3), (3,), (2, 3)), + pytest.param((2, 3), (1, 1), (2, 3)), + pytest.param((2, 3), (2, 1), (2, 3)), + pytest.param((2, 3), (3, 1), None), + pytest.param((2, 3), (1, 2), None), + pytest.param((2, 3), (2, 2), None), + pytest.param((2, 3), (3, 2), None), + pytest.param((2, 3), (1, 3), (2, 3)), + pytest.param((2, 3), (2, 3), (2, 3)), + pytest.param((2, 3), (3, 3), None), + pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)), + pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)), + pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)), + # Tests cases taken from `numpy` + # https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351 + pytest.param((1, 2), (2,), (1, 2)), + pytest.param((1, 1), (3, 4), (3, 4)), + pytest.param((1, 3), (3, 1), (3, 3)), + pytest.param((1, 0), (0, 0), (0, 0)), + pytest.param((0, 1), (0, 0), (0, 0)), + pytest.param((1, 0), (0, 1), (0, 0)), + pytest.param((1, 1), (0, 0), (0, 0)), + pytest.param((1, 1), (1, 0), (1, 0)), + pytest.param((1, 1), (0, 1), (0, 1)), + pytest.param((), (0,), (0,)), + pytest.param((0,), (0, 0), (0, 0)), + pytest.param((0,), (0, 1), (0, 0)), + pytest.param((1,), (0, 0), (0, 0)), + pytest.param((2,), (0, 0), (0, 0)), + pytest.param((), (0, 0), (0, 0)), + pytest.param((1, 1), (0,), (1, 0)), + pytest.param((1,), (0, 1), (0, 1)), + pytest.param((1,), (1, 0), (1, 0)), + pytest.param((), (1, 0), (1, 0)), + pytest.param((), (0, 1), (0, 1)), + pytest.param((1,), (3,), (3,)), + pytest.param((2,), (3, 2), (3, 2)), + pytest.param((3,), (4,), None), + pytest.param((2, 3), (2,), None), + pytest.param((1, 3, 4), (2, 3, 3), None), + pytest.param((2,), (2, 3), None), + ], +) +def test_broadcast_shapes(shape1, shape2, expected_shape): + """Test function for `broadcast_shapes` helper""" + assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape + assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index f3df7e072..9560622df 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -6,6 +6,7 @@ import networkx as nx import numpy import pytest +from concrete.common.data_types.dtypes_helpers import broadcast_shapes from concrete.common.data_types.floats import Float from concrete.common.data_types.integers import Integer from concrete.common.debugging import get_printable_graph @@ -244,6 +245,51 @@ return(%12) assert get_printable_graph(op_graph, show_data_types=True) == expected +@pytest.mark.parametrize( + "x_shape,y_shape", + [ + pytest.param((), ()), + pytest.param((3,), ()), + pytest.param((3,), (1,)), + pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), + pytest.param((3,), (3,)), + pytest.param((2, 3), ()), + pytest.param((2, 3), (1,)), + pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), + pytest.param((2, 3), (3,)), + pytest.param((2, 3), (1, 1)), + pytest.param((2, 3), (2, 1)), + pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), + pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), + pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), + pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), + pytest.param((2, 3), (1, 3)), + pytest.param((2, 3), (2, 3)), + pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)), + pytest.param((2, 1, 3), (1, 1, 1)), + pytest.param((2, 1, 3), (1, 4, 1)), + pytest.param((2, 1, 3), (2, 4, 3)), + ], +) +def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape): + """Test numpy tracing broadcasted tensors""" + + def f(x, y): + return x + y + + op_graph = tracing.trace_numpy_function( + f, + { + "x": EncryptedTensor(Integer(3, True), shape=x_shape), + "y": EncryptedTensor(Integer(3, True), shape=y_shape), + }, + ) + + assert op_graph.input_nodes[0].outputs[0].shape == x_shape + assert op_graph.input_nodes[1].outputs[0].shape == y_shape + assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape) + + @pytest.mark.parametrize( "function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples", [