diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index f132466b6..5314e4e6b 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union, cast from ..data_types import Float from ..data_types.base import BaseDataType @@ -17,7 +17,7 @@ from ..representation.intermediate import ( Mul, Sub, ) -from ..values import BaseValue +from ..values import BaseValue, TensorValue class BaseTracer(ABC): @@ -130,6 +130,17 @@ class BaseTracer(ABC): def __neg__(self) -> "BaseTracer": return 0 - self + def __pos__(self) -> "BaseTracer": + # Remark that we don't want to return 'self' since we want the result to be a copy, ie not + # a reference to the same object + return 0 + self + + def __lshift__(self, shift) -> "BaseTracer": + return 2 ** shift * self + + def __rshift__(self, shift) -> "BaseTracer": + return self // 2 ** shift + def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented @@ -171,6 +182,45 @@ class BaseTracer(ABC): # some changes __rmul__ = __mul__ + def unary_ndarray_op(self, op_lambda, op_string: str): + """Trace an operator which maintains the shape, which will thus be replaced by a TLU. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + first_arg_output = self.output + assert_true(isinstance(first_arg_output, TensorValue)) + first_arg_output = cast(TensorValue, first_arg_output) + + out_dtype = first_arg_output.dtype + out_shape = first_arg_output.shape + + generic_function_output_value = TensorValue( + out_dtype, + first_arg_output.is_encrypted, + out_shape, + ) + + traced_computation = GenericFunction( + inputs=[deepcopy(first_arg_output)], + arbitrary_func=op_lambda, + output_value=generic_function_output_value, + op_kind="TLU", + op_name=f"{op_string}", + ) + output_tracer = self.__class__( + [self], + traced_computation=traced_computation, + output_idx=0, + ) + return output_tracer + + def __abs__(self): + return self.unary_ndarray_op(lambda x: x.__abs__(), "__abs__") + + def __invert__(self): + return self.unary_ndarray_op(lambda x: x.__invert__(), "__invert__") + def __getitem__(self, item): traced_computation = IndexConstant(self.output, item) return self.__class__([self], traced_computation, 0) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index c3cdb9843..ca77827f7 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -427,7 +427,7 @@ class NPTracer(BaseTracer): ) traced_computation = GenericFunction( - inputs=[first_arg_output], + inputs=[deepcopy(first_arg_output)], arbitrary_func=numpy.reshape, output_value=generic_function_output_value, op_kind="Memory", @@ -442,6 +442,45 @@ class NPTracer(BaseTracer): ) return output_tracer + def flatten(self, *args: "NPTracer", **kwargs) -> "NPTracer": + """Trace x.flatten. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + assert_true((num_args := len(args)) == 0, f"flatten expect 0 input got {num_args}") + + first_arg_output = self.output + assert_true(isinstance(first_arg_output, TensorValue)) + first_arg_output = cast(TensorValue, first_arg_output) + + flatten_is_fusable = first_arg_output.ndim == 1 + + out_dtype = first_arg_output.dtype + out_shape = (1,) if first_arg_output.is_scalar else (numpy.product(first_arg_output.shape),) + + generic_function_output_value = TensorValue( + out_dtype, + first_arg_output.is_encrypted, + out_shape, + ) + + traced_computation = GenericFunction( + inputs=[deepcopy(first_arg_output)], + arbitrary_func=lambda x: x.flatten(), + output_value=generic_function_output_value, + op_kind="Memory", + op_kwargs=deepcopy(kwargs), + op_name="flatten", + op_attributes={"fusable": flatten_is_fusable}, + ) + output_tracer = self.__class__( + [self], + traced_computation=traced_computation, + output_idx=0, + ) + return output_tracer + def __getitem__(self, item): if isinstance(item, tuple): item = tuple(process_indexing_element(indexing_element) for indexing_element in item) diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index c6373a405..5b248faeb 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -1,5 +1,7 @@ """Test file for numpy tracing""" +# pylint: disable=too-many-lines + import inspect from copy import deepcopy @@ -691,7 +693,11 @@ def subtest_tracing_calls( node_results = op_graph.evaluate({0: input_}) evaluated_output = node_results[output_node] assert isinstance(evaluated_output, type(expected_output)), type(evaluated_output) - assert numpy.array_equal(expected_output, evaluated_output) + if not numpy.array_equal(expected_output, evaluated_output): + print("Wrong result") + print(f"Expected: {expected_output}") + print(f"Got : {evaluated_output}") + raise AssertionError @pytest.mark.parametrize( @@ -831,6 +837,76 @@ def test_tracing_numpy_calls( ], marks=pytest.mark.xfail(strict=True, raises=AssertionError), ), + pytest.param( + lambda x: x.flatten(), + [ + ( + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15), + ) + ], + ), + pytest.param( + lambda x: abs(x), + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5), + ) + ], + ), + pytest.param( + lambda x: +x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5), + ) + ], + ), + pytest.param( + lambda x: -x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + (numpy.arange(15).reshape(3, 5)) * (-1), + ) + ], + ), + pytest.param( + lambda x: ~x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5).__invert__(), + ) + ], + ), + pytest.param( + lambda x: x << 3, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15), + numpy.arange(15) * 8, + ) + ], + ), + pytest.param( + lambda x: x >> 1, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15), + numpy.arange(15) // 2, + ) + ], + ), ], ) def test_tracing_ndarray_calls( @@ -858,3 +934,6 @@ def test_errors_with_generic_function(lambda_f, params): tracing.trace_numpy_function(lambda_f, params) assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value) + + +# pylint: enable=too-many-lines