From 8564930fc86d6ed3453a274b91e0a20276d76d02 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 23 Jan 2023 16:03:10 +0100 Subject: [PATCH] feat: add support for np.squeeze --- concrete/numpy/mlir/graph_converter.py | 3 +++ concrete/numpy/mlir/node_converter.py | 22 ++++++++++++++++ concrete/numpy/representation/node.py | 1 + concrete/numpy/tracing/tracer.py | 4 +++ tests/execution/test_others.py | 35 ++++++++++++++++++++++++++ 5 files changed, 65 insertions(+) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index aa9c7dff9..5d5950a90 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -151,6 +151,9 @@ class GraphConverter: if not inputs[0].is_encrypted: return "only encrypted reshape is supported" + elif name == "squeeze": + assert_that(len(inputs) == 1) + elif name == "subtract": assert_that(len(inputs) == 2) diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index c9a59b3ce..54d251486 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -170,6 +170,7 @@ class NodeConverter: "negative": self._convert_neg, "ones": self._convert_ones, "reshape": self._convert_reshape, + "squeeze": self._convert_squeeze, "subtract": self._convert_sub, "sum": self._convert_sum, "transpose": self._convert_transpose, @@ -884,6 +885,27 @@ class NodeConverter: ), ).result + def _convert_squeeze(self) -> OpResult: + """ + Convert "squeeze" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + # because of the tracing logic, we have the correct output shape + + # if the output shape is (), it means (1, 1, ..., 1, 1) is squeezed + # and the result is a scalar, so we need to do indexing, not reshape + if self.node.output.shape == (): + assert all(size == 1 for size in self.node.inputs[0].shape) + self.node.properties["kwargs"]["index"] = (0,) * self.node.inputs[0].ndim + return self._convert_static_indexing() + + # otherwise, a simple reshape would work as we already have the correct shape + return self._convert_reshape() + def _convert_sub(self) -> OpResult: """ Convert "subtract" node to its corresponding MLIR representation. diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index 50919c3dc..35c00df52 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -381,6 +381,7 @@ class Node: "negative", "ones", "reshape", + "squeeze", "subtract", "sum", "transpose", diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 8a5fe56c7..e3ae9f9de 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -287,6 +287,7 @@ class Tracer: np.spacing, np.sqrt, np.square, + np.squeeze, np.subtract, np.sum, np.tan, @@ -320,6 +321,9 @@ class Tracer: np.round_: { "decimals", }, + np.squeeze: { + "axis", + }, np.sum: { "axis", "keepdims", diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index 87a5d6358..fb45bb002 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -642,6 +642,41 @@ def deterministic_unary_function(x): }, id="x ** 3", ), + pytest.param( + lambda x: np.squeeze(x), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)}, + }, + id="np.squeeze(x)", + ), + pytest.param( + lambda x: np.squeeze(x, axis=2), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)}, + }, + id="np.squeeze(x, axis=2)", + ), + pytest.param( + lambda x: np.squeeze(x, axis=(0, 4)), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 2, 1, 3, 1, 4)}, + }, + id="np.squeeze(x, axis=(0, 4))", + ), + pytest.param( + lambda x: np.squeeze(x), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)}, + }, + id="np.squeeze(x) where x.shape == (1, 1, 1)", + ), + pytest.param( + lambda x: np.squeeze(x, axis=1), + { + "x": {"status": "encrypted", "range": [-10, 10], "shape": (1, 1, 1)}, + }, + id="np.squeeze(x, axis=1) where x.shape == (1, 1, 1)", + ), ], ) def test_others(function, parameters, helpers):