From 05cacc8744db0d52c02ce45707eab0abb6ae8e76 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Wed, 17 Nov 2021 11:20:44 +0100 Subject: [PATCH] feat(tracing): let's trace k<>x closes #915 --- concrete/common/tracing/base_tracer.py | 197 +++++++++++++++---------- tests/numpy/test_tracing.py | 28 +++- 2 files changed, 139 insertions(+), 86 deletions(-) diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 5314e4e6b..cfd58a43f 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -110,6 +110,89 @@ class BaseTracer(ABC): return output_tracers + def _helper_for_unary_functions(self, op_lambda: Callable, op_name: str) -> "BaseTracer": + """Trace a unary operator which maintains the shape, which will thus be replaced by a TLU. + + Returns: + BaseTracer: The output NPTracer containing the traced function + """ + first_arg_output = self.output + assert_true(isinstance(first_arg_output, TensorValue)) + first_arg_output = cast(TensorValue, first_arg_output) + + out_dtype = first_arg_output.dtype + out_shape = first_arg_output.shape + + generic_function_output_value = TensorValue( + out_dtype, + first_arg_output.is_encrypted, + out_shape, + ) + + traced_computation = GenericFunction( + inputs=[deepcopy(first_arg_output)], + arbitrary_func=op_lambda, + output_value=generic_function_output_value, + op_kind="TLU", + op_name=f"{op_name}", + ) + output_tracer = self.__class__( + [self], + traced_computation=traced_computation, + output_idx=0, + ) + return output_tracer + + def _helper_for_binary_functions_with_one_cst_input( + self, + lhs: Union["BaseTracer", Any], + rhs: Union["BaseTracer", Any], + op_lambda: Callable, + op_name: str, + output_dtype: Optional[BaseDataType] = None, + ) -> "BaseTracer": + """Trace a binary operator which maintains the shape, when one input is a constant. + + This function is helpful to convert an operation with two inputs, one of which being a + constant, into a TLU, while maintaining the constant somewhere in the graph, eg to simplify + debugging. + + Returns: + BaseTracer: The output NPTracer containing the traced function + """ + if isinstance(lhs, BaseTracer): + if not self._supports_other_operand(rhs): + return NotImplemented + elif isinstance(rhs, BaseTracer): + if not self._supports_other_operand(lhs): + return NotImplemented + + sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]] + + # One of the inputs has to be constant + if not ( + isinstance(sanitized_inputs[0].traced_computation, Constant) + or isinstance(sanitized_inputs[1].traced_computation, Constant) + ): + raise NotImplementedError(f"Can't manage binary operator {op_name}") + + sanitized_input_values = [san_input.output for san_input in sanitized_inputs] + output_value = self._get_mix_values_func()(*sanitized_input_values) + if output_dtype is not None: + output_value.dtype = deepcopy(output_dtype) + + traced_computation = GenericFunction( + inputs=sanitized_input_values, + arbitrary_func=op_lambda, + output_value=output_value, + op_kind="TLU", + op_name=op_name, + ) + + result_tracer = self.__class__(sanitized_inputs, traced_computation, 0) + + return result_tracer + def __add__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): return NotImplemented @@ -135,11 +218,31 @@ class BaseTracer(ABC): # a reference to the same object return 0 + self - def __lshift__(self, shift) -> "BaseTracer": - return 2 ** shift * self + def _lshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer": + return self._helper_for_binary_functions_with_one_cst_input( + lhs, rhs, lambda x, y: x << y, "lshift" + ) - def __rshift__(self, shift) -> "BaseTracer": - return self // 2 ** shift + def __lshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # x << shift + return self._lshift(self, other) + + def __rlshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # cst << x + return self._lshift(other, self) + + def _rshift(self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any]) -> "BaseTracer": + return self._helper_for_binary_functions_with_one_cst_input( + lhs, rhs, lambda x, y: x >> y, "rshift" + ) + + def __rshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # x >> shift + return self._rshift(self, other) + + def __rrshift__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": + # cst >> x + return self._rshift(other, self) def __sub__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": if not self._supports_other_operand(other): @@ -182,94 +285,22 @@ class BaseTracer(ABC): # some changes __rmul__ = __mul__ - def unary_ndarray_op(self, op_lambda, op_string: str): - """Trace an operator which maintains the shape, which will thus be replaced by a TLU. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - first_arg_output = self.output - assert_true(isinstance(first_arg_output, TensorValue)) - first_arg_output = cast(TensorValue, first_arg_output) - - out_dtype = first_arg_output.dtype - out_shape = first_arg_output.shape - - generic_function_output_value = TensorValue( - out_dtype, - first_arg_output.is_encrypted, - out_shape, - ) - - traced_computation = GenericFunction( - inputs=[deepcopy(first_arg_output)], - arbitrary_func=op_lambda, - output_value=generic_function_output_value, - op_kind="TLU", - op_name=f"{op_string}", - ) - output_tracer = self.__class__( - [self], - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - def __abs__(self): - return self.unary_ndarray_op(lambda x: x.__abs__(), "__abs__") + return self._helper_for_unary_functions(lambda x: x.__abs__(), "__abs__") def __invert__(self): - return self.unary_ndarray_op(lambda x: x.__invert__(), "__invert__") + return self._helper_for_unary_functions(lambda x: x.__invert__(), "__invert__") def __getitem__(self, item): traced_computation = IndexConstant(self.output, item) return self.__class__([self], traced_computation, 0) - def _div_common( - self, - lhs: Union["BaseTracer", Any], - rhs: Union["BaseTracer", Any], - div_op: Callable, - op_name: str, - output_dtype: Optional[BaseDataType] = None, - ) -> "BaseTracer": - if isinstance(lhs, BaseTracer): - if not self._supports_other_operand(rhs): - return NotImplemented - elif isinstance(rhs, BaseTracer): - if not self._supports_other_operand(lhs): - return NotImplemented - - sanitized_inputs = [self._sanitize(inp) for inp in [lhs, rhs]] - - # One of the inputs has to be constant - if not ( - isinstance(sanitized_inputs[0].traced_computation, Constant) - or isinstance(sanitized_inputs[1].traced_computation, Constant) - ): - raise NotImplementedError(f"Can't manage binary operator {op_name}") - - sanitized_input_values = [san_input.output for san_input in sanitized_inputs] - output_value = self._get_mix_values_func()(*sanitized_input_values) - if output_dtype is not None: - output_value.dtype = deepcopy(output_dtype) - - traced_computation = GenericFunction( - inputs=sanitized_input_values, - arbitrary_func=div_op, - output_value=output_value, - op_kind="TLU", - op_name=op_name, - ) - - result_tracer = self.__class__(sanitized_inputs, traced_computation, 0) - - return result_tracer - def _truediv( self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] ) -> "BaseTracer": - return self._div_common(lhs, rhs, lambda x, y: x / y, "truediv", Float(64)) + return self._helper_for_binary_functions_with_one_cst_input( + lhs, rhs, lambda x, y: x / y, "truediv", Float(64) + ) def __truediv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": return self._truediv(self, other) @@ -280,7 +311,9 @@ class BaseTracer(ABC): def _floordiv( self, lhs: Union["BaseTracer", Any], rhs: Union["BaseTracer", Any] ) -> "BaseTracer": - return self._div_common(lhs, rhs, lambda x, y: x // y, "floordiv") + return self._helper_for_binary_functions_with_one_cst_input( + lhs, rhs, lambda x, y: x // y, "floordiv" + ) def __floordiv__(self, other: Union["BaseTracer", Any]) -> "BaseTracer": return self._floordiv(self, other) diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 7d5786252..9001555ed 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -701,8 +701,8 @@ def test_tracing_numpy_calls( [ ( EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15), - numpy.arange(15) * 8, + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) * 8, ) ], ), @@ -711,8 +711,28 @@ def test_tracing_numpy_calls( [ ( EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), - numpy.arange(15), - numpy.arange(15) // 2, + numpy.arange(15).reshape(3, 5), + numpy.arange(15).reshape(3, 5) // 2, + ) + ], + ), + pytest.param( + lambda x: 2 << x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5) % 8, + 2 << (numpy.arange(15).reshape(3, 5) % 8), + ) + ], + ), + pytest.param( + lambda x: 256 >> x, + [ + ( + EncryptedTensor(Integer(32, is_signed=True), shape=(3, 5)), + numpy.arange(15).reshape(3, 5) % 8, + 256 >> (numpy.arange(15).reshape(3, 5) % 8), ) ], ),