feat(debugging): provide a way for highlighting nodes with custom messages during printing

This commit is contained in:
Umut
2021-10-15 11:36:31 +03:00
parent 753ab5b6a2
commit 4c9c49ecd2
2 changed files with 108 additions and 5 deletions

View File

@@ -1,12 +1,12 @@
"""functions to print the different graphs we can generate in the package, eg to debug."""
from typing import Any, Dict
from typing import Any, Dict, Optional
import networkx as nx
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, Input, UnivariateFunction
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
def output_data_type_to_string(node):
@@ -39,18 +39,26 @@ def shorten_a_constant(constant_data: str):
return short_content
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
def get_printable_graph(
opgraph: OPGraph,
show_data_types: bool = False,
highlighted_nodes: Optional[Dict[IntermediateNode, str]] = None,
) -> str:
"""Return a string representing a graph.
Args:
opgraph (OPGraph): The graph that we want to draw
show_data_types (bool): Whether or not showing data_types of nodes, eg
to see their width
show_data_types (bool): Whether or not showing data_types of nodes, eg to see their width
highlighted_nodes (Optional[Dict[IntermediateNode, str]]):
The dict of nodes which will be highlighted and their corresponding messages
Returns:
str: a string to print or save in a file
"""
assert_true(isinstance(opgraph, OPGraph))
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
graph = opgraph.graph
@@ -127,6 +135,10 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
returned_str += f"{new_line}\n"
if node in highlighted_nodes:
message = highlighted_nodes[node]
returned_str += f"{'^' * len(new_line)} {message}\n"
map_table[node] = i
i += 1

View File

@@ -0,0 +1,91 @@
"""Test file for printing"""
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import get_printable_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
def test_get_printable_graph_with_offending_nodes():
"""Test get_printable_graph with offending nodes"""
def function(x):
return x + 42
opgraph = compile_numpy_function_into_op_graph(
function,
{"x": EncryptedScalar(Integer(7, True))},
[(i,) for i in range(-5, 5)],
)
highlighted_nodes = {opgraph.input_nodes[0]: "foo"}
without_types = get_printable_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
).strip()
with_types = get_printable_graph(
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
).strip()
assert (
without_types
== """
%0 = x
^^^^^^ foo
%1 = Constant(42)
%2 = Add(%0, %1)
return(%2)
""".strip()
)
assert (
with_types
== """
%0 = x # EncryptedScalar<Integer<signed, 6 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = Constant(42) # ClearScalar<Integer<unsigned, 7 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
return(%2)
""".strip()
)
highlighted_nodes = {opgraph.input_nodes[0]: "foo", opgraph.output_nodes[0]: "bar"}
without_types = get_printable_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
).strip()
with_types = get_printable_graph(
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
).strip()
assert (
without_types
== """
%0 = x
^^^^^^ foo
%1 = Constant(42)
%2 = Add(%0, %1)
^^^^^^^^^^^^^^^^ bar
return(%2)
""".strip()
)
assert (
with_types
== """
%0 = x # EncryptedScalar<Integer<signed, 6 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
%1 = Constant(42) # ClearScalar<Integer<unsigned, 7 bits>>
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
return(%2)
""".strip()
)