diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 97246d612..fbdc330a9 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -45,7 +45,9 @@ class NPTracer(BaseTracer): (len(kwargs) == 0), f"**kwargs are currently not supported for numpy ufuncs, ufunc: {ufunc}", ) - return tracing_func(*input_tracers, **kwargs) + # Create constant tracers when needed + sanitized_input_tracers = [self._sanitize(inp) for inp in input_tracers] + return tracing_func(*sanitized_input_tracers, **kwargs) raise NotImplementedError("Only __call__ method is supported currently") def __array_function__(self, func, _types, args, kwargs): @@ -163,6 +165,61 @@ class NPTracer(BaseTracer): ) return output_tracer + @classmethod + def _binary_operator( + cls, binary_operator, binary_operator_string, *input_tracers: "NPTracer", **kwargs + ) -> "NPTracer": + """Trace a binary operator, supposing one of the input is a constant. + + If no input is a constant, raises an error. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + custom_assert(len(input_tracers) == 2) + + # One of the inputs has to be constant + if isinstance(input_tracers[0].traced_computation, Constant): + in_which_input_is_constant = 0 + baked_constant = deepcopy(input_tracers[0].traced_computation.constant_data) + elif isinstance(input_tracers[1].traced_computation, Constant): + in_which_input_is_constant = 1 + baked_constant = deepcopy(input_tracers[1].traced_computation.constant_data) + else: + raise NotImplementedError(f"Can't manage binary operator {binary_operator}") + + in_which_input_is_variable = 1 - in_which_input_is_constant + + if in_which_input_is_constant == 0: + + def arbitrary_func(x, baked_constant, **kwargs): + return binary_operator(baked_constant, x, **kwargs) + + else: + + def arbitrary_func(x, baked_constant, **kwargs): + return binary_operator(x, baked_constant, **kwargs) + + common_output_dtypes = cls._manage_dtypes(binary_operator, *input_tracers) + custom_assert(len(common_output_dtypes) == 1) + + op_kwargs = deepcopy(kwargs) + op_kwargs["baked_constant"] = baked_constant + + traced_computation = ArbitraryFunction( + input_base_value=input_tracers[in_which_input_is_variable].output, + arbitrary_func=arbitrary_func, + output_dtype=common_output_dtypes[0], + op_kwargs=op_kwargs, + op_name=binary_operator_string, + ) + output_tracer = cls( + (input_tracers[in_which_input_is_variable],), + traced_computation=traced_computation, + output_index=0, + ) + return output_tracer + def dot(self, other_tracer: "NPTracer", **_kwargs) -> "NPTracer": """Trace numpy.dot. @@ -188,8 +245,9 @@ class NPTracer(BaseTracer): ) return output_tracer + # Supported functions are either univariate or bivariate for which one of the two + # sources is a constant LIST_OF_SUPPORTED_UFUNC: List[numpy.ufunc] = [ - # The commented functions are functions require more than a single argument numpy.absolute, # numpy.add, numpy.arccos, @@ -197,7 +255,7 @@ class NPTracer(BaseTracer): numpy.arcsin, numpy.arcsinh, numpy.arctan, - # numpy.arctan2, + numpy.arctan2, numpy.arctanh, # numpy.bitwise_and, # numpy.bitwise_or, @@ -216,7 +274,7 @@ class NPTracer(BaseTracer): numpy.exp2, numpy.expm1, numpy.fabs, - # numpy.float_power, + numpy.float_power, numpy.floor, # numpy.floor_divide, # numpy.fmax, @@ -289,7 +347,7 @@ class NPTracer(BaseTracer): } -def _get_fun(function: numpy.ufunc): +def _get_unary_fun(function: numpy.ufunc): """Wrap _unary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING.""" # We have to access this method to be able to build NPTracer.UFUNC_ROUTING @@ -301,32 +359,41 @@ def _get_fun(function: numpy.ufunc): # pylint: enable=protected-access +def _get_binary_fun(function: numpy.ufunc): + """Wrap _binary_operator in a lambda to populate NPTRACER.UFUNC_ROUTING.""" + + # We have to access this method to be able to build NPTracer.UFUNC_ROUTING + # dynamically + # pylint: disable=protected-access + return lambda *input_tracers, **kwargs: NPTracer._binary_operator( + function, f"np.{function.__name__}", *input_tracers, **kwargs + ) + # pylint: enable=protected-access + + # We are populating NPTracer.UFUNC_ROUTING dynamically -NPTracer.UFUNC_ROUTING = {fun: _get_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC} +NPTracer.UFUNC_ROUTING = { + fun: _get_unary_fun(fun) for fun in NPTracer.LIST_OF_SUPPORTED_UFUNC if fun.nin == 1 +} + +NPTracer.UFUNC_ROUTING[numpy.arctan2] = _get_binary_fun(numpy.arctan2) +NPTracer.UFUNC_ROUTING[numpy.float_power] = _get_binary_fun(numpy.float_power) + # We are adding initial support for `np.array(...)` +,-,* `BaseTracer` # (note that this is not the proper complete handling of these functions) def _on_numpy_add(lhs, rhs): - if isinstance(lhs, BaseTracer): - return lhs.__add__(rhs) - - return rhs.__radd__(lhs) + return lhs.__add__(rhs) def _on_numpy_subtract(lhs, rhs): - if isinstance(lhs, BaseTracer): - return lhs.__sub__(rhs) - - return rhs.__rsub__(lhs) + return lhs.__sub__(rhs) def _on_numpy_multiply(lhs, rhs): - if isinstance(lhs, BaseTracer): - return lhs.__mul__(rhs) - - return rhs.__rmul__(lhs) + return lhs.__mul__(rhs) NPTracer.UFUNC_ROUTING[numpy.add] = _on_numpy_add diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index f10535aab..7f017ed7a 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -146,45 +146,127 @@ def test_tensor_no_fuse(): assert orig_num_nodes == fused_num_nodes -def test_fuse_float_operations_correctness(): - """Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC - with fuse_float_operations.""" +def subtest_fuse_float_unary_operations_correctness(fun): + """Test a unary function with fuse_float_operations.""" - for fun in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC: + # Some manipulation to avoid issues with domain of definitions of functions + if fun == numpy.arccosh: + input_list = [1, 2, 42, 44] + super_fun_list = [complex_fuse_direct_input] + elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: + input_list = [0, 0.1, 0.2] + super_fun_list = [complex_fuse_direct_input] + else: + input_list = [0, 2, 42, 44] + super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input] + + for super_fun in super_fun_list: + + for input_ in input_list: + + def get_function_to_trace(): + return lambda x, y: super_fun(fun, x, y) + + function_to_trace = get_function_to_trace() + + params_names = signature(function_to_trace).parameters.keys() + + op_graph = trace_numpy_function( + function_to_trace, + {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, + ) + orig_num_nodes = len(op_graph.graph) + fuse_float_operations(op_graph) + fused_num_nodes = len(op_graph.graph) + + assert fused_num_nodes < orig_num_nodes + + input_ = numpy.int32(input_) + + num_params = len(params_names) + inputs = (input_,) * num_params + + assert function_to_trace(*inputs) == op_graph(*inputs) + + +def subtest_fuse_float_binary_operations_correctness(fun): + """Test a binary functions with fuse_float_operations, with a constant as a source.""" + + for i in range(4): + + # For bivariate functions: fix one of the inputs + if i == 0: + # With an integer in first position + def get_function_to_trace(): + return lambda x, y: fun(3, x + y).astype(numpy.int32) + + elif i == 1: + # With a float in first position + def get_function_to_trace(): + return lambda x, y: fun(2.3, x + y).astype(numpy.int32) + + elif i == 2: + # With an integer in second position + def get_function_to_trace(): + return lambda x, y: fun(x + y, 4).astype(numpy.int32) - if fun == numpy.arccosh: - input_list = [1, 2, 42, 44] - super_fun_list = [complex_fuse_direct_input] - elif fun in [numpy.arctanh, numpy.arccos, numpy.arcsin, numpy.arctan]: - input_list = [0, 0.1, 0.2] - super_fun_list = [complex_fuse_direct_input] else: - input_list = [0, 2, 42, 44] - super_fun_list = [complex_fuse_direct_input, complex_fuse_indirect_input] + # With a float in second position + def get_function_to_trace(): + return lambda x, y: fun(x + y, 5.7).astype(numpy.int32) - for super_fun in super_fun_list: + input_list = [0, 2, 42, 44] - for input_ in input_list: + for input_ in input_list: - def get_function_to_trace(): - return lambda x, y: super_fun(fun, x, y) + function_to_trace = get_function_to_trace() - function_to_trace = get_function_to_trace() + params_names = signature(function_to_trace).parameters.keys() - params_names = signature(function_to_trace).parameters.keys() + op_graph = trace_numpy_function( + function_to_trace, + {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, + ) + orig_num_nodes = len(op_graph.graph) + fuse_float_operations(op_graph) + fused_num_nodes = len(op_graph.graph) - op_graph = trace_numpy_function( - function_to_trace, - {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, - ) - orig_num_nodes = len(op_graph.graph) - fuse_float_operations(op_graph) - fused_num_nodes = len(op_graph.graph) + assert fused_num_nodes < orig_num_nodes - assert fused_num_nodes < orig_num_nodes + input_ = numpy.int32(input_) - input_ = numpy.int32(input_) + num_params = len(params_names) + inputs = (input_,) * num_params - num_params = len(params_names) - inputs = (input_,) * num_params - assert function_to_trace(*inputs) == op_graph(*inputs) + assert function_to_trace(*inputs) == op_graph(*inputs) + + +def subtest_fuse_float_binary_operations_dont_support_two_variables(fun): + """Test a binary function with fuse_float_operations, with no constant as + a source.""" + + def get_function_to_trace(): + return lambda x, y: fun(x, y).astype(numpy.int32) + + function_to_trace = get_function_to_trace() + + params_names = signature(function_to_trace).parameters.keys() + + with pytest.raises(NotImplementedError, match=r"Can't manage binary operator"): + trace_numpy_function( + function_to_trace, + {param_name: EncryptedScalar(Integer(32, True)) for param_name in params_names}, + ) + + +@pytest.mark.parametrize("fun", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC) +def test_ufunc_operations(fun): + """Test functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC.""" + + if fun.nin == 1: + subtest_fuse_float_unary_operations_correctness(fun) + elif fun.nin == 2: + subtest_fuse_float_binary_operations_correctness(fun) + subtest_fuse_float_binary_operations_dont_support_two_variables(fun) + else: + raise NotImplementedError("Only unary and binary functions are tested for now") diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 9560622df..6a9998d95 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -195,11 +195,11 @@ def test_numpy_tracing_tensors(): %5 = x # EncryptedTensor, shape=(2, 2)> %6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> %7 = Add(5, 6) # EncryptedTensor, shape=(2, 2)> -%8 = Add(7, 4) # EncryptedTensor, shape=(2, 2)> +%8 = Add(4, 7) # EncryptedTensor, shape=(2, 2)> %9 = Sub(3, 8) # EncryptedTensor, shape=(2, 2)> %10 = Sub(9, 2) # EncryptedTensor, shape=(2, 2)> %11 = Mul(10, 1) # EncryptedTensor, shape=(2, 2)> -%12 = Mul(11, 0) # EncryptedTensor, shape=(2, 2)> +%12 = Mul(0, 11) # EncryptedTensor, shape=(2, 2)> return(%12) """.lstrip() @@ -234,11 +234,11 @@ def test_numpy_explicit_tracing_tensors(): %5 = x # EncryptedTensor, shape=(2, 2)> %6 = Constant([[1 2] [3 4]]) # ClearTensor, shape=(2, 2)> %7 = Add(5, 6) # EncryptedTensor, shape=(2, 2)> -%8 = Add(7, 4) # EncryptedTensor, shape=(2, 2)> +%8 = Add(4, 7) # EncryptedTensor, shape=(2, 2)> %9 = Sub(3, 8) # EncryptedTensor, shape=(2, 2)> %10 = Sub(9, 2) # EncryptedTensor, shape=(2, 2)> %11 = Mul(10, 1) # EncryptedTensor, shape=(2, 2)> -%12 = Mul(11, 0) # EncryptedTensor, shape=(2, 2)> +%12 = Mul(0, 11) # EncryptedTensor, shape=(2, 2)> return(%12) """.lstrip() @@ -406,44 +406,43 @@ def test_tracing_astype( ), ], ) -def test_trace_numpy_supported_ufuncs(inputs, expected_output_node): +@pytest.mark.parametrize( + "function_to_trace_def", [f for f in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC if f.nin == 1] +) +def test_trace_numpy_supported_unary_ufuncs(inputs, expected_output_node, function_to_trace_def): """Function to trace supported numpy ufuncs""" - for function_to_trace_def in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC: + # We really need a lambda (because numpy functions are not playing + # nice with inspect.signature), but pylint and flake8 are not happy + # with it + # pylint: disable=unnecessary-lambda,cell-var-from-loop + function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731 + # pylint: enable=unnecessary-lambda,cell-var-from-loop - # We really need a lambda (because numpy functions are not playing - # nice with inspect.signature), but pylint and flake8 are not happy - # with it - # pylint: disable=unnecessary-lambda,cell-var-from-loop - function_to_trace = lambda x: function_to_trace_def(x) # noqa: E731 - # pylint: enable=unnecessary-lambda,cell-var-from-loop + op_graph = tracing.trace_numpy_function(function_to_trace, inputs) - op_graph = tracing.trace_numpy_function(function_to_trace, inputs) + assert len(op_graph.output_nodes) == 1 + assert isinstance(op_graph.output_nodes[0], expected_output_node) + assert len(op_graph.output_nodes[0].outputs) == 1 - assert len(op_graph.output_nodes) == 1 - assert isinstance(op_graph.output_nodes[0], expected_output_node) - assert len(op_graph.output_nodes[0].outputs) == 1 + if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64: + assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64)) + elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL: - if function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_FLOAT64: - assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Float(64)) - elif function_to_trace_def in LIST_OF_UFUNC_WHOSE_OUTPUT_IS_BOOL: + # Boolean function + assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar(Integer(8, is_signed=False)) + else: - # Boolean function - assert op_graph.output_nodes[0].outputs[0] == EncryptedScalar( - Integer(8, is_signed=False) - ) - else: + # Function keeping more or less input type + input_node_type = inputs["x"] - # Function keeping more or less input type - input_node_type = inputs["x"] + expected_output_node_type = deepcopy(input_node_type) - expected_output_node_type = deepcopy(input_node_type) + expected_output_node_type.dtype.bit_width = max( + expected_output_node_type.dtype.bit_width, 32 + ) - expected_output_node_type.dtype.bit_width = max( - expected_output_node_type.dtype.bit_width, 32 - ) - - assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type + assert op_graph.output_nodes[0].outputs[0] == expected_output_node_type def test_trace_numpy_ufuncs_not_supported(): @@ -516,15 +515,13 @@ def test_trace_numpy_dot(function_to_trace, inputs, expected_output_node, expect assert op_graph.output_nodes[0].outputs[0] == expected_output_value -def test_nptracer_get_tracing_func_for_np_functions(): +@pytest.mark.parametrize("np_function", tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC) +def test_nptracer_get_tracing_func_for_np_functions(np_function): """Test NPTracer get_tracing_func_for_np_function""" - for np_function in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC: - expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function] + expected_tracing_func = tracing.NPTracer.UFUNC_ROUTING[np_function] - assert ( - tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func - ) + assert tracing.NPTracer.get_tracing_func_for_np_function(np_function) == expected_tracing_func def test_nptracer_get_tracing_func_for_np_functions_not_implemented():