From ecfde7b233638db500bce4c0ccc8f1a1a1a6f5d1 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 26 Oct 2021 15:06:08 +0200 Subject: [PATCH] refactor(debugging): accept several highlights per node when printing refs #645 --- concrete/common/debugging/printing.py | 16 +++++++++------- concrete/common/mlir/utils.py | 12 ++++++------ tests/common/debugging/test_printing.py | 6 ++++-- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index cfca8696b..13c334dcf 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -1,6 +1,6 @@ """functions to print the different graphs we can generate in the package, eg to debug.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import networkx as nx @@ -48,15 +48,16 @@ def shorten_a_constant(constant_data: str): def get_printable_graph( opgraph: OPGraph, show_data_types: bool = False, - highlighted_nodes: Optional[Dict[IntermediateNode, str]] = None, + highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None, ) -> str: """Return a string representing a graph. Args: opgraph (OPGraph): The graph that we want to draw - show_data_types (bool): Whether or not showing data_types of nodes, eg to see their width - highlighted_nodes (Optional[Dict[IntermediateNode, str]]): - The dict of nodes which will be highlighted and their corresponding messages + show_data_types (bool, optional): Whether or not showing data_types of nodes, eg to see + their width. Defaults to False. + highlighted_nodes (Optional[Dict[IntermediateNode, List[str]]], optional): The dict of nodes + which will be highlighted and their corresponding messages. Defaults to None. Returns: str: a string to print or save in a file @@ -145,8 +146,9 @@ def get_printable_graph( returned_str += f"{new_line}\n" if node in highlighted_nodes: - message = highlighted_nodes[node] - returned_str += f"{'^' * len(new_line)} {message}\n" + new_line_len = len(new_line) + message = f"\n{' ' * new_line_len} ".join(highlighted_nodes[node]) + returned_str += f"{'^' * new_line_len} {message}\n" map_table[node] = i i += 1 diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index b46d9e355..39405af8e 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -1,5 +1,5 @@ """Utilities for MLIR conversion.""" -from typing import Dict, Optional, cast +from typing import Dict, List, Optional, cast from ..data_types import Integer from ..data_types.dtypes_helpers import ( @@ -99,7 +99,7 @@ def check_node_compatibility_with_mlir(node: IntermediateNode, is_output: bool) def check_graph_values_compatibility_with_mlir( op_graph: OPGraph, -) -> Optional[Dict[IntermediateNode, str]]: +) -> Optional[Dict[IntermediateNode, List[str]]]: """Make sure the graph outputs are unsigned integers, which is what the compiler supports. Args: @@ -115,7 +115,7 @@ def check_graph_values_compatibility_with_mlir( for node in op_graph.graph.nodes: is_output = node in op_graph.output_nodes.values() if (reason := check_node_compatibility_with_mlir(node, is_output)) is not None: - offending_nodes[node] = reason + offending_nodes[node] = [reason] return None if len(offending_nodes) == 0 else offending_nodes @@ -162,9 +162,9 @@ def update_bit_width_for_mlir(op_graph: OPGraph): # Check that current_node_out_bit_width is supported by the compiler if current_node_out_bit_width > ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB: - offending_nodes[ - node - ] = f"{current_node_out_bit_width} bits is not supported for the time being" + offending_nodes[node] = [ + f"{current_node_out_bit_width} bits is not supported for the time being" + ] if len(offending_nodes) != 0: raise RuntimeError( diff --git a/tests/common/debugging/test_printing.py b/tests/common/debugging/test_printing.py index c191d7052..622ef67fd 100644 --- a/tests/common/debugging/test_printing.py +++ b/tests/common/debugging/test_printing.py @@ -19,7 +19,7 @@ def test_get_printable_graph_with_offending_nodes(default_compilation_configurat default_compilation_configuration, ) - highlighted_nodes = {opgraph.input_nodes[0]: "foo"} + highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]} without_types = get_printable_graph( opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes @@ -54,7 +54,7 @@ return(%2) """.strip() ) - highlighted_nodes = {opgraph.input_nodes[0]: "foo", opgraph.output_nodes[0]: "bar"} + highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]} without_types = get_printable_graph( opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes @@ -72,6 +72,7 @@ return(%2) %1 = Constant(42) %2 = Add(%0, %1) ^^^^^^^^^^^^^^^^ bar + baz return(%2) """.strip() @@ -86,6 +87,7 @@ return(%2) %1 = Constant(42) # ClearScalar> %2 = Add(%0, %1) # EncryptedScalar> ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar + baz return(%2) """.strip()