mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(debugging): rename get_printable_graph to format_operation_graph
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Union
|
||||
import networkx as nx
|
||||
from PIL import Image
|
||||
|
||||
from ..debugging import assert_true, draw_graph, get_printable_graph
|
||||
from ..debugging import assert_true, draw_graph, format_operation_graph
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import IntermediateNode
|
||||
from ..values import BaseValue
|
||||
@@ -85,7 +85,7 @@ class CompilationArtifacts:
|
||||
"""
|
||||
|
||||
drawing = draw_graph(operation_graph)
|
||||
textual_representation = get_printable_graph(operation_graph, show_data_types=True)
|
||||
textual_representation = format_operation_graph(operation_graph, show_data_types=True)
|
||||
|
||||
self.drawings_of_operation_graphs[name] = drawing
|
||||
self.textual_representations_of_operation_graphs[name] = textual_representation
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Module for debugging."""
|
||||
from .custom_assert import assert_true
|
||||
from .drawing import draw_graph
|
||||
from .printing import get_printable_graph
|
||||
from .printing import format_operation_graph
|
||||
|
||||
@@ -45,7 +45,7 @@ def shorten_a_constant(constant_data: str):
|
||||
return short_content
|
||||
|
||||
|
||||
def get_printable_graph(
|
||||
def format_operation_graph(
|
||||
opgraph: OPGraph,
|
||||
show_data_types: bool = False,
|
||||
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional, Union
|
||||
|
||||
from zamalang import CompilerEngine
|
||||
|
||||
from .debugging import draw_graph, get_printable_graph
|
||||
from .debugging import draw_graph, format_operation_graph
|
||||
from .operator_graph import OPGraph
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class FHECircuit:
|
||||
self.engine = engine
|
||||
|
||||
def __str__(self):
|
||||
return get_printable_graph(self.opgraph, show_data_types=True)
|
||||
return format_operation_graph(self.opgraph, show_data_types=True)
|
||||
|
||||
def draw(
|
||||
self,
|
||||
|
||||
@@ -13,7 +13,7 @@ from ..data_types.dtypes_helpers import (
|
||||
value_is_scalar,
|
||||
value_is_unsigned_integer,
|
||||
)
|
||||
from ..debugging import get_printable_graph
|
||||
from ..debugging import format_operation_graph
|
||||
from ..debugging.custom_assert import assert_not_reached, assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation import intermediate
|
||||
@@ -199,7 +199,9 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
|
||||
f"max_bit_width of some nodes is too high for the current version of "
|
||||
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) "
|
||||
f"which is not compatible with:\n"
|
||||
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
|
||||
+ format_operation_graph(
|
||||
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
|
||||
)
|
||||
)
|
||||
|
||||
_set_all_bit_width(op_graph, max_bit_width)
|
||||
|
||||
@@ -9,7 +9,7 @@ from loguru import logger
|
||||
from ..compilation.artifacts import CompilationArtifacts
|
||||
from ..data_types.floats import Float
|
||||
from ..data_types.integers import Integer
|
||||
from ..debugging import get_printable_graph
|
||||
from ..debugging import format_operation_graph
|
||||
from ..debugging.custom_assert import assert_true
|
||||
from ..operator_graph import OPGraph
|
||||
from ..representation.intermediate import Constant, GenericFunction, Input, IntermediateNode
|
||||
@@ -131,7 +131,7 @@ def convert_float_subgraph_to_fused_node(
|
||||
float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes))
|
||||
float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node])
|
||||
|
||||
printable_graph = get_printable_graph(
|
||||
printable_graph = format_operation_graph(
|
||||
float_subgraph_as_op_graph,
|
||||
show_data_types=True,
|
||||
highlighted_nodes=node_with_issues_for_fusing,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger
|
||||
from ..common.debugging import draw_graph, get_printable_graph
|
||||
from ..common.debugging import draw_graph, format_operation_graph
|
||||
from ..common.extensions.multi_table import MultiLookupTable
|
||||
from ..common.extensions.table import LookupTable
|
||||
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_in
|
||||
from ..common.common_helpers import check_op_graph_is_integer_program
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.debugging import get_printable_graph
|
||||
from ..common.debugging import format_operation_graph
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
|
||||
from ..common.mlir.utils import (
|
||||
@@ -244,7 +244,9 @@ def prepare_op_graph_for_mlir(op_graph):
|
||||
if offending_nodes is not None:
|
||||
raise RuntimeError(
|
||||
"function you are trying to compile isn't supported for MLIR lowering\n\n"
|
||||
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
|
||||
+ format_operation_graph(
|
||||
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
|
||||
)
|
||||
)
|
||||
|
||||
# Update bit_width for MLIR
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
"""Test file for printing"""
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import get_printable_graph
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph
|
||||
|
||||
|
||||
def test_get_printable_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test get_printable_graph with offending nodes"""
|
||||
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
|
||||
"""Test format_operation_graph with offending nodes"""
|
||||
|
||||
def function(x):
|
||||
return x + 42
|
||||
@@ -21,10 +21,10 @@ def test_get_printable_graph_with_offending_nodes(default_compilation_configurat
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
|
||||
|
||||
without_types = get_printable_graph(
|
||||
without_types = format_operation_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = get_printable_graph(
|
||||
with_types = format_operation_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
@@ -56,10 +56,10 @@ return(%2)
|
||||
|
||||
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
|
||||
|
||||
without_types = get_printable_graph(
|
||||
without_types = format_operation_graph(
|
||||
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
with_types = get_printable_graph(
|
||||
with_types = format_operation_graph(
|
||||
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
|
||||
).strip()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import filecmp
|
||||
|
||||
import concrete.numpy as hnp
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
|
||||
|
||||
def test_circuit_str(default_compilation_configuration):
|
||||
@@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration):
|
||||
inputset = [(i,) for i in range(2 ** 3)]
|
||||
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
|
||||
|
||||
assert str(circuit) == get_printable_graph(circuit.opgraph, show_data_types=True)
|
||||
assert str(circuit) == format_operation_graph(circuit.opgraph, show_data_types=True)
|
||||
|
||||
|
||||
def test_circuit_draw(default_compilation_configuration):
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from concrete.common.compilation import CompilationConfiguration
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
from concrete.common.extensions.multi_table import MultiLookupTable
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
@@ -490,7 +490,7 @@ def test_compile_function_multiple_outputs(
|
||||
# when we have the converter, we can check the MLIR
|
||||
draw_graph(op_graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
@@ -732,7 +732,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration):
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
print(f"\n{str_of_the_graph}\n")
|
||||
|
||||
|
||||
@@ -1134,7 +1134,7 @@ def test_compile_function_with_dot(
|
||||
data_gen(max_for_ij, repeat),
|
||||
default_compilation_configuration,
|
||||
)
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
f"==================\nExpected \n{ref_graph_str}"
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph, get_printable_graph
|
||||
from concrete.common.debugging import draw_graph, format_operation_graph
|
||||
from concrete.common.extensions.table import LookupTable
|
||||
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
@@ -154,13 +154,13 @@ return(%4)
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
"Test get_printable_graph and draw_graph"
|
||||
"Test format_operation_graph and draw_graph"
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -185,12 +185,12 @@ def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"Test get_printable_graph and draw_graph on graphs with direct table lookup"
|
||||
"Test format_operation_graph and draw_graph on graphs with direct table lookup"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -213,12 +213,12 @@ def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
|
||||
"Test get_printable_graph and draw_graph on graphs with dot"
|
||||
"Test format_operation_graph and draw_graph on graphs with dot"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph)
|
||||
str_of_the_graph = format_operation_graph(graph)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -289,12 +289,12 @@ return(%1)
|
||||
],
|
||||
)
|
||||
def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str):
|
||||
"Test get_printable_graph and draw_graph on graphs with generic function"
|
||||
"Test format_operation_graph and draw_graph on graphs with generic function"
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -343,11 +343,11 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_
|
||||
],
|
||||
)
|
||||
def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
"""Test get_printable_graph with show_data_types"""
|
||||
"""Test format_operation_graph with show_data_types"""
|
||||
x, y = x_y
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -399,12 +399,12 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
|
||||
],
|
||||
)
|
||||
def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str):
|
||||
"""Test get_printable_graph with show_data_types on graphs with direct table lookup"""
|
||||
"""Test format_operation_graph with show_data_types on graphs with direct table lookup"""
|
||||
graph = tracing.trace_numpy_function(lambda_f, params)
|
||||
|
||||
draw_graph(graph, show=False)
|
||||
|
||||
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == ref_graph_str, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
@@ -414,7 +414,7 @@ def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_
|
||||
|
||||
|
||||
def test_numpy_long_constant():
|
||||
"Test get_printable_graph with long constant"
|
||||
"Test format_operation_graph with long constant"
|
||||
|
||||
def all_explicit_operations(x):
|
||||
intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10))
|
||||
@@ -440,7 +440,7 @@ def test_numpy_long_constant():
|
||||
return(%8)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
|
||||
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
|
||||
|
||||
assert str_of_the_graph == expected, (
|
||||
f"\n==================\nGot \n{str_of_the_graph}"
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
|
||||
from concrete.common.data_types.floats import Float
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import get_printable_graph
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.representation import intermediate as ir
|
||||
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
|
||||
from concrete.numpy import tracing
|
||||
@@ -204,7 +204,7 @@ def test_numpy_tracing_tensors():
|
||||
return(%12)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
assert format_operation_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
def test_numpy_explicit_tracing_tensors():
|
||||
@@ -243,7 +243,7 @@ def test_numpy_explicit_tracing_tensors():
|
||||
return(%12)
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
assert get_printable_graph(op_graph, show_data_types=True) == expected
|
||||
assert format_operation_graph(op_graph, show_data_types=True) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user