mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
chore(debugging): show problems in a clearer way with highlighted_nodes
This commit is contained in:
committed by
Benoit Chevallier
parent
0b864afb76
commit
bc90ed37ff
@@ -10,6 +10,7 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_scalar,
|
||||
value_is_unsigned_integer,
|
||||
)
|
||||
from ..debugging import get_printable_graph
|
||||
from ..debugging.custom_assert import assert_not_reached, assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate
|
||||
@@ -134,7 +135,7 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
op_graph: graph to update bit_width for
|
||||
"""
|
||||
max_bit_width = 0
|
||||
offending_list = []
|
||||
offending_nodes = {}
|
||||
for node in op_graph.graph.nodes:
|
||||
for value_out in node.outputs:
|
||||
if value_is_clear_scalar_integer(value_out) or value_is_clear_tensor_integer(value_out):
|
||||
@@ -152,18 +153,20 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
|
||||
# Check that current_node_out_bit_width is supported by the compiler
|
||||
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
|
||||
offending_list.append((node, current_node_out_bit_width))
|
||||
offending_nodes[
|
||||
node
|
||||
] = f"{current_node_out_bit_width} bits is not supported for the time being"
|
||||
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
# Check that the max_bit_width is supported by the compiler
|
||||
if len(offending_list) != 0:
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
f"max_bit_width of some nodes is too high for the current version of "
|
||||
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB} "
|
||||
f"which is not compatible with {offending_list})"
|
||||
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) "
|
||||
f"which is not compatible with:\n"
|
||||
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
|
||||
)
|
||||
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
|
||||
def extend_direct_lookup_tables(op_graph: OPGraph):
|
||||
"""Extend direct lookup tables to the maximum length the input bit width can support.
|
||||
|
||||
@@ -1049,12 +1049,17 @@ def test_compile_too_high_bitwidth(default_compilation_configuration):
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
assert (
|
||||
"max_bit_width of some nodes is too high for the current version of the "
|
||||
"compiler (maximum must be 7 which is not compatible with" in str(excinfo.value)
|
||||
str(excinfo.value)
|
||||
== "max_bit_width of some nodes is too high for the current version of the compiler (maximum must be 7) which is not compatible with:\n" # noqa: E501
|
||||
"%0 = x # EncryptedScalar<Integer<unsigned, 7 bits>>\n" # noqa: E501
|
||||
"%1 = y # EncryptedScalar<Integer<unsigned, 5 bits>>\n" # noqa: E501
|
||||
"%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 8 bits>>\n" # noqa: E501
|
||||
"^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 bits is not supported for the time being\n" # noqa: E501
|
||||
"return(%2)\n"
|
||||
)
|
||||
|
||||
assert str(excinfo.value).endswith(", 8)])")
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
# Just ok
|
||||
input_ranges = [(0, 99), (0, 28)]
|
||||
|
||||
Reference in New Issue
Block a user