From 50a6b06c374b9df716c4f63534bc3a15588ebc76 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Tue, 2 Nov 2021 18:03:41 +0100 Subject: [PATCH] feat(tracing): support x.reshape as well as np.reshape(x, ) closes #701 --- concrete/numpy/tracing.py | 53 ++++++++++++++++++++++++++++++++----- tests/numpy/test_compile.py | 14 ++++++++++ tests/numpy/test_tracing.py | 32 ++++++++++++++++++++++ 3 files changed, 93 insertions(+), 6 deletions(-) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 05d21099f..a3ecc05b9 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -251,7 +251,7 @@ class NPTracer(BaseTracer): ) return output_tracer - def dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer": + def numpy_dot(self, *args: "NPTracer", **_kwargs) -> "NPTracer": """Trace numpy.dot. Returns: @@ -275,7 +275,27 @@ class NPTracer(BaseTracer): ) return output_tracer + def dot(self, *args: "NPTracer", **kwargs) -> "NPTracer": + """Trace x.dot. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + assert len(args) == 1 + arg0 = self._sanitize(args[0]) + assert_true(isinstance(arg0, NPTracer)) + arg0 = cast(NPTracer, arg0) + return self.numpy_dot(self, arg0, **kwargs) + def transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer": + """Trace x.transpose. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self.numpy_transpose(self, *args, **kwargs) + + def numpy_transpose(self, *args: "NPTracer", **kwargs) -> "NPTracer": """Trace numpy.transpose. Returns: @@ -304,6 +324,14 @@ class NPTracer(BaseTracer): return output_tracer def ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer": + """Trace x.ravel. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self.numpy_ravel(self, *args, **kwargs) + + def numpy_ravel(self, *args: "NPTracer", **kwargs) -> "NPTracer": """Trace numpy.ravel. Returns: @@ -331,7 +359,15 @@ class NPTracer(BaseTracer): ) return output_tracer - def reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer": + def reshape(self, arg1: Tuple[Any, ...], **kwargs) -> "NPTracer": + """Trace x.reshape. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self.numpy_reshape(self, arg1, **kwargs) + + def numpy_reshape(self, arg0: "NPTracer", arg1: Tuple[Any, ...], **kwargs) -> "NPTracer": """Trace numpy.reshape. Returns: @@ -388,6 +424,11 @@ class NPTracer(BaseTracer): return BaseTracer.__getitem__(self, item) def __matmul__(self, other): + """Trace numpy.matmul.""" + return self.__array_ufunc__(numpy.matmul, "__call__", self, other) + + def matmul(self, other): + """Trace x.matmul.""" return self.__array_ufunc__(numpy.matmul, "__call__", self, other) # Supported functions are either univariate or bivariate for which one of the two @@ -490,10 +531,10 @@ class NPTracer(BaseTracer): UFUNC_ROUTING: Dict[numpy.ufunc, Callable] = {} FUNC_ROUTING: Dict[Callable, Callable] = { - numpy.dot: dot, - numpy.transpose: transpose, - numpy.reshape: reshape, - numpy.ravel: ravel, + numpy.dot: numpy_dot, + numpy.transpose: numpy_transpose, + numpy.reshape: numpy_reshape, + numpy.ravel: numpy_ravel, } diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 10890fd62..4ce3caf37 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -938,6 +938,20 @@ return(%7) "return(%2)\n" ), ), + pytest.param( + lambda x: x.matmul(numpy.ones(shape=(2, 3), dtype=numpy.uint32)), + {"x": EncryptedTensor(Integer(3, is_signed=False), shape=(3, 2))}, + [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for i in range(10)], + ( + "function you are trying to compile isn't supported for MLIR lowering\n" + "\n" + "%0 = x # EncryptedTensor, shape=(3, 2)>\n" # noqa: E501 + "%1 = Constant([[1 1 1] [1 1 1]]) # ClearTensor, shape=(2, 3)>\n" # noqa: E501 + "%2 = MatMul(%0, %1) # EncryptedTensor, shape=(3, 3)>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ matrix multiplication is not supported for the time being\n" # noqa: E501 + "return(%2)\n" + ), + ), pytest.param( multi_lut, {"x": EncryptedTensor(UnsignedInteger(2), shape=(3, 2))}, diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 9059e4186..f1c7f7836 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -557,6 +557,14 @@ def test_trace_numpy_ufuncs_no_kwargs_no_extra_args(): ir.Dot, EncryptedScalar(Integer(64, True)), ), + pytest.param( + lambda x: x.dot(numpy.array([1, 2, 3, 4, 5], dtype=numpy.int64)), + { + "x": EncryptedTensor(Integer(64, is_signed=True), shape=(5,)), + }, + ir.Dot, + EncryptedScalar(Integer(64, True)), + ), ], ) def test_trace_numpy_dot(function_to_trace, inputs, expected_output_node, expected_output_value): @@ -613,6 +621,7 @@ def test_nptracer_unsupported_operands(operation, tracer): @pytest.mark.parametrize( "function_to_trace,input_value,input_and_expected_output_tuples", [ + # Indirect calls, like numpy.function(x, ...) ( lambda x: numpy.transpose(x), EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), @@ -643,6 +652,29 @@ def test_nptracer_unsupported_operands(operation, tracer): (numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)), ], ), + # Direct calls, like x.function(...) + ( + lambda x: x.transpose() + 42, + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + [ + (numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(3, 5).transpose()), + ], + ), + ( + lambda x: x.ravel(), + EncryptedTensor(Integer(4, is_signed=False), shape=(2, 2)), + [ + (numpy.arange(4), numpy.array([0, 1, 2, 3])), + (numpy.arange(4).reshape(2, 2), numpy.array([0, 1, 2, 3])), + ], + ), + ( + lambda x: x.reshape((5, 3)) + 42, + EncryptedTensor(Integer(32, is_signed=False), shape=(3, 5)), + [ + (numpy.arange(15).reshape(3, 5), numpy.arange(42, 57).reshape(5, 3)), + ], + ), ], ) def test_tracing_generic_function(function_to_trace, input_value, input_and_expected_output_tuples):