mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(tracing): add support for ndarray functions
ndarray.flatten ndarray.__abs__ ndarray.__neg__, __pos__ and __invert__ ndarray.__rshift__ and __lshift_ refs #751 refs #218
This commit is contained in:
committed by
Benoit Chevallier
parent
baa8a97e98
commit
a2a61a079f
@@ -2,7 +2,7 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union, cast
|
||||
|
||||
from ..data_types import Float
|
||||
from ..data_types.base import BaseDataType
|
||||
@@ -17,7 +17,7 @@ from ..representation.intermediate import (
|
||||
Mul,
|
||||
Sub,
|
||||
)
|
||||
from ..values import BaseValue
|
||||
from ..values import BaseValue, TensorValue
|
||||
|
||||
|
||||
class BaseTracer(ABC):
|
||||
@@ -130,6 +130,17 @@ class BaseTracer(ABC):
|
||||
def __neg__(self) -> "BaseTracer":
|
||||
return 0 - self
|
||||
|
||||
def __pos__(self) -> "BaseTracer":
|
||||
# Remark that we don't want to return 'self' since we want the result to be a copy, ie not
|
||||
# a reference to the same object
|
||||
return 0 + self
|
||||
|
||||
def __lshift__(self, shift) -> "BaseTracer":
|
||||
return 2 ** shift * self
|
||||
|
||||
def __rshift__(self, shift) -> "BaseTracer":
|
||||
return self // 2 ** shift
|
||||
|
||||
def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer":
|
||||
if not self._supports_other_operand(other):
|
||||
return NotImplemented
|
||||
@@ -171,6 +182,45 @@ class BaseTracer(ABC):
|
||||
# some changes
|
||||
__rmul__ = __mul__
|
||||
|
||||
def unary_ndarray_op(self, op_lambda, op_string: str):
|
||||
"""Trace an operator which maintains the shape, which will thus be replaced by a TLU.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
first_arg_output = self.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = first_arg_output.shape
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
arbitrary_func=op_lambda,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="TLU",
|
||||
op_name=f"{op_string}",
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def __abs__(self):
|
||||
return self.unary_ndarray_op(lambda x: x.__abs__(), "__abs__")
|
||||
|
||||
def __invert__(self):
|
||||
return self.unary_ndarray_op(lambda x: x.__invert__(), "__invert__")
|
||||
|
||||
def __getitem__(self, item):
|
||||
traced_computation = IndexConstant(self.output, item)
|
||||
return self.__class__([self], traced_computation, 0)
|
||||
|
||||
@@ -427,7 +427,7 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[first_arg_output],
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
arbitrary_func=numpy.reshape,
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
@@ -442,6 +442,45 @@ class NPTracer(BaseTracer):
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def flatten(self, *args: "NPTracer", **kwargs) -> "NPTracer":
|
||||
"""Trace x.flatten.
|
||||
|
||||
Returns:
|
||||
NPTracer: The output NPTracer containing the traced function
|
||||
"""
|
||||
assert_true((num_args := len(args)) == 0, f"flatten expect 0 input got {num_args}")
|
||||
|
||||
first_arg_output = self.output
|
||||
assert_true(isinstance(first_arg_output, TensorValue))
|
||||
first_arg_output = cast(TensorValue, first_arg_output)
|
||||
|
||||
flatten_is_fusable = first_arg_output.ndim == 1
|
||||
|
||||
out_dtype = first_arg_output.dtype
|
||||
out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),)
|
||||
|
||||
generic_function_output_value = TensorValue(
|
||||
out_dtype,
|
||||
first_arg_output.is_encrypted,
|
||||
out_shape,
|
||||
)
|
||||
|
||||
traced_computation = GenericFunction(
|
||||
inputs=[deepcopy(first_arg_output)],
|
||||
arbitrary_func=lambda x: x.flatten(),
|
||||
output_value=generic_function_output_value,
|
||||
op_kind="Memory",
|
||||
op_kwargs=deepcopy(kwargs),
|
||||
op_name="flatten",
|
||||
op_attributes={"fusable": flatten_is_fusable},
|
||||
)
|
||||
output_tracer = self.__class__(
|
||||
[self],
|
||||
traced_computation=traced_computation,
|
||||
output_idx=0,
|
||||
)
|
||||
return output_tracer
|
||||
|
||||
def __getitem__(self, item):
|
||||
if isinstance(item, tuple):
|
||||
item = tuple(process_indexing_element(indexing_element) for indexing_element in item)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Test file for numpy tracing"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import inspect
|
||||
from copy import deepcopy
|
||||
|
||||
@@ -691,7 +693,11 @@ def subtest_tracing_calls(
|
||||
node_results = op_graph.evaluate({0: input_})
|
||||
evaluated_output = node_results[output_node]
|
||||
assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output)
|
||||
assert numpy.array_equal(expected_output, evaluated_output)
|
||||
if not numpy.array_equal(expected_output, evaluated_output):
|
||||
print("Wrong result")
|
||||
print(f"Expected: {expected_output}")
|
||||
print(f"Got : {evaluated_output}")
|
||||
raise AssertionError
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -831,6 +837,76 @@ def test_tracing_numpy_calls(
|
||||
],
|
||||
marks=pytest.mark.xfail(strict=True, raises=AssertionError),
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x.flatten(),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: abs(x),
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: +x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: -x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
(numpy.arange(15).reshape(3, 5)) * (-1),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: ~x,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15).reshape(3, 5),
|
||||
numpy.arange(15).reshape(3, 5).__invert__(),
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x << 3,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15),
|
||||
numpy.arange(15) * 8,
|
||||
)
|
||||
],
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x >> 1,
|
||||
[
|
||||
(
|
||||
EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)),
|
||||
numpy.arange(15),
|
||||
numpy.arange(15) // 2,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tracing_ndarray_calls(
|
||||
@@ -858,3 +934,6 @@ def test_errors_with_generic_function(lambda_f, params):
|
||||
tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)
|
||||
|
||||
|
||||
# pylint: enable=too-many-lines
|
||||
|
||||
Reference in New Issue
Block a user