diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 6a945ec0e..f01034cdd 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -153,6 +153,7 @@ class Tracer: np.arctan, np.arctan2, np.arctanh, + np.around, np.bitwise_and, np.bitwise_or, np.bitwise_xor, @@ -220,6 +221,7 @@ class Tracer: np.reshape, np.right_shift, np.rint, + np.round_, np.sign, np.signbit, np.sin, @@ -239,6 +241,9 @@ class Tracer: } SUPPORTED_KWARGS: Dict[Any, Set[str]] = { + np.around: { + "decimals", + }, np.concatenate: { "axis", }, @@ -248,6 +253,9 @@ class Tracer: np.reshape: { "newshape", }, + np.round_: { + "decimals", + }, np.sum: { "axis", "keepdims", @@ -429,6 +437,12 @@ class Tracer: def __abs__(self): return Tracer._trace_numpy_operation(np.absolute, self) + def __round__(self, ndigits=None): + if ndigits is None: + return Tracer._trace_numpy_operation(np.around, self).astype(np.int64) + + return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits) + def __invert__(self): return Tracer._trace_numpy_operation(np.invert, self) diff --git a/docs/user/howto/numpy_support.md b/docs/user/howto/numpy_support.md index e7b718931..f1c03a7b7 100644 --- a/docs/user/howto/numpy_support.md +++ b/docs/user/howto/numpy_support.md @@ -14,6 +14,7 @@ List of supported functions: - arctan - arctan2 - arctanh +- around - bitwise_and - bitwise_or - bitwise_xor @@ -79,6 +80,7 @@ List of supported functions: - reshape - right_shift - rint +- round_ - sign - signbit - sin diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index d688eae61..6e3e8d160 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -482,6 +482,20 @@ def deterministic_unary_function(x): }, id="cnp.univariate(deterministic_unary_function)(x)", ), + pytest.param( + lambda x: round(np.sqrt(x)), + { + "x": {"status": "encrypted", "range": [0, 100], "shape": ()}, + }, + id="round(np.sqrt(x))", + ), + pytest.param( + lambda x: (2.5 * round(np.sqrt(x), ndigits=4)).astype(np.int64), + { + "x": {"status": "encrypted", "range": [0, 100], "shape": ()}, + }, + id="(2.5 * round(np.sqrt(x), decimals=4)).astype(np.int64)", + ), ], ) def test_others(function, parameters, helpers): @@ -492,7 +506,7 @@ def test_others(function, parameters, helpers): # scalar # ------ - if "shape" not in parameters["x"]: + if "shape" not in parameters["x"] or parameters["x"]["shape"] == (): parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration() @@ -510,6 +524,9 @@ def test_others(function, parameters, helpers): if "shape" not in parameters["x"]: parameters["x"]["shape"] = (3, 2) + if parameters["x"]["shape"] == (): + return + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) configuration = helpers.configuration()