feat: add support for np.ndarray.round() method

This commit is contained in:
Umut
2022-06-14 10:20:13 +02:00
parent ada4369283
commit 11819fcf2f
2 changed files with 14 additions and 0 deletions

View File

@@ -576,6 +576,13 @@ class Tracer:
return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape)
def round(self, decimals: int = 0) -> "Tracer":
"""
Trace numpy.ndarray.round().
"""
return Tracer._trace_numpy_operation(np.around, self, decimals=decimals)
def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer":
"""
Trace numpy.ndarray.transpose().

View File

@@ -489,6 +489,13 @@ def deterministic_unary_function(x):
},
id="round(np.sqrt(x))",
),
pytest.param(
lambda x: np.sqrt(x).round().astype(np.int64),
{
"x": {"status": "encrypted", "range": [0, 100]},
},
id="np.sqrt(x).round().astype(np.int64)",
),
pytest.param(
lambda x: (2.5 * round(np.sqrt(x), ndigits=4)).astype(np.int64),
{