mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor(debugging): accept several highlights per node when printing
refs #645
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
"""functions to print the different graphs we can generate in the package, eg to debug."""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@@ -48,15 +48,16 @@ def shorten_a_constant(constant_data: str):
|
||||
def get_printable_graph(
|
||||
opgraph: OPGraph,
|
||||
show_data_types: bool = False,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, str]] = None,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
|
||||
) -> str:
|
||||
"""Return a string representing a graph.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): The graph that we want to draw
|
||||
show_data_types (bool): Whether or not showing data_types of nodes, eg to see their width
|
||||
highlighted_nodes (Optional[Dict[IntermediateNode, str]]):
|
||||
The dict of nodes which will be highlighted and their corresponding messages
|
||||
show_data_types (bool, optional): Whether or not showing data_types of nodes, eg to see
|
||||
their width. Defaults to False.
|
||||
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]], optional): The dict of nodes
|
||||
which will be highlighted and their corresponding messages. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: a string to print or save in a file
|
||||
@@ -145,8 +146,9 @@ def get_printable_graph(
|
||||
returned_str += f"{new_line}\n"
|
||||
|
||||
if node in highlighted_nodes:
|
||||
message = highlighted_nodes[node]
|
||||
returned_str += f"{'^' * len(new_line)} {message}\n"
|
||||
new_line_len = len(new_line)
|
||||
message = f"\n{' ' * new_line_len} ".join(highlighted_nodes[node])
|
||||
returned_str += f"{'^' * new_line_len} {message}\n"
|
||||
|
||||
map_table[node] = i
|
||||
i += 1
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""Utilities for MLIR conversion."""
|
||||
from typing import Dict, Optional, cast
|
||||
from typing import Dict, List, Optional, cast
|
||||
|
||||
from ..data_types import Integer
|
||||
from ..data_types.dtypes_helpers import (
|
||||
@@ -99,7 +99,7 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
|
||||
|
||||
def check_graph_values_compatibility_with_mlir(
|
||||
op_graph: OPGraph,
|
||||
) -> Optional[Dict[IntermediateNode, str]]:
|
||||
) -> Optional[Dict[IntermediateNode, List[str]]]:
|
||||
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
|
||||
|
||||
Args:
|
||||
@@ -115,7 +115,7 @@ def check_graph_values_compatibility_with_mlir(
|
||||
for node in op_graph.graph.nodes:
|
||||
is_output = node in op_graph.output_nodes.values()
|
||||
if (reason := check_node_compatibility_with_mlir(node, is_output)) is not None:
|
||||
offending_nodes[node] = reason
|
||||
offending_nodes[node] = [reason]
|
||||
|
||||
return None if len(offending_nodes) == 0 else offending_nodes
|
||||
|
||||
@@ -162,9 +162,9 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
|
||||
# Check that current_node_out_bit_width is supported by the compiler
|
||||
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
|
||||
offending_nodes[
|
||||
node
|
||||
] = f"{current_node_out_bit_width} bits is not supported for the time being"
|
||||
offending_nodes[node] = [
|
||||
f"{current_node_out_bit_width} bits is not supported for the time being"
|
||||
]
|
||||
|
||||
if len(offending_nodes) != 0:
|
||||
raise RuntimeError(
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_get_printable_graph_with_offending_nodes(default_compilation_configurat
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: "foo"}
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
|
||||
|
||||
without_types = get_printable_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
@@ -54,7 +54,7 @@ return(%2)
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: "foo", opgraph.output_nodes[0]: "bar"}
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
|
||||
|
||||
without_types = get_printable_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
@@ -72,6 +72,7 @@ return(%2)
|
||||
%1 = Constant(42)
|
||||
%2 = Add(%0, %1)
|
||||
^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
@@ -86,6 +87,7 @@ return(%2)
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
|
||||
Reference in New Issue
Block a user