From 155e5b2941d5c4f97508ade447712ff38762fa0c Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 2 Jan 2023 14:44:58 +0100 Subject: [PATCH] refactor: improve error message of circuits which has TLU and more than 16-bits at the same time --- concrete/numpy/mlir/graph_converter.py | 20 +++++++++++++++++--- tests/mlir/test_graph_converter.py | 3 ++- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/concrete/numpy/mlir/graph_converter.py b/concrete/numpy/mlir/graph_converter.py index 05a71fb2d..aa9c7dff9 100644 --- a/concrete/numpy/mlir/graph_converter.py +++ b/concrete/numpy/mlir/graph_converter.py @@ -236,18 +236,21 @@ class GraphConverter: offending_nodes: Dict[Node, List[str]] = {} max_bit_width = 0 + max_bit_width_node = None first_tlu_node = None first_signed_node = None - for node in graph.graph.nodes: + for node in nx.lexicographical_topological_sort(graph.graph): dtype = node.output.dtype assert_that(isinstance(dtype, Integer)) current_node_bit_width = ( dtype.bit_width - 1 if node.output.is_clear else dtype.bit_width ) - max_bit_width = max(max_bit_width, current_node_bit_width) + if max_bit_width < current_node_bit_width: + max_bit_width = current_node_bit_width + max_bit_width_node = node if node.converted_to_table_lookup and first_tlu_node is None: first_tlu_node = node @@ -257,9 +260,20 @@ class GraphConverter: if first_tlu_node is not None: if max_bit_width > MAXIMUM_TLU_BIT_WIDTH: + assert max_bit_width_node is not None + offending_nodes[max_bit_width_node] = [ + ( + { + Operation.Input: f"this input is {max_bit_width}-bits", + Operation.Constant: f"this constant is {max_bit_width}-bits", + Operation.Generic: f"this operation results in {max_bit_width}-bits", + }[max_bit_width_node.operation] + ), + max_bit_width_node.location, + ] offending_nodes[first_tlu_node] = [ f"table lookups are only supported on circuits with " - f"up to {MAXIMUM_TLU_BIT_WIDTH}-bit integers", + f"up to {MAXIMUM_TLU_BIT_WIDTH}-bits", first_tlu_node.location, ] diff --git a/tests/mlir/test_graph_converter.py b/tests/mlir/test_graph_converter.py index 3b1c0cc1c..1fbfac691 100644 --- a/tests/mlir/test_graph_converter.py +++ b/tests/mlir/test_graph_converter.py @@ -381,10 +381,11 @@ return %2 Function you are trying to compile cannot be converted to MLIR: %0 = x # EncryptedScalar ∈ [200000, 200000] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ this input is 18-bits %1 = 300 # ClearScalar ∈ [300, 300] %2 = add(%0, %1) # EncryptedScalar ∈ [200300, 200300] %3 = subgraph(%2) # EncryptedScalar ∈ [9, 9] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bit integers +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ table lookups are only supported on circuits with up to 16-bits return %3 Subgraphs: