diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index bec4b7f57..05d21099f 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -1,7 +1,7 @@ """numpy tracing utilities.""" from copy import deepcopy from functools import partial -from typing import Any, Callable, Dict, List, Optional, Union, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast import numpy from numpy.typing import DTypeLike @@ -69,7 +69,15 @@ class NPTracer(BaseTracer): (len(kwargs) == 0), f"**kwargs are currently not supported for numpy functions, func: {func}", ) - sanitized_args = [self._sanitize(arg) for arg in args] + + # Fixme: Special case to be removed once #772 is done + if func is not numpy.reshape: + sanitized_args = [self._sanitize(arg) for arg in args] + else: + # In numpy.reshape, the second argument is the new shape + sanitized_args = [self._sanitize(args[0]), args[1]] + return tracing_func(self, sanitized_args[0], sanitized_args[1], **kwargs) + return tracing_func(self, *sanitized_args, **kwargs) def astype(self, numpy_dtype: DTypeLike, *args, **kwargs) -> "NPTracer": @@ -267,7 +275,7 @@ class NPTracer(BaseTracer): ) return output_tracer - def transpose(self, *args: "NPTracer", **_kwargs) -> "NPTracer": + def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer": """Trace numpy.transpose. Returns: @@ -285,7 +293,7 @@ class NPTracer(BaseTracer): arbitrary_func=numpy.transpose, output_dtype=first_arg_output.dtype, output_shape=first_arg_output.shape[::-1], - op_kwargs=deepcopy(_kwargs), + op_kwargs=deepcopy(kwargs), op_name="np.transpose", ) output_tracer = self.__class__( @@ -295,7 +303,7 @@ class NPTracer(BaseTracer): ) return output_tracer - def ravel(self, *args: "NPTracer", **_kwargs) -> "NPTracer": + def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer": """Trace numpy.ravel. Returns: @@ -313,7 +321,7 @@ class NPTracer(BaseTracer): arbitrary_func=numpy.ravel, output_dtype=first_arg_output.dtype, output_shape=(numpy.product(first_arg_output.shape),), - op_kwargs=deepcopy(_kwargs), + op_kwargs=deepcopy(kwargs), op_name="np.ravel", ) output_tracer = self.__class__( @@ -323,6 +331,54 @@ class NPTracer(BaseTracer): ) return output_tracer + def reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer": + """Trace numpy.reshape. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + + # FIXME: #772, restore reshape(self, *args, **kwargs) signature when possible, with mypy + # types + + # FIXME: #772, restore + # assert_true((num_args := len(args)) == 2, f"reshape expect 2 input got {num_args}") + # when possible + + assert_true((num_kwargs := len(kwargs)) == 0, f"reshape expect 0 kwargs got {num_kwargs}") + + first_arg_output = arg0.output + assert_true(isinstance(first_arg_output, TensorValue)) + first_arg_output = cast(TensorValue, first_arg_output) + + newshape = deepcopy(arg1) + + if isinstance(newshape, int): + # Make numpy.reshape(x, (170)) and numpy.reshape(x, 170) work, while classical form is + # numpy.reshape(x, (170,)) + newshape = (newshape,) + + # Check shape compatibility + assert_true( + numpy.product(newshape) == numpy.product(first_arg_output.shape), + f"shapes are not compatible (old shape {first_arg_output.shape}, new shape {newshape})", + ) + + traced_computation = GenericFunction( + input_base_value=first_arg_output, + arbitrary_func=numpy.reshape, + output_dtype=first_arg_output.dtype, + output_shape=newshape, + op_kwargs={"newshape": newshape}, + op_name="np.reshape", + ) + output_tracer = self.__class__( + [arg0], + traced_computation=traced_computation, + output_idx=0, + ) + return output_tracer + def __getitem__(self, item): if isinstance(item, tuple): item = tuple(process_indexing_element(indexing_element) for indexing_element in item) @@ -436,6 +492,7 @@ class NPTracer(BaseTracer): FUNC_ROUTING: Dict[Callable, Callable] = { numpy.dot: dot, numpy.transpose: transpose, + numpy.reshape: reshape, numpy.ravel: ravel, } diff --git a/tests/common/representation/test_intermediate.py b/tests/common/representation/test_intermediate.py index ab53b0a4e..221bfbe64 100644 --- a/tests/common/representation/test_intermediate.py +++ b/tests/common/representation/test_intermediate.py @@ -197,6 +197,17 @@ from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, En numpy.arange(15), id="GenericFunction, x ravel", ), + pytest.param( + ir.GenericFunction( + EncryptedTensor(Integer(32, False), shape=(3, 5)), + lambda x: numpy.reshape(x, (5, 3)), + Integer(32, False), + output_shape=(5, 3), + ), + [numpy.arange(15).reshape(3, 5)], + numpy.arange(15).reshape(5, 3), + id="GenericFunction, x reshape", + ), ], ) def test_evaluate( diff --git a/tests/numpy/test_debugging.py b/tests/numpy/test_debugging.py index fa9968680..f79aecab7 100644 --- a/tests/numpy/test_debugging.py +++ b/tests/numpy/test_debugging.py @@ -215,6 +215,7 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): ) +# pylint: disable=line-too-long @pytest.mark.parametrize( "lambda_f,params,ref_graph_str", [ @@ -223,14 +224,55 @@ def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): { "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), }, - "%0 = x\n%1 = np.transpose(%0)\nreturn(%1)\n", + """ +%0 = x # EncryptedTensor, shape=(3, 5)> +%1 = np.transpose(%0) # EncryptedTensor, shape=(5, 3)> +return(%1) +""".lstrip(), # noqa: E501 ), ( lambda x: numpy.ravel(x), { "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), }, - "%0 = x\n%1 = np.ravel(%0)\nreturn(%1)\n", + """ +%0 = x # EncryptedTensor, shape=(3, 5)> +%1 = np.ravel(%0) # EncryptedTensor, shape=(15,)> +return(%1) +""".lstrip(), # noqa: E501 + ), + ( + lambda x: numpy.reshape(x, (5, 3)), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(3, 5)), + }, + """ +%0 = x # EncryptedTensor, shape=(3, 5)> +%1 = np.reshape(%0) # EncryptedTensor, shape=(5, 3)> +return(%1) +""".lstrip(), # noqa: E501 + ), + ( + lambda x: numpy.reshape(x, (170,)), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)), + }, + """ +%0 = x # EncryptedTensor, shape=(17, 10)> +%1 = np.reshape(%0) # EncryptedTensor, shape=(170,)> +return(%1) +""".lstrip(), # noqa: E501 + ), + ( + lambda x: numpy.reshape(x, (170)), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(17, 10)), + }, + """ +%0 = x # EncryptedTensor, shape=(17, 10)> +%1 = np.reshape(%0) # EncryptedTensor, shape=(170,)> +return(%1) +""".lstrip(), # noqa: E501 ), ], ) @@ -240,7 +282,7 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_ draw_graph(graph, show=False) - str_of_the_graph = get_printable_graph(graph) + str_of_the_graph = get_printable_graph(graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" @@ -249,6 +291,9 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_ ) +# pylint: enable=line-too-long + + # Remark that the bitwidths are not particularly correct (eg, a MUL of a 17b times 23b # returning 23b), since they are replaced later by the real bitwidths computed on the # inputset diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index ffb0a9bbc..9059e4186 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -636,6 +636,13 @@ def test_nptracer_unsupported_operands(operation, tracer): (numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])), ], ), + ( + lambda x: numpy.reshape(x, (5, 3)) + 42, + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + [ + (numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)), + ], + ), ], ) def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples): @@ -649,3 +656,22 @@ def test_tracing_generic_function(function_to_trace, input_value, input_and_expe evaluated_output = node_results[output_node] assert isinstance(evaluated_output, type(expected_output)) assert numpy.array_equal(expected_output, evaluated_output) + + +@pytest.mark.parametrize( + "lambda_f,params", + [ + ( + lambda x: numpy.reshape(x, (5, 3)), + { + "x": EncryptedTensor(Integer(2, is_signed=False), shape=(7, 5)), + }, + ), + ], +) +def test_errors_with_generic_function(lambda_f, params): + "Test some errors with generic function" + with pytest.raises(AssertionError) as excinfo: + tracing.trace_numpy_function(lambda_f, params) + + assert "shapes are not compatible (old shape (7, 5), new shape (5, 3))" in str(excinfo.value)