mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add support for rounding
This commit is contained in:
@@ -153,6 +153,7 @@ class Tracer:
|
||||
np.arctan,
|
||||
np.arctan2,
|
||||
np.arctanh,
|
||||
np.around,
|
||||
np.bitwise_and,
|
||||
np.bitwise_or,
|
||||
np.bitwise_xor,
|
||||
@@ -220,6 +221,7 @@ class Tracer:
|
||||
np.reshape,
|
||||
np.right_shift,
|
||||
np.rint,
|
||||
np.round_,
|
||||
np.sign,
|
||||
np.signbit,
|
||||
np.sin,
|
||||
@@ -239,6 +241,9 @@ class Tracer:
|
||||
}
|
||||
|
||||
SUPPORTED_KWARGS: Dict[Any, Set[str]] = {
|
||||
np.around: {
|
||||
"decimals",
|
||||
},
|
||||
np.concatenate: {
|
||||
"axis",
|
||||
},
|
||||
@@ -248,6 +253,9 @@ class Tracer:
|
||||
np.reshape: {
|
||||
"newshape",
|
||||
},
|
||||
np.round_: {
|
||||
"decimals",
|
||||
},
|
||||
np.sum: {
|
||||
"axis",
|
||||
"keepdims",
|
||||
@@ -429,6 +437,12 @@ class Tracer:
|
||||
def __abs__(self):
|
||||
return Tracer._trace_numpy_operation(np.absolute, self)
|
||||
|
||||
def __round__(self, ndigits=None):
|
||||
if ndigits is None:
|
||||
return Tracer._trace_numpy_operation(np.around, self).astype(np.int64)
|
||||
|
||||
return Tracer._trace_numpy_operation(np.around, self, decimals=ndigits)
|
||||
|
||||
def __invert__(self):
|
||||
return Tracer._trace_numpy_operation(np.invert, self)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ List of supported functions:
|
||||
- arctan
|
||||
- arctan2
|
||||
- arctanh
|
||||
- around
|
||||
- bitwise_and
|
||||
- bitwise_or
|
||||
- bitwise_xor
|
||||
@@ -79,6 +80,7 @@ List of supported functions:
|
||||
- reshape
|
||||
- right_shift
|
||||
- rint
|
||||
- round_
|
||||
- sign
|
||||
- signbit
|
||||
- sin
|
||||
|
||||
@@ -482,6 +482,20 @@ def deterministic_unary_function(x):
|
||||
},
|
||||
id="cnp.univariate(deterministic_unary_function)(x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: round(np.sqrt(x)),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 100], "shape": ()},
|
||||
},
|
||||
id="round(np.sqrt(x))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: (2.5 * round(np.sqrt(x), ndigits=4)).astype(np.int64),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 100], "shape": ()},
|
||||
},
|
||||
id="(2.5 * round(np.sqrt(x), decimals=4)).astype(np.int64)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_others(function, parameters, helpers):
|
||||
@@ -492,7 +506,7 @@ def test_others(function, parameters, helpers):
|
||||
# scalar
|
||||
# ------
|
||||
|
||||
if "shape" not in parameters["x"]:
|
||||
if "shape" not in parameters["x"] or parameters["x"]["shape"] == ():
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@@ -510,6 +524,9 @@ def test_others(function, parameters, helpers):
|
||||
if "shape" not in parameters["x"]:
|
||||
parameters["x"]["shape"] = (3, 2)
|
||||
|
||||
if parameters["x"]["shape"] == ():
|
||||
return
|
||||
|
||||
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
|
||||
configuration = helpers.configuration()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user