diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 688a4faef..0077ce794 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -1,12 +1,12 @@ """functions to print the different graphs we can generate in the package, eg to debug.""" -from typing import Any, Dict +from typing import Any, Dict, Optional import networkx as nx from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph -from ..representation.intermediate import Constant, Input, UnivariateFunction +from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction def output_data_type_to_string(node): @@ -39,18 +39,26 @@ def shorten_a_constant(constant_data: str): return short_content -def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: +def get_printable_graph( + opgraph: OPGraph, + show_data_types: bool = False, + highlighted_nodes: Optional[Dict[IntermediateNode, str]] = None, +) -> str: """Return a string representing a graph. Args: opgraph (OPGraph): The graph that we want to draw - show_data_types (bool): Whether or not showing data_types of nodes, eg - to see their width + show_data_types (bool): Whether or not showing data_types of nodes, eg to see their width + highlighted_nodes (Optional[Dict[IntermediateNode, str]]): + The dict of nodes which will be highlighted and their corresponding messages Returns: str: a string to print or save in a file """ assert_true(isinstance(opgraph, OPGraph)) + + highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {} + list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values()) graph = opgraph.graph @@ -127,6 +135,10 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: returned_str += f"{new_line}\n" + if node in highlighted_nodes: + message = highlighted_nodes[node] + returned_str += f"{'^' * len(new_line)} {message}\n" + map_table[node] = i i += 1 diff --git a/tests/common/debugging/test_printing.py b/tests/common/debugging/test_printing.py new file mode 100644 index 000000000..d183329ca --- /dev/null +++ b/tests/common/debugging/test_printing.py @@ -0,0 +1,91 @@ +"""Test file for printing""" + +from concrete.common.data_types.integers import Integer +from concrete.common.debugging import get_printable_graph +from concrete.common.values import EncryptedScalar +from concrete.numpy.compile import compile_numpy_function_into_op_graph + + +def test_get_printable_graph_with_offending_nodes(): + """Test get_printable_graph with offending nodes""" + + def function(x): + return x + 42 + + opgraph = compile_numpy_function_into_op_graph( + function, + {"x": EncryptedScalar(Integer(7, True))}, + [(i,) for i in range(-5, 5)], + ) + + highlighted_nodes = {opgraph.input_nodes[0]: "foo"} + + without_types = get_printable_graph( + opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes + ).strip() + with_types = get_printable_graph( + opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes + ).strip() + + assert ( + without_types + == """ + +%0 = x +^^^^^^ foo +%1 = Constant(42) +%2 = Add(%0, %1) +return(%2) + +""".strip() + ) + + assert ( + with_types + == """ + +%0 = x # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo +%1 = Constant(42) # ClearScalar> +%2 = Add(%0, %1) # EncryptedScalar> +return(%2) + +""".strip() + ) + + highlighted_nodes = {opgraph.input_nodes[0]: "foo", opgraph.output_nodes[0]: "bar"} + + without_types = get_printable_graph( + opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes + ).strip() + with_types = get_printable_graph( + opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes + ).strip() + + assert ( + without_types + == """ + +%0 = x +^^^^^^ foo +%1 = Constant(42) +%2 = Add(%0, %1) +^^^^^^^^^^^^^^^^ bar +return(%2) + +""".strip() + ) + + assert ( + with_types + == """ + +%0 = x # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo +%1 = Constant(42) # ClearScalar> +%2 = Add(%0, %1) # EncryptedScalar> +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar +return(%2) + +""".strip() + )