mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: add support for np.expand_dims
This commit is contained in:
@@ -111,6 +111,9 @@ class GraphConverter:
|
||||
if inputs[0].is_encrypted and inputs[1].is_encrypted:
|
||||
return "only dot product between encrypted and clear is supported"
|
||||
|
||||
elif name == "expand_dims":
|
||||
assert_that(len(inputs) == 1)
|
||||
|
||||
elif name == "index.static":
|
||||
assert_that(len(inputs) == 1)
|
||||
if not inputs[0].is_encrypted:
|
||||
|
||||
@@ -196,7 +196,7 @@ class NodeConverter:
|
||||
elif name == "ones":
|
||||
result = self._convert_ones()
|
||||
|
||||
elif name == "reshape":
|
||||
elif name in ["reshape", "expand_dims"]:
|
||||
result = self._convert_reshape()
|
||||
|
||||
elif name == "subtract":
|
||||
|
||||
@@ -326,6 +326,7 @@ class Node:
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"dot",
|
||||
"expand_dims",
|
||||
"index.static",
|
||||
"matmul",
|
||||
"multiply",
|
||||
|
||||
@@ -217,6 +217,7 @@ class Tracer:
|
||||
np.equal,
|
||||
np.exp,
|
||||
np.exp2,
|
||||
np.expand_dims,
|
||||
np.expm1,
|
||||
np.fabs,
|
||||
np.float_power,
|
||||
@@ -296,6 +297,9 @@ class Tracer:
|
||||
np.concatenate: {
|
||||
"axis",
|
||||
},
|
||||
np.expand_dims: {
|
||||
"axis",
|
||||
},
|
||||
np.ones_like: {
|
||||
"dtype",
|
||||
},
|
||||
|
||||
@@ -80,6 +80,7 @@ Some of these operations are not supported between two encrypted values. A detai
|
||||
* [np.equal](https://numpy.org/doc/stable/reference/generated/numpy.equal.html)
|
||||
* [np.exp](https://numpy.org/doc/stable/reference/generated/numpy.exp.html)
|
||||
* [np.exp2](https://numpy.org/doc/stable/reference/generated/numpy.exp2.html)
|
||||
* [np.expand\_dims](https://numpy.org/doc/stable/reference/generated/numpy.expand\_dims.html)
|
||||
* [np.expm1](https://numpy.org/doc/stable/reference/generated/numpy.expm1.html)
|
||||
* [np.fabs](https://numpy.org/doc/stable/reference/generated/numpy.fabs.html)
|
||||
* [np.float\_power](https://numpy.org/doc/stable/reference/generated/numpy.float\_power.html)
|
||||
|
||||
@@ -586,6 +586,55 @@ def deterministic_unary_function(x):
|
||||
},
|
||||
id="cnp.LookupTable(list(range(32)))[x + y]",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.expand_dims(x, axis=0),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
|
||||
},
|
||||
id="np.expand_dims(x, axis=0)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.expand_dims(x, axis=1),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
|
||||
},
|
||||
id="np.expand_dims(x, axis=1)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.expand_dims(x, axis=2),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
|
||||
},
|
||||
id="np.expand_dims(x, axis=2)",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.expand_dims(x, axis=(0, 1)),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
|
||||
},
|
||||
id="np.expand_dims(x, axis=(0, 1))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.expand_dims(x, axis=(0, 2)),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
|
||||
},
|
||||
id="np.expand_dims(x, axis=(0, 2))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.expand_dims(x, axis=(1, 2)),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
|
||||
},
|
||||
id="np.expand_dims(x, axis=(1, 2))",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: np.expand_dims(x, axis=(0, 1, 2)),
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
|
||||
},
|
||||
id="np.expand_dims(x, axis=(0, 1, 2))",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_others(function, parameters, helpers):
|
||||
|
||||
Reference in New Issue
Block a user