diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 50adbf0dd..6d20a6aed 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -28,7 +28,7 @@ from ..internal.utils import assert_that from ..representation import Graph, Node, Operation from ..values import ClearScalar, EncryptedScalar from .node_converter import NodeConverter -from .utils import MAXIMUM_TLU_BIT_WIDTH +from .utils import MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS, MAXIMUM_TLU_BIT_WIDTH # pylint: enable=no-member,no-name-in-module @@ -240,11 +240,19 @@ class GraphConverter: if dtype.is_signed and first_signed_node is None: first_signed_node = node - if first_tlu_node is not None and max_bit_width > MAXIMUM_TLU_BIT_WIDTH: - offending_nodes[first_tlu_node] = [ - f"table lookups are only supported on circuits with " - f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers" - ] + if first_tlu_node is not None: + if max_bit_width > MAXIMUM_TLU_BIT_WIDTH: + offending_nodes[first_tlu_node] = [ + f"table lookups are only supported on circuits with " + f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers" + ] + + if first_signed_node is not None and max_bit_width > MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS: + offending_nodes[first_signed_node] = [ + f"signed integers are only supported " + f"up to {MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS}-bits " + f"on circuits with table lookups" + ] if len(offending_nodes) != 0: raise RuntimeError( diff --git a/concrete/numpy/mlir/utils.py b/concrete/numpy/mlir/utils.py index 26c542d6e..74a832853 100644 --- a/concrete/numpy/mlir/utils.py +++ b/concrete/numpy/mlir/utils.py @@ -13,7 +13,8 @@ from ..dtypes import Integer from ..internal.utils import assert_that from ..representation import Node, Operation -MAXIMUM_TLU_BIT_WIDTH = 8 +MAXIMUM_TLU_BIT_WIDTH = 16 +MAXIMUM_SIGNED_BIT_WIDTH_WITH_TLUS = 8 class HashableNdarray: diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 289fd62bb..319afe653 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -401,17 +401,17 @@ return %2 pytest.param( lambda x: np.abs(10 * np.sin(x + 300)).astype(np.int64), {"x": "encrypted"}, - range(200), + [200000], RuntimeError, """ Function you are trying to compile cannot be converted to MLIR: -%0 = x # EncryptedScalar +%0 = x # EncryptedScalar %1 = 300 # ClearScalar -%2 = add(%0, %1) # EncryptedScalar +%2 = add(%0, %1) # EncryptedScalar %3 = subgraph(%2) # EncryptedScalar -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 8-bit integers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers return %3 Subgraphs: @@ -426,6 +426,35 @@ Subgraphs: %5 = astype(%4, dtype=int_) # EncryptedScalar return %5 + """, # noqa: E501 + ), + pytest.param( + lambda x: (10 * np.sin(x + 300)).astype(np.int64), + {"x": "encrypted"}, + range(2**10, 2**11), + RuntimeError, + """ + +Function you are trying to compile cannot be converted to MLIR: + +%0 = x # EncryptedScalar +%1 = 300 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +%3 = subgraph(%2) # EncryptedScalar +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ signed integers are only supported up to 8-bits on circuits with table lookups +return %3 + +Subgraphs: + + %3 = subgraph(%2): + + %0 = 10 # ClearScalar + %1 = input # EncryptedScalar + %2 = sin(%1) # EncryptedScalar + %3 = multiply(%0, %2) # EncryptedScalar + %4 = astype(%3, dtype=int_) # EncryptedScalar + return %4 + """, # noqa: E501 ), ],