mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: add support for np.ones_like and np.zeros_like
This commit is contained in:
@@ -210,6 +210,7 @@ class Tracer:
|
||||
np.negative,
|
||||
np.nextafter,
|
||||
np.not_equal,
|
||||
np.ones_like,
|
||||
np.positive,
|
||||
np.power,
|
||||
np.rad2deg,
|
||||
@@ -234,12 +235,16 @@ class Tracer:
|
||||
np.true_divide,
|
||||
np.trunc,
|
||||
np.where,
|
||||
np.zeros_like,
|
||||
}
|
||||
|
||||
SUPPORTED_KWARGS: Dict[Any, Set[str]] = {
|
||||
np.concatenate: {
|
||||
"axis",
|
||||
},
|
||||
np.ones_like: {
|
||||
"dtype",
|
||||
},
|
||||
np.reshape: {
|
||||
"newshape",
|
||||
},
|
||||
@@ -247,6 +252,9 @@ class Tracer:
|
||||
"axis",
|
||||
"keepdims",
|
||||
},
|
||||
np.zeros_like: {
|
||||
"dtype",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@@ -279,6 +287,14 @@ class Tracer:
|
||||
f"Function 'np.{operation.__name__}' is not supported with kwarg '{kwarg}'"
|
||||
)
|
||||
|
||||
if operation == np.ones_like: # pylint: disable=comparison-with-callable
|
||||
dtype = kwargs.get("dtype", np.int64)
|
||||
return Tracer(Node.constant(np.ones(args[0].shape, dtype=dtype)), [])
|
||||
|
||||
if operation == np.zeros_like: # pylint: disable=comparison-with-callable
|
||||
dtype = kwargs.get("dtype", np.int64)
|
||||
return Tracer(Node.constant(np.zeros(args[0].shape, dtype=dtype)), [])
|
||||
|
||||
def sampler(arg: Any) -> Any:
|
||||
if isinstance(arg, tuple):
|
||||
return tuple(sampler(item) for item in arg)
|
||||
|
||||
@@ -69,6 +69,7 @@ List of supported functions:
|
||||
- negative
|
||||
- nextafter
|
||||
- not_equal
|
||||
- ones_like
|
||||
- positive
|
||||
- power
|
||||
- rad2deg
|
||||
@@ -93,6 +94,7 @@ List of supported functions:
|
||||
- true_divide
|
||||
- trunc
|
||||
- where
|
||||
- zeros_like
|
||||
<!--- gen_supported_ufuncs.py: inject supported operations [END] -->
|
||||
|
||||
# Shapes
|
||||
|
||||
@@ -447,6 +447,20 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
},
|
||||
id="np.where(x < 5, x * 3, x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + np.ones_like(x),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 10]},
|
||||
},
|
||||
id="x + np.ones_like(x)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + np.zeros_like(x),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 10]},
|
||||
},
|
||||
id="x + np.zeros_like(x)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_others(function, parameters, helpers):
|
||||
|
||||
Reference in New Issue
Block a user