From 60d80793034e6b172f899629d5402bf63f7ab2f6 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 23 Nov 2021 14:20:54 +0100 Subject: [PATCH] feat: add clip support closes #965 closes #983 closes #984 --- concrete/numpy/tracing.py | 18 ++++++++++++++++++ tests/numpy/test_compile.py | 38 +++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index b9968ffca..adc9d5ef0 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -239,6 +239,23 @@ class NPTracer(BaseTracer): ) return output_tracer + def clip(self, *args: Union["NPTracer", Any], **kwargs) -> "NPTracer": + """Trace x.clip. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + sanitized_args = [cast(NPTracer, self._sanitize(arg)) for arg in args] + return self.numpy_clip(self, *sanitized_args, **kwargs) + + def numpy_clip(self, *args: "NPTracer", **kwargs) -> "NPTracer": + """Trace numpy.clip. + + Returns: + NPTracer: The output NPTracer containing the traced function + """ + return self._np_operator(numpy.clip, "clip", 3, *args, **kwargs) + def dot(self, *args: "NPTracer", **kwargs) -> "NPTracer": """Trace x.dot. @@ -565,6 +582,7 @@ class NPTracer(BaseTracer): numpy.transpose: numpy_transpose, numpy.reshape: numpy_reshape, numpy.ravel: numpy_ravel, + numpy.clip: numpy_clip, } diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index eb6c03673..52ac49a00 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -937,6 +937,44 @@ def test_compile_and_run_correctness__for_prog_with_tlu( ([2, 7, 1],), [4, 14, 2], ), + pytest.param( + lambda x: numpy.clip(x, 1, 5), + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), + }, + [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)], + ( + [ + [0, 7], + [6, 1], + [2, 5], + ], + ), + [ + [1, 5], + [5, 1], + [2, 5], + ], + ), + pytest.param( + lambda x: x.clip(1, 5), + { + "x": EncryptedTensor(UnsignedInteger(3), shape=(3, 2)), + }, + [(numpy.random.randint(0, 2 ** 3, size=(3, 2)),) for _ in range(10)], + ( + [ + [0, 7], + [6, 1], + [2, 5], + ], + ), + [ + [1, 5], + [5, 1], + [2, 5], + ], + ), ], ) def test_compile_and_run_tensor_correctness(