refactor(debugging): accept several highlights per node when printing

refs #645
This commit is contained in:
Arthur Meyre
2021-10-26 15:06:08 +02:00
parent 70fbac7188
commit ecfde7b233
3 changed files with 19 additions and 15 deletions

View File

@@ -1,6 +1,6 @@
"""functions to print the different graphs we can generate in the package, eg to debug."""
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
import networkx as nx
@@ -48,15 +48,16 @@ def shorten_a_constant(constant_data: str):
def get_printable_graph(
opgraph: OPGraph,
show_data_types: bool = False,
highlighted_nodes: Optional[Dict[IntermediateNode, str]] = None,
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
) -> str:
"""Return a string representing a graph.
Args:
opgraph (OPGraph): The graph that we want to draw
show_data_types (bool): Whether or not showing data_types of nodes, eg to see their width
highlighted_nodes (Optional[Dict[IntermediateNode, str]]):
The dict of nodes which will be highlighted and their corresponding messages
show_data_types (bool, optional): Whether or not showing data_types of nodes, eg to see
their width. Defaults to False.
highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]], optional): The dict of nodes
which will be highlighted and their corresponding messages. Defaults to None.
Returns:
str: a string to print or save in a file
@@ -145,8 +146,9 @@ def get_printable_graph(
returned_str += f"{new_line}\n"
if node in highlighted_nodes:
message = highlighted_nodes[node]
returned_str += f"{'^' * len(new_line)} {message}\n"
new_line_len = len(new_line)
message = f"\n{' ' * new_line_len} ".join(highlighted_nodes[node])
returned_str += f"{'^' * new_line_len} {message}\n"
map_table[node] = i
i += 1

View File

@@ -1,5 +1,5 @@
"""Utilities for MLIR conversion."""
from typing import Dict, Optional, cast
from typing import Dict, List, Optional, cast
from ..data_types import Integer
from ..data_types.dtypes_helpers import (
@@ -99,7 +99,7 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool)
def check_graph_values_compatibility_with_mlir(
op_graph: OPGraph,
) -> Optional[Dict[IntermediateNode, str]]:
) -> Optional[Dict[IntermediateNode, List[str]]]:
"""Make sure the graph outputs are unsigned integers, which is what the compiler supports.
Args:
@@ -115,7 +115,7 @@ def check_graph_values_compatibility_with_mlir(
for node in op_graph.graph.nodes:
is_output = node in op_graph.output_nodes.values()
if (reason := check_node_compatibility_with_mlir(node, is_output)) is not None:
offending_nodes[node] = reason
offending_nodes[node] = [reason]
return None if len(offending_nodes) == 0 else offending_nodes
@@ -162,9 +162,9 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
# Check that current_node_out_bit_width is supported by the compiler
if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB:
offending_nodes[
node
] = f"{current_node_out_bit_width} bits is not supported for the time being"
offending_nodes[node] = [
f"{current_node_out_bit_width} bits is not supported for the time being"
]
if len(offending_nodes) != 0:
raise RuntimeError(

View File

@@ -19,7 +19,7 @@ def test_get_printable_graph_with_offending_nodes(default_compilation_configurat
default_compilation_configuration,
)
highlighted_nodes = {opgraph.input_nodes[0]: "foo"}
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
without_types = get_printable_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
@@ -54,7 +54,7 @@ return(%2)
""".strip()
)
highlighted_nodes = {opgraph.input_nodes[0]: "foo", opgraph.output_nodes[0]: "bar"}
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
without_types = get_printable_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
@@ -72,6 +72,7 @@ return(%2)
%1 = Constant(42)
%2 = Add(%0, %1)
^^^^^^^^^^^^^^^^ bar
baz
return(%2)
""".strip()
@@ -86,6 +87,7 @@ return(%2)
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
baz
return(%2)
""".strip()