mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(tracing): enable implicit broadcasting for binary operations
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Callable, Union, cast
|
||||
from typing import Callable, Optional, Tuple, Union, cast
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..values import BaseValue, ClearTensor, EncryptedTensor, TensorValue
|
||||
@@ -212,21 +212,21 @@ def mix_tensor_values_determine_holding_dtype(
|
||||
isinstance(value2, TensorValue), f"Unsupported value2: {value2}, expected TensorValue"
|
||||
)
|
||||
|
||||
resulting_shape = broadcast_shapes(value1.shape, value2.shape)
|
||||
custom_assert(
|
||||
value1.shape == value2.shape,
|
||||
resulting_shape is not None,
|
||||
(
|
||||
f"Tensors have different shapes which is not supported.\n"
|
||||
f"Tensors have incompatible shapes which is not supported.\n"
|
||||
f"value1: {value1.shape}, value2: {value2.shape}"
|
||||
),
|
||||
)
|
||||
assert resulting_shape is not None # this is to make mypy happy
|
||||
|
||||
holding_type = find_type_to_hold_both_lossy(value1.dtype, value2.dtype)
|
||||
shape = value1.shape
|
||||
|
||||
if value1.is_encrypted or value2.is_encrypted:
|
||||
mixed_value = EncryptedTensor(dtype=holding_type, shape=shape)
|
||||
mixed_value = EncryptedTensor(dtype=holding_type, shape=resulting_shape)
|
||||
else:
|
||||
mixed_value = ClearTensor(dtype=holding_type, shape=shape)
|
||||
mixed_value = ClearTensor(dtype=holding_type, shape=resulting_shape)
|
||||
|
||||
return mixed_value
|
||||
|
||||
@@ -344,3 +344,37 @@ def is_data_type_compatible_with(
|
||||
|
||||
combination = find_type_to_hold_both_lossy(dtype, other)
|
||||
return other == combination
|
||||
|
||||
|
||||
def broadcast_shapes(shape1: Tuple[int, ...], shape2: Tuple[int, ...]) -> Optional[Tuple[int, ...]]:
|
||||
"""Broadcast two shapes into a single shape.
|
||||
|
||||
We are mimicing the exact semantics of broadcasting in numpy.
|
||||
You can learn more about it here: https://numpy.org/doc/stable/user/theory.broadcasting.html
|
||||
|
||||
Args:
|
||||
shape1 (Tuple[int, ...]): first shape to broadcast
|
||||
shape2 (Tuple[int, ...]): second shape to broadcast
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[int, ...]]: None if the shapes are not broadcastable else broadcasted shape
|
||||
"""
|
||||
|
||||
result = []
|
||||
for size1, size2 in zip(shape1[::-1], shape2[::-1]):
|
||||
if size1 != size2 and size1 != 1 and size2 != 1 and size1 != 0 and size2 != 0:
|
||||
return None
|
||||
|
||||
if size1 == 0 or size2 == 0:
|
||||
result.append(0)
|
||||
else:
|
||||
result.append(max(size1, size2))
|
||||
|
||||
if len(result) < len(shape1):
|
||||
for i in reversed(range(len(shape1) - len(result))):
|
||||
result.append(shape1[i])
|
||||
elif len(result) < len(shape2):
|
||||
for i in reversed(range(len(shape2) - len(result))):
|
||||
result.append(shape2[i])
|
||||
|
||||
return tuple(reversed(result))
|
||||
|
||||
@@ -3,6 +3,7 @@ import pytest
|
||||
|
||||
from concrete.common.data_types.base import BaseDataType
|
||||
from concrete.common.data_types.dtypes_helpers import (
|
||||
broadcast_shapes,
|
||||
find_type_to_hold_both_lossy,
|
||||
mix_values_determine_holding_dtype,
|
||||
value_is_encrypted_scalar_integer,
|
||||
@@ -236,3 +237,63 @@ def test_fail_mix_values_determine_holding_dtype():
|
||||
DummyValue(Integer(32, True), True),
|
||||
DummyValue(Integer(32, True), True),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"shape1,shape2,expected_shape",
|
||||
[
|
||||
pytest.param((), (), ()),
|
||||
pytest.param((3,), (), (3,)),
|
||||
pytest.param((3,), (1,), (3,)),
|
||||
pytest.param((3,), (2,), None),
|
||||
pytest.param((3,), (3,), (3,)),
|
||||
pytest.param((2, 3), (), (2, 3)),
|
||||
pytest.param((2, 3), (1,), (2, 3)),
|
||||
pytest.param((2, 3), (2,), None),
|
||||
pytest.param((2, 3), (3,), (2, 3)),
|
||||
pytest.param((2, 3), (1, 1), (2, 3)),
|
||||
pytest.param((2, 3), (2, 1), (2, 3)),
|
||||
pytest.param((2, 3), (3, 1), None),
|
||||
pytest.param((2, 3), (1, 2), None),
|
||||
pytest.param((2, 3), (2, 2), None),
|
||||
pytest.param((2, 3), (3, 2), None),
|
||||
pytest.param((2, 3), (1, 3), (2, 3)),
|
||||
pytest.param((2, 3), (2, 3), (2, 3)),
|
||||
pytest.param((2, 3), (3, 3), None),
|
||||
pytest.param((2, 1, 3), (1, 1, 1), (2, 1, 3)),
|
||||
pytest.param((2, 1, 3), (1, 4, 1), (2, 4, 3)),
|
||||
pytest.param((2, 1, 3), (2, 4, 3), (2, 4, 3)),
|
||||
# Tests cases taken from `numpy`
|
||||
# https://github.com/numpy/numpy/blob/623bc1fae1d47df24e7f1e29321d0c0ba2771ce0/numpy/lib/tests/test_stride_tricks.py#L296-L351
|
||||
pytest.param((1, 2), (2,), (1, 2)),
|
||||
pytest.param((1, 1), (3, 4), (3, 4)),
|
||||
pytest.param((1, 3), (3, 1), (3, 3)),
|
||||
pytest.param((1, 0), (0, 0), (0, 0)),
|
||||
pytest.param((0, 1), (0, 0), (0, 0)),
|
||||
pytest.param((1, 0), (0, 1), (0, 0)),
|
||||
pytest.param((1, 1), (0, 0), (0, 0)),
|
||||
pytest.param((1, 1), (1, 0), (1, 0)),
|
||||
pytest.param((1, 1), (0, 1), (0, 1)),
|
||||
pytest.param((), (0,), (0,)),
|
||||
pytest.param((0,), (0, 0), (0, 0)),
|
||||
pytest.param((0,), (0, 1), (0, 0)),
|
||||
pytest.param((1,), (0, 0), (0, 0)),
|
||||
pytest.param((2,), (0, 0), (0, 0)),
|
||||
pytest.param((), (0, 0), (0, 0)),
|
||||
pytest.param((1, 1), (0,), (1, 0)),
|
||||
pytest.param((1,), (0, 1), (0, 1)),
|
||||
pytest.param((1,), (1, 0), (1, 0)),
|
||||
pytest.param((), (1, 0), (1, 0)),
|
||||
pytest.param((), (0, 1), (0, 1)),
|
||||
pytest.param((1,), (3,), (3,)),
|
||||
pytest.param((2,), (3, 2), (3, 2)),
|
||||
pytest.param((3,), (4,), None),
|
||||
pytest.param((2, 3), (2,), None),
|
||||
pytest.param((1, 3, 4), (2, 3, 3), None),
|
||||
pytest.param((2,), (2, 3), None),
|
||||
],
|
||||
)
|
||||
def test_broadcast_shapes(shape1, shape2, expected_shape):
|
||||
"""Test function for `broadcast_shapes` helper"""
|
||||
assert broadcast_shapes(shape1=shape1, shape2=shape2) == expected_shape
|
||||
assert broadcast_shapes(shape1=shape2, shape2=shape1) == expected_shape
|
||||
|
||||
@@ -6,6 +6,7 @@ import networkx as nx
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import get_printable_graph
|
||||
@@ -244,6 +245,51 @@ return(%12)
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"x_shape,y_shape",
|
||||
[
|
||||
pytest.param((), ()),
|
||||
pytest.param((3,), ()),
|
||||
pytest.param((3,), (1,)),
|
||||
pytest.param((3,), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((3,), (3,)),
|
||||
pytest.param((2, 3), ()),
|
||||
pytest.param((2, 3), (1,)),
|
||||
pytest.param((2, 3), (2,), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (3,)),
|
||||
pytest.param((2, 3), (1, 1)),
|
||||
pytest.param((2, 3), (2, 1)),
|
||||
pytest.param((2, 3), (3, 1), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (1, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (2, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (3, 2), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 3), (1, 3)),
|
||||
pytest.param((2, 3), (2, 3)),
|
||||
pytest.param((2, 3), (3, 3), marks=pytest.mark.xfail(raises=AssertionError, strict=True)),
|
||||
pytest.param((2, 1, 3), (1, 1, 1)),
|
||||
pytest.param((2, 1, 3), (1, 4, 1)),
|
||||
pytest.param((2, 1, 3), (2, 4, 3)),
|
||||
],
|
||||
)
|
||||
def test_numpy_tracing_broadcasted_tensors(x_shape, y_shape):
|
||||
"""Test numpy tracing broadcasted tensors"""
|
||||
|
||||
def f(x, y):
|
||||
return x + y
|
||||
|
||||
op_graph = tracing.trace_numpy_function(
|
||||
f,
|
||||
{
|
||||
"x": EncryptedTensor(Integer(3, True), shape=x_shape),
|
||||
"y": EncryptedTensor(Integer(3, True), shape=y_shape),
|
||||
},
|
||||
)
|
||||
|
||||
assert op_graph.input_nodes[0].outputs[0].shape == x_shape
|
||||
assert op_graph.input_nodes[1].outputs[0].shape == y_shape
|
||||
assert op_graph.output_nodes[0].outputs[0].shape == broadcast_shapes(x_shape, y_shape)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function_to_trace,op_graph_expected_output_type,input_and_expected_output_tuples",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user