From 56085204e6b3a9ced5ccdda5dc4da9e036b12bbc Mon Sep 17 00:00:00 2001 From: Umut Date: Tue, 17 May 2022 10:58:33 +0200 Subject: [PATCH] feat: add support for np.ones_like and np.zeros_like --- concrete/numpy/tracing/tracer.py | 16 ++++++++++++++++ docs/user/howto/numpy_support.md | 2 ++ tests/execution/test_others.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+) diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index b9d552bdf..6a945ec0e 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -210,6 +210,7 @@ class Tracer: np.negative, np.nextafter, np.not_equal, + np.ones_like, np.positive, np.power, np.rad2deg, @@ -234,12 +235,16 @@ class Tracer: np.true_divide, np.trunc, np.where, + np.zeros_like, } SUPPORTED_KWARGS: Dict[Any, Set[str]] = { np.concatenate: { "axis", }, + np.ones_like: { + "dtype", + }, np.reshape: { "newshape", }, @@ -247,6 +252,9 @@ class Tracer: "axis", "keepdims", }, + np.zeros_like: { + "dtype", + }, } @staticmethod @@ -279,6 +287,14 @@ class Tracer: f"Function 'np.{operation.__name__}' is not supported with kwarg '{kwarg}'" ) + if operation == np.ones_like: # pylint: disable=comparison-with-callable + dtype = kwargs.get("dtype", np.int64) + return Tracer(Node.constant(np.ones(args[0].shape, dtype=dtype)), []) + + if operation == np.zeros_like: # pylint: disable=comparison-with-callable + dtype = kwargs.get("dtype", np.int64) + return Tracer(Node.constant(np.zeros(args[0].shape, dtype=dtype)), []) + def sampler(arg: Any) -> Any: if isinstance(arg, tuple): return tuple(sampler(item) for item in arg) diff --git a/docs/user/howto/numpy_support.md b/docs/user/howto/numpy_support.md index 6ba5813cd..e7b718931 100644 --- a/docs/user/howto/numpy_support.md +++ b/docs/user/howto/numpy_support.md @@ -69,6 +69,7 @@ List of supported functions: - negative - nextafter - not_equal +- ones_like - positive - power - rad2deg @@ -93,6 +94,7 @@ List of supported functions: - true_divide - trunc - where +- zeros_like # Shapes diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index eaffcbc8c..11c6b566e 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -447,6 +447,20 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): }, id="np.where(x < 5, x * 3, x)", ), + pytest.param( + lambda x: x + np.ones_like(x), + { + "x": {"status": "encrypted", "range": [0, 10]}, + }, + id="x + np.ones_like(x)", + ), + pytest.param( + lambda x: x + np.zeros_like(x), + { + "x": {"status": "encrypted", "range": [0, 10]}, + }, + id="x + np.zeros_like(x)", + ), ], ) def test_others(function, parameters, helpers):