From ba33d427622e56de2e2da73585d585ded0b09233 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 4 Apr 2022 14:13:03 +0200 Subject: [PATCH] feat: implement transpose --- concrete/numpy/mlir/graph_converter.py | 9 +++-- concrete/numpy/mlir/node_converter.py | 17 +++++++++ concrete/numpy/representation/node.py | 1 + concrete/numpy/tracing/tracer.py | 25 +++++++++++++- docs/user/howto/numpy_support.md | 1 + tests/execution/test_transpose.py | 48 ++++++++++++++++++++++++++ tests/mlir/test_graph_converter.py | 16 +++++++++ tests/tracing/test_tracer.py | 12 +++++++ 8 files changed, 126 insertions(+), 3 deletions(-) create mode 100644 tests/execution/test_transpose.py diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 70c75877d..82d020b2c 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -45,7 +45,7 @@ class GraphConverter: None if node is convertible to MLIR, the reason for inconvertibility otherwise """ - # pylint: disable=too-many-branches,too-many-return-statements + # pylint: disable=too-many-branches,too-many-return-statements,too-many-statements inputs = node.inputs output = node.output @@ -121,6 +121,11 @@ class GraphConverter: if not inputs[0].is_encrypted: return "only encrypted sum is supported" + elif name == "transpose": + assert_that(len(inputs) == 1) + if not inputs[0].is_encrypted: + return "only encrypted transpose is supported" + else: variable_input_indices = [ idx @@ -135,7 +140,7 @@ class GraphConverter: return None - # pylint: enable=too-many-branches,too-many-return-statements + # pylint: enable=too-many-branches,too-many-return-statements,too-many-statements @staticmethod def _check_graph_convertibility(graph: Graph): diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index bb36e2d93..1146caa93 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -166,6 +166,9 @@ class NodeConverter: elif name == "sum": result = self.convert_sum() + elif name == "transpose": + result = self.convert_transpose() + else: result = self.convert_tlu() @@ -798,3 +801,17 @@ class NodeConverter: result = fhe.ApplyLookupTableEintOp(resulting_type, pred, lut).result return result + + def convert_transpose(self) -> OpResult: + """ + Convert "transpose" node to its corresponding MLIR representation. + + Returns: + OpResult: + in-memory MLIR representation corresponding to `self.node` + """ + + resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) + preds = self.preds + + return fhelinalg.TransposeOp(resulting_type, *preds).result diff --git a/concrete/numpy/representation/node.py b/concrete/numpy/representation/node.py index d40abf38b..575b8e427 100644 --- a/concrete/numpy/representation/node.py +++ b/concrete/numpy/representation/node.py @@ -313,4 +313,5 @@ class Node: "reshape", "subtract", "sum", + "transpose", ] diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 5e247731a..ef6ca2448 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -4,7 +4,7 @@ Declaration of `Tracer` class. import inspect from copy import deepcopy -from typing import Any, Callable, Dict, List, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import networkx as nx import numpy as np @@ -230,6 +230,7 @@ class Tracer: np.sum, np.tan, np.tanh, + np.transpose, np.true_divide, np.trunc, } @@ -345,6 +346,10 @@ class Tracer: sanitized_args = [self._sanitize(args[0])] if len(args) > 1: kwargs["newshape"] = args[1] + elif func is np.transpose: + sanitized_args = [self._sanitize(args[0])] + if len(args) > 1: + kwargs["axes"] = args[1] else: sanitized_args = [self._sanitize(arg) for arg in args] @@ -516,6 +521,16 @@ class Tracer: return Tracer._trace_numpy_operation(np.reshape, self, newshape=newshape) + def transpose(self, axes: Optional[Tuple[int, ...]] = None) -> "Tracer": + """ + Trace numpy.ndarray.transpose(). + """ + + if axes is None: + return Tracer._trace_numpy_operation(np.transpose, self) + + return Tracer._trace_numpy_operation(np.transpose, self, axes=axes) + def __getitem__( self, index: Union[int, np.integer, slice, Tuple[Union[int, np.integer, slice], ...]], @@ -583,3 +598,11 @@ class Tracer: """ return self.output.size + + @property + def T(self) -> "Tracer": # pylint: disable=invalid-name + """ + Trace numpy.ndarray.T. + """ + + return Tracer._trace_numpy_operation(np.transpose, self) diff --git a/docs/user/howto/numpy_support.md b/docs/user/howto/numpy_support.md index 13fd967ac..fa8e59dd3 100644 --- a/docs/user/howto/numpy_support.md +++ b/docs/user/howto/numpy_support.md @@ -89,6 +89,7 @@ List of supported functions: - sum - tan - tanh +- transpose - true_divide - trunc diff --git a/tests/execution/test_transpose.py b/tests/execution/test_transpose.py new file mode 100644 index 000000000..d39ed9d17 --- /dev/null +++ b/tests/execution/test_transpose.py @@ -0,0 +1,48 @@ +""" +Tests of execution of transpose operation. +""" + +import numpy as np +import pytest + +import concrete.numpy as cnp + + +@pytest.mark.parametrize( + "function,parameters", + [ + pytest.param( + lambda x: np.transpose(x), + { + "x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"}, + }, + ), + pytest.param( + lambda x: x.transpose(), + { + "x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"}, + }, + ), + pytest.param( + lambda x: x.T, + { + "x": {"shape": (3, 2), "range": [0, 10], "status": "encrypted"}, + }, + ), + ], +) +def test_transpose(function, parameters, helpers): + """ + Test transpose. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses, configuration) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset) + + sample = helpers.generate_sample(parameters) + helpers.check_execution(circuit, function, sample) diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 5983e8196..e5a449f89 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -309,6 +309,22 @@ Function you are trying to compile cannot be converted to MLIR: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 8-bit integers are supported return %2 + """, # noqa: E501 + ), + pytest.param( + lambda x: np.transpose(x), + {"x": "clear"}, + [np.random.randint(0, 2, size=(3, 2)) for _ in range(100)], + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR + +%0 = x # ClearTensor +%1 = transpose(%0) # ClearTensor +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted transpose is supported +return %1 + """, # noqa: E501 ), ], diff --git a/tests/tracing/test_tracer.py b/tests/tracing/test_tracer.py index e07db0b63..7e06d65c3 100644 --- a/tests/tracing/test_tracer.py +++ b/tests/tracing/test_tracer.py @@ -25,6 +25,18 @@ from concrete.numpy.values import EncryptedTensor RuntimeError, "Function 'np.sum' is not supported with kwarg 'initial'", ), + pytest.param( + lambda x: np.transpose(x, (1, 0, 2)), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))}, + RuntimeError, + "Function 'np.transpose' is not supported with kwarg 'axes'", + ), + pytest.param( + lambda x: x.transpose((1, 0, 2)), + {"x": EncryptedTensor(UnsignedInteger(7), shape=(1, 2, 3))}, + RuntimeError, + "Function 'np.transpose' is not supported with kwarg 'axes'", + ), pytest.param( lambda x: np.multiply.outer(x, [1, 2, 3]), {"x": EncryptedTensor(UnsignedInteger(7), shape=(4,))},