From 039a632c724c9895f77a3fcab33b62e2b65d53d2 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 15 Aug 2022 16:32:58 +0200 Subject: [PATCH] feat: support all kind of subtraction operations --- concrete/numpy/mlir/graph_converter.py | 2 - concrete/numpy/mlir/node_converter.py | 19 ++++- tests/execution/test_sub.py | 103 +++++++++++++++++++++++++ tests/mlir/test_graph_converter.py | 17 ---- 4 files changed, 119 insertions(+), 22 deletions(-) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 4c49fccbc..ab58e0b04 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -137,8 +137,6 @@ class GraphConverter: elif name == "subtract": assert_that(len(inputs) == 2) - if not (inputs[0].is_clear and inputs[1].is_encrypted): - return "only subtraction of encrypted from clear is supported" elif name == "sum": assert_that(len(inputs) == 1) diff --git a/concrete/numpy/mlir/node_converter.py b/concrete/numpy/mlir/node_converter.py index 622a5b14d..22ea92229 100644 --- a/concrete/numpy/mlir/node_converter.py +++ b/concrete/numpy/mlir/node_converter.py @@ -840,10 +840,23 @@ class NodeConverter: resulting_type = NodeConverter.value_to_mlir_type(self.ctx, self.node.output) preds = self.preds - if self.one_of_the_inputs_is_a_tensor: - result = fhelinalg.SubIntEintOp(resulting_type, *preds).result + if self.all_of_the_inputs_are_encrypted: + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.SubEintOp(resulting_type, *preds).result + else: + result = fhe.SubEintOp(resulting_type, *preds).result + + elif self.node.inputs[0].is_clear: + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.SubIntEintOp(resulting_type, *preds).result + else: + result = fhe.SubIntEintOp(resulting_type, *preds).result + else: - result = fhe.SubIntEintOp(resulting_type, *preds).result + if self.one_of_the_inputs_is_a_tensor: + result = fhelinalg.SubEintIntOp(resulting_type, *preds).result + else: + result = fhe.SubEintIntOp(resulting_type, *preds).result return result diff --git a/tests/execution/test_sub.py b/tests/execution/test_sub.py index e132153bd..8c3fdd313 100644 --- a/tests/execution/test_sub.py +++ b/tests/execution/test_sub.py @@ -11,14 +11,26 @@ import concrete.numpy as cnp @pytest.mark.parametrize( "function", [ + pytest.param( + lambda x: x - 42, + id="x - 42", + ), pytest.param( lambda x: 42 - x, id="42 - x", ), + pytest.param( + lambda x: x - np.array([1, 2, 3]), + id="x - [1, 2, 3]", + ), pytest.param( lambda x: np.array([1, 2, 3]) - x, id="[1, 2, 3] - x", ), + pytest.param( + lambda x: x - np.array([[1, 2, 3], [4, 5, 6]]), + id="x - [[1, 2, 3], [4, 5, 6]]", + ), pytest.param( lambda x: np.array([[1, 2, 3], [4, 5, 6]]) - x, id="[[1, 2, 3], [4, 5, 6]] - x", @@ -54,3 +66,94 @@ def test_constant_sub(function, parameters, helpers): sample = helpers.generate_sample(parameters) helpers.check_execution(circuit, function, sample) + + +@pytest.mark.parametrize( + "function", + [ + pytest.param( + lambda x, y: x - y, + id="x - y", + ), + ], +) +@pytest.mark.parametrize( + "parameters", + [ + { + "x": {"range": [0, 60], "status": "clear"}, + "y": {"range": [0, 60], "status": "encrypted"}, + }, + { + "x": {"range": [0, 60], "status": "encrypted"}, + "y": {"range": [0, 60], "status": "clear"}, + }, + { + "x": {"range": [0, 60], "status": "encrypted"}, + "y": {"range": [0, 60], "status": "encrypted"}, + }, + { + "x": {"range": [0, 60], "status": "clear", "shape": (3,)}, + "y": {"range": [0, 60], "status": "encrypted"}, + }, + { + "x": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 60], "status": "clear"}, + }, + { + "x": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 60], "status": "encrypted"}, + }, + { + "x": {"range": [0, 60], "status": "clear"}, + "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "encrypted"}, + "y": {"range": [0, 60], "status": "clear", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "encrypted"}, + "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "clear", "shape": (3,)}, + "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 60], "status": "clear", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "clear", "shape": (2, 1)}, + "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "encrypted", "shape": (2, 1)}, + "y": {"range": [0, 60], "status": "clear", "shape": (3,)}, + }, + { + "x": {"range": [0, 60], "status": "encrypted", "shape": (2, 1)}, + "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, + }, + ], +) +def test_sub(function, parameters, helpers): + """ + Test sub where both of the operators are dynamic. + """ + + parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters) + configuration = helpers.configuration() + + compiler = cnp.Compiler(function, parameter_encryption_statuses) + + inputset = helpers.generate_inputset(parameters) + circuit = compiler.compile(inputset, configuration) + + sample = helpers.generate_sample(parameters) + helpers.check_execution(circuit, function, sample) diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 47fd89e12..b826a090f 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -286,23 +286,6 @@ return %1 """, # noqa: E501 ), - pytest.param( - lambda x: x - 1, - {"x": "clear"}, - [np.random.randint(0, 2**3, size=(2, 3)) for _ in range(100)], - RuntimeError, - """ - -Function you are trying to compile cannot be converted to MLIR - -%0 = x # ClearTensor -%1 = 1 # ClearScalar -%2 = subtract(%0, %1) # ClearTensor -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only subtraction of encrypted from clear is supported -return %2 - - """, # noqa: E501 - ), pytest.param( lambda x: np.sum(x), {"x": "clear"},