mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add support for np.ndarray.round() method
This commit is contained in:
@@ -576,6 +576,13 @@ class Tracer:
|
||||
|
||||
return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape)
|
||||
|
||||
def round(self, decimals: int = 0) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.round().
|
||||
"""
|
||||
|
||||
return Tracer._trace_numpy_operation(np.around, self, decimals=decimals)
|
||||
|
||||
def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":
|
||||
"""
|
||||
Trace numpy.ndarray.transpose().
|
||||
|
||||
@@ -489,6 +489,13 @@ def deterministic_unary_function(x):
|
||||
},
|
||||
id="round(np.sqrt(x))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.sqrt(x).round().astype(np.int64),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 100]},
|
||||
},
|
||||
id="np.sqrt(x).round().astype(np.int64)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (2.5 * round(np.sqrt(x), ndigits=4)).astype(np.int64),
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user