feat: add support for np.expand_dims

This commit is contained in:
Umut
2022-10-14 17:09:06 +02:00
parent ea34aeba09
commit d6ba76318f
6 changed files with 59 additions and 1 deletions

View File

@@ -111,6 +111,9 @@ class GraphConverter:
if inputs[0].is_encrypted and inputs[1].is_encrypted:
return "only dot product between encrypted and clear is supported"
elif name == "expand_dims":
assert_that(len(inputs) == 1)
elif name == "index.static":
assert_that(len(inputs) == 1)
if not inputs[0].is_encrypted:

View File

@@ -196,7 +196,7 @@ class NodeConverter:
elif name == "ones":
result = self._convert_ones()
elif name == "reshape":
elif name in ["reshape", "expand_dims"]:
result = self._convert_reshape()
elif name == "subtract":

View File

@@ -326,6 +326,7 @@ class Node:
"conv2d",
"conv3d",
"dot",
"expand_dims",
"index.static",
"matmul",
"multiply",

View File

@@ -217,6 +217,7 @@ class Tracer:
np.equal,
np.exp,
np.exp2,
np.expand_dims,
np.expm1,
np.fabs,
np.float_power,
@@ -296,6 +297,9 @@ class Tracer:
np.concatenate: {
"axis",
},
np.expand_dims: {
"axis",
},
np.ones_like: {
"dtype",
},

View File

@@ -80,6 +80,7 @@ Some of these operations are not supported between two encrypted values. A detai
* [np.equal](https://numpy.org/doc/stable/reference/generated/numpy.equal.html)
* [np.exp](https://numpy.org/doc/stable/reference/generated/numpy.exp.html)
* [np.exp2](https://numpy.org/doc/stable/reference/generated/numpy.exp2.html)
* [np.expand\_dims](https://numpy.org/doc/stable/reference/generated/numpy.expand\_dims.html)
* [np.expm1](https://numpy.org/doc/stable/reference/generated/numpy.expm1.html)
* [np.fabs](https://numpy.org/doc/stable/reference/generated/numpy.fabs.html)
* [np.float\_power](https://numpy.org/doc/stable/reference/generated/numpy.float\_power.html)

View File

@@ -586,6 +586,55 @@ def deterministic_unary_function(x):
},
id="cnp.LookupTable(list(range(32)))[x + y]",
),
pytest.param(
lambda x: np.expand_dims(x, axis=0),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, axis=0)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=1),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, axis=1)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=2),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, axis=2)",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(0, 1)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, axis=(0, 1))",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(0, 2)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, axis=(0, 2))",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(1, 2)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, axis=(1, 2))",
),
pytest.param(
lambda x: np.expand_dims(x, axis=(0, 1, 2)),
{
"x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)},
},
id="np.expand_dims(x, axis=(0, 1, 2))",
),
],
)
def test_others(function, parameters, helpers):