diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index d5a1a7fc4..493fdf4ce 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -576,6 +576,13 @@ class Tracer: return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape) + def round(self, decimals: int = 0) -> "Tracer": + """ + Trace numpy.ndarray.round(). + """ + + return Tracer._trace_numpy_operation(np.around, self, decimals=decimals) + def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer": """ Trace numpy.ndarray.transpose(). diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index bc3dfea99..34f21baf7 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -489,6 +489,13 @@ def deterministic_unary_function(x): }, id="round(np.sqrt(x))", ), + pytest.param( + lambda x: np.sqrt(x).round().astype(np.int64), + { + "x": {"status": "encrypted", "range": [0, 100]}, + }, + id="np.sqrt(x).round().astype(np.int64)", + ), pytest.param( lambda x: (2.5 * round(np.sqrt(x), ndigits=4)).astype(np.int64), {