From d6ba76318f0bf6ac39d3e45222296946cb7c0d83 Mon Sep 17 00:00:00 2001 From: Umut Date: Fri, 14 Oct 2022 17:09:06 +0200 Subject: [PATCH] feat: add support for np.expand_dims --- concrete/numpy/mlir/graph_converter.py | 3 ++ concrete/numpy/mlir/node_converter.py | 2 +- concrete/numpy/representation/node.py | 1 + concrete/numpy/tracing/tracer.py | 4 +++ docs/getting-started/compatibility.md | 1 + tests/execution/test_others.py | 49 ++++++++++++++++++++++++++ 6 files changed, 59 insertions(+), 1 deletion(-) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 6d20a6aed..53880e6c5 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -111,6 +111,9 @@ class GraphConverter: if inputs[0].is_encrypted and inputs[1].is_encrypted: return "only dot product between encrypted and clear is supported" + elif name == "expand_dims": + assert_that(len(inputs) == 1) + elif name == "index.static": assert_that(len(inputs) == 1) if not inputs[0].is_encrypted: diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 505cf9eac..2304cb980 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -196,7 +196,7 @@ class NodeConverter: elif name == "ones": result = self._convert_ones() - elif name == "reshape": + elif name in ["reshape", "expand_dims"]: result = self._convert_reshape() elif name == "subtract": diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index e9423c15e..fd64b8eb6 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -326,6 +326,7 @@ class Node: "conv2d", "conv3d", "dot", + "expand_dims", "index.static", "matmul", "multiply", diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 31a6a44a5..d7f98bb51 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -217,6 +217,7 @@ class Tracer: np.equal, np.exp, np.exp2, + np.expand_dims, np.expm1, np.fabs, np.float_power, @@ -296,6 +297,9 @@ class Tracer: np.concatenate: { "axis", }, + np.expand_dims: { + "axis", + }, np.ones_like: { "dtype", }, diff --git a/docs/getting-started/compatibility.md b/docs/getting-started/compatibility.md index bf146d2c1..a9fab3086 100644 --- a/docs/getting-started/compatibility.md +++ b/docs/getting-started/compatibility.md @@ -80,6 +80,7 @@ Some of these operations are not supported between two encrypted values. A detai * [np.equal](https://numpy.org/doc/stable/reference/generated/numpy.equal.html) * [np.exp](https://numpy.org/doc/stable/reference/generated/numpy.exp.html) * [np.exp2](https://numpy.org/doc/stable/reference/generated/numpy.exp2.html) +* [np.expand\_dims](https://numpy.org/doc/stable/reference/generated/numpy.expand\_dims.html) * [np.expm1](https://numpy.org/doc/stable/reference/generated/numpy.expm1.html) * [np.fabs](https://numpy.org/doc/stable/reference/generated/numpy.fabs.html) * [np.float\_power](https://numpy.org/doc/stable/reference/generated/numpy.float\_power.html) diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index dff837a26..33703104a 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -586,6 +586,55 @@ def deterministic_unary_function(x): }, id="cnp.LookupTable(list(range(32)))[x + y]", ), + pytest.param( + lambda x: np.expand_dims(x, axis=0), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)}, + }, + id="np.expand_dims(x, axis=0)", + ), + pytest.param( + lambda x: np.expand_dims(x, axis=1), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)}, + }, + id="np.expand_dims(x, axis=1)", + ), + pytest.param( + lambda x: np.expand_dims(x, axis=2), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)}, + }, + id="np.expand_dims(x, axis=2)", + ), + pytest.param( + lambda x: np.expand_dims(x, axis=(0, 1)), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)}, + }, + id="np.expand_dims(x, axis=(0, 1))", + ), + pytest.param( + lambda x: np.expand_dims(x, axis=(0, 2)), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)}, + }, + id="np.expand_dims(x, axis=(0, 2))", + ), + pytest.param( + lambda x: np.expand_dims(x, axis=(1, 2)), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)}, + }, + id="np.expand_dims(x, axis=(1, 2))", + ), + pytest.param( + lambda x: np.expand_dims(x, axis=(0, 1, 2)), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (3, 2)}, + }, + id="np.expand_dims(x, axis=(0, 1, 2))", + ), ], ) def test_others(function, parameters, helpers):