mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor(debugging): rename get_printable_graph to format_operation_graph
This commit is contained in:
@@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
from concrete.common.extensions.multi_table import MultiLookupTable
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
@@ -490,7 +490,7 @@ def test_compile_function_multiple_outputs(
|
||||
# when we have the converter, we can check the MLIR
|
||||
draw_graph(op_graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
@@ -732,7 +732,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration):
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
@@ -1134,7 +1134,7 @@ def test_compile_function_with_dot(
|
||||
data_gen(max_for_ij, repeat),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
@@ -154,13 +154,13 @@ return(%4)
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
"Test get_printable_graph and draw_graph"
|
||||
"Test format_operation_graph and draw_graph"
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -185,12 +185,12 @@ def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"Test get_printable_graph and draw_graph on graphs with direct table lookup"
|
||||
"Test format_operation_graph and draw_graph on graphs with direct table lookup"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -213,12 +213,12 @@ def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
"Test get_printable_graph and draw_graph on graphs with dot"
|
||||
"Test format_operation_graph and draw_graph on graphs with dot"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -289,12 +289,12 @@ return(%1)
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str):
|
||||
"Test get_printable_graph and draw_graph on graphs with generic function"
|
||||
"Test format_operation_graph and draw_graph on graphs with generic function"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -343,11 +343,11 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_
|
||||
],
|
||||
)
|
||||
def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
"""Test get_printable_graph with show_data_types"""
|
||||
"""Test format_operation_graph with show_data_types"""
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -399,12 +399,12 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
],
|
||||
)
|
||||
def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"""Test get_printable_graph with show_data_types on graphs with direct table lookup"""
|
||||
"""Test format_operation_graph with show_data_types on graphs with direct table lookup"""
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -414,7 +414,7 @@ def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_
|
||||
|
||||
|
||||
def test_numpy_long_constant():
|
||||
"Test get_printable_graph with long constant"
|
||||
"Test format_operation_graph with long constant"
|
||||
|
||||
def all_explicit_operations(x):
|
||||
intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10))
|
||||
@@ -440,7 +440,7 @@ def test_numpy_long_constant():
|
||||
return(%8)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == expected, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import get_printable_graph
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
@@ -204,7 +204,7 @@ def test_numpy_tracing_tensors():
|
||||
return(%12)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
assert format_operation_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
def test_numpy_explicit_tracing_tensors():
|
||||
@@ -243,7 +243,7 @@ def test_numpy_explicit_tracing_tensors():
|
||||
return(%12)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
assert format_operation_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user