From f1ed07d5809aaed2f470e5587e56bd7b035643f8 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 22 Nov 2021 18:14:23 +0100 Subject: [PATCH] refactor(tracing): add generic tracing functions for generic functions - remove nearly duplicated code when tracing GenericFunction for unary and binary operators refs #965 --- concrete/numpy/tracing.py | 89 ++++++------------- .../common/optimization/test_float_fusing.py | 5 +- 2 files changed, 32 insertions(+), 62 deletions(-) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 7c4144413..b9968ffca 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -153,79 +153,46 @@ class NPTracer(BaseTracer): return self.__class__([], NPConstant(constant_data), 0) @classmethod - def _unary_operator( - cls, unary_operator, unary_operator_string, *input_tracers: "NPTracer", **kwargs + def _np_operator( + cls, + numpy_operator, + numpy_operator_string, + numpy_operator_nin, + *input_tracers: "NPTracer", + **kwargs, ) -> "NPTracer": - """Trace an unary operator. + """Trace a numpy operator. Returns: NPTracer: The output NPTracer containing the traced function """ - assert_true(len(input_tracers) == 1) + assert_true(len(input_tracers) == numpy_operator_nin) + common_output_dtypes_and_shapes = ( get_numpy_function_output_dtype_and_shape_from_input_tracers( - unary_operator, + numpy_operator, *input_tracers, ) ) assert_true(len(common_output_dtypes_and_shapes) == 1) - output_dtype, output_shape = common_output_dtypes_and_shapes[0] - - generic_function_output_value = TensorValue( - output_dtype, input_tracers[0].output.is_encrypted, output_shape + variable_input_indices = [ + idx + for idx, pred in enumerate(input_tracers) + if not isinstance(pred.traced_computation, Constant) + ] + assert_true( + (non_constant_pred_count := len(variable_input_indices)) == 1, + f"Can only have 1 non constant predecessor in {cls._np_operator.__name__}, " + f"got {non_constant_pred_count} for operator {numpy_operator}", ) - traced_computation = GenericFunction( - inputs=[input_tracers[0].output], - arbitrary_func=unary_operator, - output_value=generic_function_output_value, - op_kind="TLU", - op_kwargs=deepcopy(kwargs), - op_name=unary_operator_string, - ) - output_tracer = cls( - input_tracers, - traced_computation=traced_computation, - output_idx=0, - ) - return output_tracer - - @classmethod - def _binary_operator( - cls, binary_operator, binary_operator_string, *input_tracers: "NPTracer", **kwargs - ) -> "NPTracer": - """Trace a binary operator, supposing one of the input is a constant. - - If no input is a constant, raises an error. - - Returns: - NPTracer: The output NPTracer containing the traced function - """ - assert_true(len(input_tracers) == 2) - - # One of the inputs has to be constant - if isinstance(input_tracers[0].traced_computation, Constant): - in_which_input_is_constant = 0 - elif isinstance(input_tracers[1].traced_computation, Constant): - in_which_input_is_constant = 1 - else: - raise NotImplementedError(f"Can't manage binary operator {binary_operator}") - - in_which_input_is_variable = 1 - in_which_input_is_constant - common_output_dtypes_and_shapes = ( - get_numpy_function_output_dtype_and_shape_from_input_tracers( - binary_operator, - *input_tracers, - ) - ) - assert_true(len(common_output_dtypes_and_shapes) == 1) - + variable_input_idx = variable_input_indices[0] output_dtype, output_shape = common_output_dtypes_and_shapes[0] generic_function_output_value = TensorValue( output_dtype, - input_tracers[in_which_input_is_variable].output.is_encrypted, + input_tracers[variable_input_idx].output.is_encrypted, output_shape, ) @@ -233,11 +200,11 @@ class NPTracer(BaseTracer): traced_computation = GenericFunction( inputs=[input_tracer.output for input_tracer in input_tracers], - arbitrary_func=binary_operator, + arbitrary_func=numpy_operator, output_value=generic_function_output_value, op_kind="TLU", op_kwargs=op_kwargs, - op_name=binary_operator_string, + op_name=numpy_operator_string, ) output_tracer = cls( input_tracers, @@ -607,8 +574,8 @@ def _get_unary_fun(function: numpy.ufunc): # We have to access this method to be able to build NPTracer.UFUNC_ROUTING # dynamically # pylint: disable=protected-access - return lambda *input_tracers, **kwargs: NPTracer._unary_operator( - function, f"{function.__name__}", *input_tracers, **kwargs + return lambda *input_tracers, **kwargs: NPTracer._np_operator( + function, f"{function.__name__}", 1, *input_tracers, **kwargs ) # pylint: enable=protected-access @@ -619,8 +586,8 @@ def _get_binary_fun(function: numpy.ufunc): # We have to access this method to be able to build NPTracer.UFUNC_ROUTING # dynamically # pylint: disable=protected-access - return lambda *input_tracers, **kwargs: NPTracer._binary_operator( - function, f"{function.__name__}", *input_tracers, **kwargs + return lambda *input_tracers, **kwargs: NPTracer._np_operator( + function, f"{function.__name__}", 2, *input_tracers, **kwargs ) # pylint: enable=protected-access diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 9d491c949..0153deaa4 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -687,7 +687,10 @@ def subtest_fuse_float_binary_operations_dont_support_two_variables(fun, tensor_ params_names = signature(function_to_trace).parameters.keys() - with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"): + with pytest.raises( + AssertionError, + match=r"Can only have 1 non constant predecessor in _np_operator, got 2 for operator", + ): trace_numpy_function( function_to_trace, {