feat: add support for rounding

This commit is contained in:
Umut
2022-05-23 12:30:49 +02:00
parent 736742cf28
commit fc6836b766
3 changed files with 34 additions and 1 deletions

View File

@@ -153,6 +153,7 @@ class Tracer:
np.arctan,
np.arctan2,
np.arctanh,
np.around,
np.bitwise_and,
np.bitwise_or,
np.bitwise_xor,
@@ -220,6 +221,7 @@ class Tracer:
np.reshape,
np.right_shift,
np.rint,
np.round_,
np.sign,
np.signbit,
np.sin,
@@ -239,6 +241,9 @@ class Tracer:
}
SUPPORTED_KWARGS: Dict[Any, Set[str]] = {
np.around: {
"decimals",
},
np.concatenate: {
"axis",
},
@@ -248,6 +253,9 @@ class Tracer:
np.reshape: {
"newshape",
},
np.round_: {
"decimals",
},
np.sum: {
"axis",
"keepdims",
@@ -429,6 +437,12 @@ class Tracer:
def __abs__(self):
return Tracer._trace_numpy_operation(np.absolute, self)
def __round__(self, ndigits=None):
if ndigits is None:
return Tracer._trace_numpy_operation(np.around, self).astype(np.int64)
return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits)
def __invert__(self):
return Tracer._trace_numpy_operation(np.invert, self)

View File

@@ -14,6 +14,7 @@ List of supported functions:
- arctan
- arctan2
- arctanh
- around
- bitwise_and
- bitwise_or
- bitwise_xor
@@ -79,6 +80,7 @@ List of supported functions:
- reshape
- right_shift
- rint
- round_
- sign
- signbit
- sin

View File

@@ -482,6 +482,20 @@ def deterministic_unary_function(x):
},
id="cnp.univariate(deterministic_unary_function)(x)",
),
pytest.param(
lambda x: round(np.sqrt(x)),
{
"x": {"status": "encrypted", "range": [0, 100], "shape": ()},
},
id="round(np.sqrt(x))",
),
pytest.param(
lambda x: (2.5 * round(np.sqrt(x), ndigits=4)).astype(np.int64),
{
"x": {"status": "encrypted", "range": [0, 100], "shape": ()},
},
id="(2.5 * round(np.sqrt(x), decimals=4)).astype(np.int64)",
),
],
)
def test_others(function, parameters, helpers):
@@ -492,7 +506,7 @@ def test_others(function, parameters, helpers):
# scalar
# ------
if "shape" not in parameters["x"]:
if "shape" not in parameters["x"] or parameters["x"]["shape"] == ():
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
@@ -510,6 +524,9 @@ def test_others(function, parameters, helpers):
if "shape" not in parameters["x"]:
parameters["x"]["shape"] = (3, 2)
if parameters["x"]["shape"] == ():
return
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()