mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(debugging): re-write graph formatting
This commit is contained in:
78
tests/common/debugging/test_formatting.py
Normal file
78
tests/common/debugging/test_formatting.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Test file for formatting"""
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_format_operation_graph_with_multiple_edges(default_compilation_configuration):
|
||||
"""Test format_operation_graph with multiple edges"""
|
||||
|
||||
def function(x):
|
||||
return x + x
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(4, True))},
|
||||
[(i,) for i in range(0, 10)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
formatted_graph = format_operation_graph(opgraph)
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<uint4>
|
||||
%1 = add(%0, %0) # EncryptedScalar<uint5>
|
||||
return %1
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test format_operation_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
|
||||
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
return %2
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
|
||||
formatted_graph = format_operation_graph(opgraph, highlighted_nodes=highlighted_nodes).strip()
|
||||
assert (
|
||||
formatted_graph
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<int4>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = 42 # ClearScalar<uint6>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint6>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return %2
|
||||
|
||||
""".strip()
|
||||
)
|
||||
@@ -1,94 +0,0 @@
|
||||
"""Test file for printing"""
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test format_operation_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
|
||||
opgraph = compile_numpy_function_into_op_graph(
|
||||
function,
|
||||
{"x": EncryptedScalar(Integer(7, True))},
|
||||
[(i,) for i in range(-5, 5)],
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
|
||||
|
||||
without_types = format_operation_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = format_operation_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
assert (
|
||||
without_types
|
||||
== """
|
||||
|
||||
%0 = x
|
||||
^^^^^^ foo
|
||||
%1 = Constant(42)
|
||||
%2 = Add(%0, %1)
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
|
||||
|
||||
without_types = format_operation_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = format_operation_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
assert (
|
||||
without_types
|
||||
== """
|
||||
|
||||
%0 = x
|
||||
^^^^^^ foo
|
||||
%1 = Constant(42)
|
||||
%2 = Add(%0, %1)
|
||||
^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
|
||||
assert (
|
||||
with_types
|
||||
== """
|
||||
|
||||
%0 = x # EncryptedScalar<Integer<signed, 4 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ foo
|
||||
%1 = Constant(42) # ClearScalar<Integer<unsigned, 6 bits>>
|
||||
%2 = Add(%0, %1) # EncryptedScalar<Integer<unsigned, 6 bits>>
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ bar
|
||||
baz
|
||||
return(%2)
|
||||
|
||||
""".strip()
|
||||
)
|
||||
Reference in New Issue
Block a user