From bc90ed37ff7621a34d9bbe1c57ff0bd322dd3805 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Thu, 21 Oct 2021 19:27:07 +0200 Subject: [PATCH] chore(debugging): show problems in a clearer way with highlighted_nodes --- concrete/common/mlir/utils.py | 19 +++++++++++-------- tests/numpy/test_compile.py | 13 +++++++++---- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index 172e133c4..d8694709f 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -10,6 +10,7 @@ from ..data_types.dtypes_helpers import ( value_is_scalar, value_is_unsigned_integer, ) +from ..debugging import get_printable_graph from ..debugging.custom_assert import assert_not_reached, assert_true from ..operator_graph import OPGraph from ..representation import intermediate @@ -134,7 +135,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph): op_graph: graph to update bit_width for """ max_bit_width = 0 - offending_list = [] + offending_nodes = {} for node in op_graph.graph.nodes: for value_out in node.outputs: if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out): @@ -152,18 +153,20 @@ def update_bit_width_for_mlir(op_graph: OPGraph): # Check that current_node_out_bit_width is supported by the compiler if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB: - offending_list.append((node, current_node_out_bit_width)) + offending_nodes[ + node + ] = f"{current_node_out_bit_width} bits is not supported for the time being" - _set_all_bit_width(op_graph, max_bit_width) - - # Check that the max_bit_width is supported by the compiler - if len(offending_list) != 0: + if len(offending_nodes) != 0: raise RuntimeError( f"max_bit_width of some nodes is too high for the current version of " - f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB} " - f"which is not compatible with {offending_list})" + f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) " + f"which is not compatible with:\n" + + get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes) ) + _set_all_bit_width(op_graph, max_bit_width) + def extend_direct_lookup_tables(op_graph: OPGraph): """Extend direct lookup tables to the maximum length the input bit width can support. diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index e26a82b13..bf583f13e 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1049,12 +1049,17 @@ def test_compile_too_high_bitwidth(default_compilation_configuration): default_compilation_configuration, ) + # pylint: disable=line-too-long assert ( - "max_bit_width of some nodes is too high for the current version of the " - "compiler (maximum must be 7 which is not compatible with" in str(excinfo.value) + str(excinfo.value) + == "max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 7) which is not compatible with:\n" # noqa: E501 + "%0 = x # EncryptedScalar>\n" # noqa: E501 + "%1 = y # EncryptedScalar>\n" # noqa: E501 + "%2 = Add(%0, %1) # EncryptedScalar>\n" # noqa: E501 + "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being\n" # noqa: E501 + "return(%2)\n" ) - - assert str(excinfo.value).endswith(", 8)])") + # pylint: enable=line-too-long # Just ok input_ranges = [(0, 99), (0, 28)]