feat(tracing): enable implicit broadcasting for binary operations

This commit is contained in:
Umut
2021-10-05 17:13:47 +03:00
parent ceb23f93d5
commit 83ea485fe1
3 changed files with 148 additions and 7 deletions

View File

@@ -2,7 +2,7 @@
from copy import deepcopy
from functools import partial
from typing import Callable, Union, cast
from typing import Callable, Optional, Tuple, Union, cast
from ..debugging.custom_assert import custom_assert
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
@@ -212,21 +212,21 @@ def mix_tensor_values_determine_holding_dtype(
isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
)
resulting_shape = broadcast_shapes(value1.shape, value2.shape)
custom_assert(
value1.shape == value2.shape,
resulting_shape is not None,
(
f"Tensors have different shapes which is not supported.\n"
f"Tensors have incompatible shapes which is not supported.\n"
f"value1: {value1.shape}, value2: {value2.shape}"
),
)
assert resulting_shape is not None # this is to make mypy happy
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
shape = value1.shape
if value1.is_encrypted or value2.is_encrypted:
mixed_value = EncryptedTensor(dtype=holding_type, shape=shape)
mixed_value = EncryptedTensor(dtype=holding_type, shape=resulting_shape)
else:
mixed_value = ClearTensor(dtype=holding_type, shape=shape)
mixed_value = ClearTensor(dtype=holding_type, shape=resulting_shape)
return mixed_value
@@ -344,3 +344,37 @@ def is_data_type_compatible_with(
combination = find_type_to_hold_both_lossy(dtype, other)
return other == combination
def broadcast_shapes(shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> Optional[Tuple[int, ...]]:
"""Broadcast two shapes into a single shape.
We are mimicing the exact semantics of broadcasting in numpy.
You can learn more about it here: https://numpy.org/doc/stable/user/theory.broadcasting.html
Args:
shape1 (Tuple[int, ...]): first shape to broadcast
shape2 (Tuple[int, ...]): second shape to broadcast
Returns:
Optional[Tuple[int, ...]]: None if the shapes are not broadcastable else broadcasted shape
"""
result = []
for size1, size2 in zip(shape1[::-1], shape2[::-1]):
if size1 != size2 and size1 != 1 and size2 != 1 and size1 != 0 and size2 != 0:
return None
if size1 == 0 or size2 == 0:
result.append(0)
else:
result.append(max(size1, size2))
if len(result) < len(shape1):
for i in reversed(range(len(shape1) - len(result))):
result.append(shape1[i])
elif len(result) < len(shape2):
for i in reversed(range(len(shape2) - len(result))):
result.append(shape2[i])
return tuple(reversed(result))

View File

@@ -3,6 +3,7 @@ import pytest
from concrete.common.data_types.base import BaseDataType
from concrete.common.data_types.dtypes_helpers import (
broadcast_shapes,
find_type_to_hold_both_lossy,
mix_values_determine_holding_dtype,
value_is_encrypted_scalar_integer,
@@ -236,3 +237,63 @@ def test_fail_mix_values_determine_holding_dtype():
DummyValue(Integer(32, True), True),
DummyValue(Integer(32, True), True),
)
@pytest.mark.parametrize(
"shape1,shape2,expected_shape",
[
pytest.param((), (), ()),
pytest.param((3,), (), (3,)),
pytest.param((3,), (1,), (3,)),
pytest.param((3,), (2,), None),
pytest.param((3,), (3,), (3,)),
pytest.param((2, 3), (), (2, 3)),
pytest.param((2, 3), (1,), (2, 3)),
pytest.param((2, 3), (2,), None),
pytest.param((2, 3), (3,), (2, 3)),
pytest.param((2, 3), (1, 1), (2, 3)),
pytest.param((2, 3), (2, 1), (2, 3)),
pytest.param((2, 3), (3, 1), None),
pytest.param((2, 3), (1, 2), None),
pytest.param((2, 3), (2, 2), None),
pytest.param((2, 3), (3, 2), None),
pytest.param((2, 3), (1, 3), (2, 3)),
pytest.param((2, 3), (2, 3), (2, 3)),
pytest.param((2, 3), (3, 3), None),
pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)),
pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)),
pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)),
# Tests cases taken from `numpy`
# https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351
pytest.param((1, 2), (2,), (1, 2)),
pytest.param((1, 1), (3, 4), (3, 4)),
pytest.param((1, 3), (3, 1), (3, 3)),
pytest.param((1, 0), (0, 0), (0, 0)),
pytest.param((0, 1), (0, 0), (0, 0)),
pytest.param((1, 0), (0, 1), (0, 0)),
pytest.param((1, 1), (0, 0), (0, 0)),
pytest.param((1, 1), (1, 0), (1, 0)),
pytest.param((1, 1), (0, 1), (0, 1)),
pytest.param((), (0,), (0,)),
pytest.param((0,), (0, 0), (0, 0)),
pytest.param((0,), (0, 1), (0, 0)),
pytest.param((1,), (0, 0), (0, 0)),
pytest.param((2,), (0, 0), (0, 0)),
pytest.param((), (0, 0), (0, 0)),
pytest.param((1, 1), (0,), (1, 0)),
pytest.param((1,), (0, 1), (0, 1)),
pytest.param((1,), (1, 0), (1, 0)),
pytest.param((), (1, 0), (1, 0)),
pytest.param((), (0, 1), (0, 1)),
pytest.param((1,), (3,), (3,)),
pytest.param((2,), (3, 2), (3, 2)),
pytest.param((3,), (4,), None),
pytest.param((2, 3), (2,), None),
pytest.param((1, 3, 4), (2, 3, 3), None),
pytest.param((2,), (2, 3), None),
],
)
def test_broadcast_shapes(shape1, shape2, expected_shape):
"""Test function for `broadcast_shapes` helper"""
assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape
assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape

View File

@@ -6,6 +6,7 @@ import networkx as nx
import numpy
import pytest
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import get_printable_graph
@@ -244,6 +245,51 @@ return(%12)
assert get_printable_graph(op_graph, show_data_types=True) == expected
@pytest.mark.parametrize(
"x_shape,y_shape",
[
pytest.param((), ()),
pytest.param((3,), ()),
pytest.param((3,), (1,)),
pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((3,), (3,)),
pytest.param((2, 3), ()),
pytest.param((2, 3), (1,)),
pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (3,)),
pytest.param((2, 3), (1, 1)),
pytest.param((2, 3), (2, 1)),
pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (1, 3)),
pytest.param((2, 3), (2, 3)),
pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 1, 3), (1, 1, 1)),
pytest.param((2, 1, 3), (1, 4, 1)),
pytest.param((2, 1, 3), (2, 4, 3)),
],
)
def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape):
"""Test numpy tracing broadcasted tensors"""
def f(x, y):
return x + y
op_graph = tracing.trace_numpy_function(
f,
{
"x": EncryptedTensor(Integer(3, True), shape=x_shape),
"y": EncryptedTensor(Integer(3, True), shape=y_shape),
},
)
assert op_graph.input_nodes[0].outputs[0].shape == x_shape
assert op_graph.input_nodes[1].outputs[0].shape == y_shape
assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape)
@pytest.mark.parametrize(
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
[