mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(debugging): provide a way for highlighting nodes with custom messages during printing
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
"""functions to print the different graphs we can generate in the package, eg to debug."""
|
||||
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Constant, Input, UnivariateFunction
|
||||
from ..representation.intermediate import Constant, Input, IntermediateNode, UnivariateFunction
|
||||
|
||||
|
||||
def output_data_type_to_string(node):
|
||||
@@ -39,18 +39,26 @@ def shorten_a_constant(constant_data: str):
|
||||
return short_content
|
||||
|
||||
|
||||
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
def get_printable_graph(
|
||||
opgraph: OPGraph,
|
||||
show_data_types: bool = False,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, str]] = None,
|
||||
) -> str:
|
||||
"""Return a string representing a graph.
|
||||
|
||||
Args:
|
||||
opgraph (OPGraph): The graph that we want to draw
|
||||
show_data_types (bool): Whether or not showing data_types of nodes, eg
|
||||
to see their width
|
||||
show_data_types (bool): Whether or not showing data_types of nodes, eg to see their width
|
||||
highlighted_nodes (Optional[Dict[IntermediateNode, str]]):
|
||||
The dict of nodes which will be highlighted and their corresponding messages
|
||||
|
||||
Returns:
|
||||
str: a string to print or save in a file
|
||||
"""
|
||||
assert_true(isinstance(opgraph, OPGraph))
|
||||
|
||||
highlighted_nodes = highlighted_nodes if highlighted_nodes is not None else {}
|
||||
|
||||
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
|
||||
graph = opgraph.graph
|
||||
|
||||
@@ -127,6 +135,10 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
|
||||
returned_str += f"{new_line}\n"
|
||||
|
||||
if node in highlighted_nodes:
|
||||
message = highlighted_nodes[node]
|
||||
returned_str += f"{'^' * len(new_line)} {message}\n"
|
||||
|
||||
map_table[node] = i
|
||||
i += 1
|
||||
|
||||
|
||||
91
tests/common/debugging/test_printing.py
Normal file
91
tests/common/debugging/test_printing.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Test file for printing"""
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import get_printable_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_get_printable_graph_with_offending_nodes():
|
||||
"""Test get_printable_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: "foo"}
|
||||
|
||||
without_types = get_printable_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = get_printable_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
assert (
|
||||
without_types
|
||||
== """
|
||||
|
||||
%0 = x
|
||||
^^^^^^ foo
|
||||
%1 = Constant(42)
|
||||
%2 = Add(%0, %1)
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 6 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 7 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: "foo", opgraph.output_nodes[0]: "bar"}
|
||||
|
||||
without_types = get_printable_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = get_printable_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
assert (
|
||||
without_types
|
||||
== """
|
||||
|
||||
%0 = x
|
||||
^^^^^^ foo
|
||||
%1 = Constant(42)
|
||||
%2 = Add(%0, %1)
|
||||
^^^^^^^^^^^^^^^^ bar
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 6 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 7 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
Reference in New Issue
Block a user