From 7415dd07e1d65bd64624029869b400c435b934da Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 29 Aug 2022 10:24:05 +0200 Subject: [PATCH] feat: support unsigned levelled operations on large bit-widths --- concrete/numpy/__init__.py | 2 +- concrete/numpy/compilation/client.py | 6 +-- concrete/numpy/mlir/graph_converter.py | 27 +++++++++-- concrete/numpy/mlir/utils.py | 2 +- tests/conftest.py | 2 +- tests/execution/test_add.py | 9 ++++ tests/mlir/test_graph_converter.py | 64 +++++++++++++++++++------- 7 files changed, 84 insertions(+), 28 deletions(-) diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 6d326758d..db447abeb 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -16,5 +16,5 @@ from .compilation import ( compiler, ) from .extensions import LookupTable, array, one, ones, univariate, zero, zeros -from .mlir.utils import MAXIMUM_BIT_WIDTH +from .mlir.utils import MAXIMUM_TLU_BIT_WIDTH from .representation import Graph diff --git a/concrete/numpy/compilation/client.py b/concrete/numpy/compilation/client.py index a48a1451c..5b1885d26 100644 --- a/concrete/numpy/compilation/client.py +++ b/concrete/numpy/compilation/client.py @@ -165,7 +165,7 @@ class Client: arg >= 0, arg, 2 * (expected_max + 1) + arg, - ).astype(np.uint8) + ).astype(np.uint64) if not is_valid: actual_value = Value.of(arg, is_encrypted=is_encrypted) @@ -220,10 +220,10 @@ class Client: else: result = result.astype(np.longlong) # to prevent overflows in numpy sanititzed_result = np.where(result < (2 ** (n - 1)), result, result - (2**n)) - sanitized_results.append(sanititzed_result.astype(np.int8)) + sanitized_results.append(sanititzed_result.astype(np.int64)) else: sanitized_results.append( - result if isinstance(result, int) else result.astype(np.uint8) + result if isinstance(result, int) else result.astype(np.uint64) ) return sanitized_results[0] if len(sanitized_results) == 1 else tuple(sanitized_results) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index ae77b7c34..ca6ce5d71 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -28,7 +28,7 @@ from ..internal.utils import assert_that from ..representation import Graph, Node, Operation from ..values import ClearScalar, EncryptedScalar from .node_converter import NodeConverter -from .utils import MAXIMUM_BIT_WIDTH +from .utils import MAXIMUM_TLU_BIT_WIDTH # pylint: enable=no-member,no-name-in-module @@ -221,6 +221,10 @@ class GraphConverter: offending_nodes: Dict[Node, List[str]] = {} max_bit_width = 0 + + first_tlu_node = None + first_signed_node = None + for node in graph.graph.nodes: dtype = node.output.dtype assert_that(isinstance(dtype, Integer)) @@ -230,10 +234,23 @@ class GraphConverter: ) max_bit_width = max(max_bit_width, current_node_bit_width) - if current_node_bit_width > MAXIMUM_BIT_WIDTH: - offending_nodes[node] = [ - f"only up to {MAXIMUM_BIT_WIDTH}-bit integers are supported" - ] + if node.converted_to_table_lookup and first_tlu_node is None: + first_tlu_node = node + + if dtype.is_signed and first_signed_node is None: + first_signed_node = node + + if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH: + offending_nodes[first_tlu_node] = [ + f"table lookups are only supported on circuits with " + f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers" + ] + + if first_signed_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH: + offending_nodes[first_signed_node] = [ + f"signed values are only supported on circuits with " + f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers" + ] if len(offending_nodes) != 0: raise RuntimeError( diff --git a/concrete/numpy/mlir/utils.py b/concrete/numpy/mlir/utils.py index ad27f1534..26c542d6e 100644 --- a/concrete/numpy/mlir/utils.py +++ b/concrete/numpy/mlir/utils.py @@ -13,7 +13,7 @@ from ..dtypes import Integer from ..internal.utils import assert_that from ..representation import Node, Operation -MAXIMUM_BIT_WIDTH = 8 +MAXIMUM_TLU_BIT_WIDTH = 8 class HashableNdarray: diff --git a/tests/conftest.py b/tests/conftest.py index ee7c68aec..1ad55bfdb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -201,7 +201,7 @@ class Helpers: sample = [] for description in parameters.values(): - minimum, maximum = description.get("range", [0, 127]) + minimum, maximum = description.get("range", [0, (2**16) - 1]) if "shape" in description: shape = description["shape"] diff --git a/tests/execution/test_add.py b/tests/execution/test_add.py index ad2f2056e..ec07f4b8c 100644 --- a/tests/execution/test_add.py +++ b/tests/execution/test_add.py @@ -58,6 +58,15 @@ import concrete.numpy as cnp { "x": {"range": [-50, 10], "status": "encrypted", "shape": (2, 3)}, }, + { + "x": {"range": [0, 1000000], "status": "encrypted"}, + }, + { + "x": {"range": [0, 1000000], "status": "encrypted", "shape": (3,)}, + }, + { + "x": {"range": [0, 1000000], "status": "encrypted", "shape": (2, 3)}, + }, ], ) def test_constant_add(function, parameters, helpers): diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index db27379b7..e63c1c03b 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -344,23 +344,6 @@ Function you are trying to compile cannot be converted to MLIR %1 = [3] # ClearTensor %2 = maximum(%0, %1) # ClearTensor ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ one of the operands must be encrypted -return %2 - - """, # noqa: E501 - ), - pytest.param( - lambda x: x + 200, - {"x": "encrypted"}, - range(200), - RuntimeError, - """ - -Function you are trying to compile cannot be converted to MLIR: - -%0 = x # EncryptedScalar -%1 = 200 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ only up to 8-bit integers are supported return %2 """, # noqa: E501 @@ -413,6 +396,53 @@ Function you are trying to compile cannot be converted to MLIR return %2 + """, # noqa: E501 + ), + pytest.param( + lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64), + {"x": "encrypted"}, + range(200), + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR: + +%0 = x # EncryptedScalar +%1 = 300 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +%3 = subgraph(%2) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 8-bit integers +return %3 + +Subgraphs: + + %3 = subgraph(%2): + + %0 = 10 # ClearScalar + %1 = input # EncryptedScalar + %2 = sin(%1) # EncryptedScalar + %3 = multiply(%0, %2) # EncryptedScalar + %4 = absolute(%3) # EncryptedScalar + %5 = astype(%4, dtype=int_) # EncryptedScalar + return %5 + + """, # noqa: E501 + ), + pytest.param( + lambda x: x - 300, + {"x": "encrypted"}, + range(200), + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR: + +%0 = x # EncryptedScalar +%1 = 300 # ClearScalar +%2 = subtract(%0, %1) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed values are only supported on circuits with up to 8-bit integers +return %2 + """, # noqa: E501 ), ],