mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
feat(tracing): enable implicit broadcasting for binary operations
This commit is contained in:
@@ -3,6 +3,7 @@ import pytest
|
||||
|
||||
from concrete.common.data_types.base import BaseDataType
|
||||
from concrete.common.data_types.dtypes_helpers import (
|
||||
broadcast_shapes,
|
||||
find_type_to_hold_both_lossy,
|
||||
mix_values_determine_holding_dtype,
|
||||
value_is_encrypted_scalar_integer,
|
||||
@@ -236,3 +237,63 @@ def test_fail_mix_values_determine_holding_dtype():
|
||||
DummyValue(Integer(32, True), True),
|
||||
DummyValue(Integer(32, True), True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape1,shape2,expected_shape",
|
||||
[
|
||||
pytest.param((), (), ()),
|
||||
pytest.param((3,), (), (3,)),
|
||||
pytest.param((3,), (1,), (3,)),
|
||||
pytest.param((3,), (2,), None),
|
||||
pytest.param((3,), (3,), (3,)),
|
||||
pytest.param((2, 3), (), (2, 3)),
|
||||
pytest.param((2, 3), (1,), (2, 3)),
|
||||
pytest.param((2, 3), (2,), None),
|
||||
pytest.param((2, 3), (3,), (2, 3)),
|
||||
pytest.param((2, 3), (1, 1), (2, 3)),
|
||||
pytest.param((2, 3), (2, 1), (2, 3)),
|
||||
pytest.param((2, 3), (3, 1), None),
|
||||
pytest.param((2, 3), (1, 2), None),
|
||||
pytest.param((2, 3), (2, 2), None),
|
||||
pytest.param((2, 3), (3, 2), None),
|
||||
pytest.param((2, 3), (1, 3), (2, 3)),
|
||||
pytest.param((2, 3), (2, 3), (2, 3)),
|
||||
pytest.param((2, 3), (3, 3), None),
|
||||
pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)),
|
||||
pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)),
|
||||
pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)),
|
||||
# Tests cases taken from `numpy`
|
||||
# https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351
|
||||
pytest.param((1, 2), (2,), (1, 2)),
|
||||
pytest.param((1, 1), (3, 4), (3, 4)),
|
||||
pytest.param((1, 3), (3, 1), (3, 3)),
|
||||
pytest.param((1, 0), (0, 0), (0, 0)),
|
||||
pytest.param((0, 1), (0, 0), (0, 0)),
|
||||
pytest.param((1, 0), (0, 1), (0, 0)),
|
||||
pytest.param((1, 1), (0, 0), (0, 0)),
|
||||
pytest.param((1, 1), (1, 0), (1, 0)),
|
||||
pytest.param((1, 1), (0, 1), (0, 1)),
|
||||
pytest.param((), (0,), (0,)),
|
||||
pytest.param((0,), (0, 0), (0, 0)),
|
||||
pytest.param((0,), (0, 1), (0, 0)),
|
||||
pytest.param((1,), (0, 0), (0, 0)),
|
||||
pytest.param((2,), (0, 0), (0, 0)),
|
||||
pytest.param((), (0, 0), (0, 0)),
|
||||
pytest.param((1, 1), (0,), (1, 0)),
|
||||
pytest.param((1,), (0, 1), (0, 1)),
|
||||
pytest.param((1,), (1, 0), (1, 0)),
|
||||
pytest.param((), (1, 0), (1, 0)),
|
||||
pytest.param((), (0, 1), (0, 1)),
|
||||
pytest.param((1,), (3,), (3,)),
|
||||
pytest.param((2,), (3, 2), (3, 2)),
|
||||
pytest.param((3,), (4,), None),
|
||||
pytest.param((2, 3), (2,), None),
|
||||
pytest.param((1, 3, 4), (2, 3, 3), None),
|
||||
pytest.param((2,), (2, 3), None),
|
||||
],
|
||||
)
|
||||
def test_broadcast_shapes(shape1, shape2, expected_shape):
|
||||
"""Test function for `broadcast_shapes` helper"""
|
||||
assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape
|
||||
assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape
|
||||
|
||||
@@ -6,6 +6,7 @@ import networkx as nx
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import get_printable_graph
|
||||
@@ -244,6 +245,51 @@ return(%12)
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"x_shape,y_shape",
|
||||
[
|
||||
pytest.param((), ()),
|
||||
pytest.param((3,), ()),
|
||||
pytest.param((3,), (1,)),
|
||||
pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((3,), (3,)),
|
||||
pytest.param((2, 3), ()),
|
||||
pytest.param((2, 3), (1,)),
|
||||
pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (3,)),
|
||||
pytest.param((2, 3), (1, 1)),
|
||||
pytest.param((2, 3), (2, 1)),
|
||||
pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (1, 3)),
|
||||
pytest.param((2, 3), (2, 3)),
|
||||
pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 1, 3), (1, 1, 1)),
|
||||
pytest.param((2, 1, 3), (1, 4, 1)),
|
||||
pytest.param((2, 1, 3), (2, 4, 3)),
|
||||
],
|
||||
)
|
||||
def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape):
|
||||
"""Test numpy tracing broadcasted tensors"""
|
||||
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
f,
|
||||
{
|
||||
"x": EncryptedTensor(Integer(3, True), shape=x_shape),
|
||||
"y": EncryptedTensor(Integer(3, True), shape=y_shape),
|
||||
},
|
||||
)
|
||||
|
||||
assert op_graph.input_nodes[0].outputs[0].shape == x_shape
|
||||
assert op_graph.input_nodes[1].outputs[0].shape == y_shape
|
||||
assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user