From 1e86c3b1e4db275c67e3eafab9277c9c06eb7e57 Mon Sep 17 00:00:00 2001 From: Umut Date: Wed, 1 Jun 2022 14:49:58 +0200 Subject: [PATCH] fix: don't allow signed clear inputs as tlu cannot be applied to them, add more signed input tests --- concrete/numpy/mlir/graph_converter.py | 4 ++++ tests/execution/test_add.py | 13 +++++++++++++ tests/execution/test_others.py | 8 ++++++++ tests/mlir/test_graph_converter.py | 15 +++++++++++++++ 4 files changed, 40 insertions(+) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index a687e358e..9c9e4da6c 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -69,6 +69,8 @@ class GraphConverter: assert_that(inputs[0] == output) if not isinstance(output.dtype, Integer): return "only integer inputs are supported" + if output.dtype.is_signed and output.is_clear: + return "only encrypted signed integer inputs are supported" else: assert_that(node.operation == Operation.Generic) @@ -345,6 +347,8 @@ class GraphConverter: input_dtype = cast(Integer, input_value.dtype) if input_dtype.is_signed: + assert_that(input_value.is_encrypted) + n = input_dtype.bit_width lut_range = np.arange(2**n) diff --git a/tests/execution/test_add.py b/tests/execution/test_add.py index f25379fe4..ad2f2056e 100644 --- a/tests/execution/test_add.py +++ b/tests/execution/test_add.py @@ -49,6 +49,15 @@ import concrete.numpy as cnp { "x": {"range": [0, 85], "status": "encrypted", "shape": (2, 3)}, }, + { + "x": {"range": [-50, 10], "status": "encrypted"}, + }, + { + "x": {"range": [-50, 10], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [-50, 10], "status": "encrypted", "shape": (2, 3)}, + }, ], ) def test_constant_add(function, parameters, helpers): @@ -140,6 +149,10 @@ def test_constant_add(function, parameters, helpers): "x": {"range": [0, 60], "status": "encrypted", "shape": (2, 1)}, "y": {"range": [0, 60], "status": "encrypted", "shape": (3,)}, }, + { + "x": {"range": [-30, 30], "status": "encrypted", "shape": (3, 2)}, + "y": {"range": [-30, 30], "status": "encrypted", "shape": (3, 2)}, + }, ], ) def test_add(function, parameters, helpers): diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index e339cb968..880d35791 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -578,6 +578,14 @@ def deterministic_unary_function(x): }, id="(2.5 * round(np.sqrt(x), decimals=4)).astype(np.int64)", ), + pytest.param( + lambda x, y: cnp.LookupTable(list(range(32)))[x + y], + { + "x": {"status": "encrypted", "range": [-10, 10]}, + "y": {"status": "encrypted", "range": [-10, 10]}, + }, + id="cnp.LookupTable(list(range(32)))[x + y]", + ), ], ) def test_others(function, parameters, helpers): diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 81c6a8544..8546a5132 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -33,6 +33,21 @@ return (%2, %3) """, # noqa: E501 ), + pytest.param( + lambda x: x, + {"x": "clear"}, + range(-10, 10), + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR + +%0 = x # ClearScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only encrypted signed integer inputs are supported +return %0 + + """, # noqa: E501 + ), pytest.param( lambda x: x * 1.5, {"x": "encrypted"},