feat(tracing): enable implicit broadcasting for binary operations

This commit is contained in:
Umut
2021-10-05 17:13:47 +03:00
parent ceb23f93d5
commit 83ea485fe1
3 changed files with 148 additions and 7 deletions

View File

@@ -3,6 +3,7 @@ import pytest
from concrete.common.data_types.base import BaseDataType
from concrete.common.data_types.dtypes_helpers import (
broadcast_shapes,
find_type_to_hold_both_lossy,
mix_values_determine_holding_dtype,
value_is_encrypted_scalar_integer,
@@ -236,3 +237,63 @@ def test_fail_mix_values_determine_holding_dtype():
DummyValue(Integer(32, True), True),
DummyValue(Integer(32, True), True),
)
@pytest.mark.parametrize(
"shape1,shape2,expected_shape",
[
pytest.param((), (), ()),
pytest.param((3,), (), (3,)),
pytest.param((3,), (1,), (3,)),
pytest.param((3,), (2,), None),
pytest.param((3,), (3,), (3,)),
pytest.param((2, 3), (), (2, 3)),
pytest.param((2, 3), (1,), (2, 3)),
pytest.param((2, 3), (2,), None),
pytest.param((2, 3), (3,), (2, 3)),
pytest.param((2, 3), (1, 1), (2, 3)),
pytest.param((2, 3), (2, 1), (2, 3)),
pytest.param((2, 3), (3, 1), None),
pytest.param((2, 3), (1, 2), None),
pytest.param((2, 3), (2, 2), None),
pytest.param((2, 3), (3, 2), None),
pytest.param((2, 3), (1, 3), (2, 3)),
pytest.param((2, 3), (2, 3), (2, 3)),
pytest.param((2, 3), (3, 3), None),
pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)),
pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)),
pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)),
# Tests cases taken from `numpy`
# https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351
pytest.param((1, 2), (2,), (1, 2)),
pytest.param((1, 1), (3, 4), (3, 4)),
pytest.param((1, 3), (3, 1), (3, 3)),
pytest.param((1, 0), (0, 0), (0, 0)),
pytest.param((0, 1), (0, 0), (0, 0)),
pytest.param((1, 0), (0, 1), (0, 0)),
pytest.param((1, 1), (0, 0), (0, 0)),
pytest.param((1, 1), (1, 0), (1, 0)),
pytest.param((1, 1), (0, 1), (0, 1)),
pytest.param((), (0,), (0,)),
pytest.param((0,), (0, 0), (0, 0)),
pytest.param((0,), (0, 1), (0, 0)),
pytest.param((1,), (0, 0), (0, 0)),
pytest.param((2,), (0, 0), (0, 0)),
pytest.param((), (0, 0), (0, 0)),
pytest.param((1, 1), (0,), (1, 0)),
pytest.param((1,), (0, 1), (0, 1)),
pytest.param((1,), (1, 0), (1, 0)),
pytest.param((), (1, 0), (1, 0)),
pytest.param((), (0, 1), (0, 1)),
pytest.param((1,), (3,), (3,)),
pytest.param((2,), (3, 2), (3, 2)),
pytest.param((3,), (4,), None),
pytest.param((2, 3), (2,), None),
pytest.param((1, 3, 4), (2, 3, 3), None),
pytest.param((2,), (2, 3), None),
],
)
def test_broadcast_shapes(shape1, shape2, expected_shape):
"""Test function for `broadcast_shapes` helper"""
assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape
assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape

View File

@@ -6,6 +6,7 @@ import networkx as nx
import numpy
import pytest
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import get_printable_graph
@@ -244,6 +245,51 @@ return(%12)
assert get_printable_graph(op_graph, show_data_types=True) == expected
@pytest.mark.parametrize(
"x_shape,y_shape",
[
pytest.param((), ()),
pytest.param((3,), ()),
pytest.param((3,), (1,)),
pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((3,), (3,)),
pytest.param((2, 3), ()),
pytest.param((2, 3), (1,)),
pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (3,)),
pytest.param((2, 3), (1, 1)),
pytest.param((2, 3), (2, 1)),
pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 3), (1, 3)),
pytest.param((2, 3), (2, 3)),
pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
pytest.param((2, 1, 3), (1, 1, 1)),
pytest.param((2, 1, 3), (1, 4, 1)),
pytest.param((2, 1, 3), (2, 4, 3)),
],
)
def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape):
"""Test numpy tracing broadcasted tensors"""
def f(x, y):
return x + y
op_graph = tracing.trace_numpy_function(
f,
{
"x": EncryptedTensor(Integer(3, True), shape=x_shape),
"y": EncryptedTensor(Integer(3, True), shape=y_shape),
},
)
assert op_graph.input_nodes[0].outputs[0].shape == x_shape
assert op_graph.input_nodes[1].outputs[0].shape == y_shape
assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape)
@pytest.mark.parametrize(
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
[