refactor(debugging): rename get_printable_graph to format_operation_graph

This commit is contained in:
Umut
2021-11-08 12:51:20 +03:00
parent 6ea46d2cbe
commit f417246ea3
13 changed files with 48 additions and 44 deletions

View File

@@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Union
import networkx as nx
from PIL import Image
from ..debugging import assert_true, draw_graph, get_printable_graph
from ..debugging import assert_true, draw_graph, format_operation_graph
from ..operator_graph import OPGraph
from ..representation.intermediate import IntermediateNode
from ..values import BaseValue
@@ -85,7 +85,7 @@ class CompilationArtifacts:
"""
drawing = draw_graph(operation_graph)
textual_representation = get_printable_graph(operation_graph, show_data_types=True)
textual_representation = format_operation_graph(operation_graph, show_data_types=True)
self.drawings_of_operation_graphs[name] = drawing
self.textual_representations_of_operation_graphs[name] = textual_representation

View File

@@ -1,4 +1,4 @@
"""Module for debugging."""
from .custom_assert import assert_true
from .drawing import draw_graph
from .printing import get_printable_graph
from .printing import format_operation_graph

View File

@@ -45,7 +45,7 @@ def shorten_a_constant(constant_data: str):
return short_content
def get_printable_graph(
def format_operation_graph(
opgraph: OPGraph,
show_data_types: bool = False,
highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None,

View File

@@ -5,7 +5,7 @@ from typing import List, Optional, Union
from zamalang import CompilerEngine
from .debugging import draw_graph, get_printable_graph
from .debugging import draw_graph, format_operation_graph
from .operator_graph import OPGraph
@@ -20,7 +20,7 @@ class FHECircuit:
self.engine = engine
def __str__(self):
return get_printable_graph(self.opgraph, show_data_types=True)
return format_operation_graph(self.opgraph, show_data_types=True)
def draw(
self,

View File

@@ -13,7 +13,7 @@ from ..data_types.dtypes_helpers import (
value_is_scalar,
value_is_unsigned_integer,
)
from ..debugging import get_printable_graph
from ..debugging import format_operation_graph
from ..debugging.custom_assert import assert_not_reached, assert_true
from ..operator_graph import OPGraph
from ..representation import intermediate
@@ -199,7 +199,9 @@ def update_bit_width_for_mlir(op_graph: OPGraph):
f"max_bit_width of some nodes is too high for the current version of "
f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) "
f"which is not compatible with:\n"
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
+ format_operation_graph(
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
)
)
_set_all_bit_width(op_graph, max_bit_width)

View File

@@ -9,7 +9,7 @@ from loguru import logger
from ..compilation.artifacts import CompilationArtifacts
from ..data_types.floats import Float
from ..data_types.integers import Integer
from ..debugging import get_printable_graph
from ..debugging import format_operation_graph
from ..debugging.custom_assert import assert_true
from ..operator_graph import OPGraph
from ..representation.intermediate import Constant, GenericFunction, Input, IntermediateNode
@@ -131,7 +131,7 @@ def convert_float_subgraph_to_fused_node(
float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes))
float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node])
printable_graph = get_printable_graph(
printable_graph = format_operation_graph(
float_subgraph_as_op_graph,
show_data_types=True,
highlighted_nodes=node_with_issues_for_fusing,

View File

@@ -2,7 +2,7 @@
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger
from ..common.debugging import draw_graph, get_printable_graph
from ..common.debugging import draw_graph, format_operation_graph
from ..common.extensions.multi_table import MultiLookupTable
from ..common.extensions.table import LookupTable
from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue

View File

@@ -11,7 +11,7 @@ from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_in
from ..common.common_helpers import check_op_graph_is_integer_program
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.debugging import get_printable_graph
from ..common.debugging import format_operation_graph
from ..common.fhe_circuit import FHECircuit
from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS
from ..common.mlir.utils import (
@@ -244,7 +244,9 @@ def prepare_op_graph_for_mlir(op_graph):
if offending_nodes is not None:
raise RuntimeError(
"function you are trying to compile isn't supported for MLIR lowering\n\n"
+ get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes)
+ format_operation_graph(
op_graph, show_data_types=True, highlighted_nodes=offending_nodes
)
)
# Update bit_width for MLIR

View File

@@ -1,13 +1,13 @@
"""Test file for printing"""
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import get_printable_graph
from concrete.common.debugging import format_operation_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy.compile import compile_numpy_function_into_op_graph
def test_get_printable_graph_with_offending_nodes(default_compilation_configuration):
"""Test get_printable_graph with offending nodes"""
def test_format_operation_graph_with_offending_nodes(default_compilation_configuration):
"""Test format_operation_graph with offending nodes"""
def function(x):
return x + 42
@@ -21,10 +21,10 @@ def test_get_printable_graph_with_offending_nodes(default_compilation_configurat
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]}
without_types = get_printable_graph(
without_types = format_operation_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
).strip()
with_types = get_printable_graph(
with_types = format_operation_graph(
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
).strip()
@@ -56,10 +56,10 @@ return(%2)
highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]}
without_types = get_printable_graph(
without_types = format_operation_graph(
opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes
).strip()
with_types = get_printable_graph(
with_types = format_operation_graph(
opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes
).strip()

View File

@@ -3,7 +3,7 @@
import filecmp
import concrete.numpy as hnp
from concrete.common.debugging import draw_graph, get_printable_graph
from concrete.common.debugging import draw_graph, format_operation_graph
def test_circuit_str(default_compilation_configuration):
@@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration):
inputset = [(i,) for i in range(2 ** 3)]
circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration)
assert str(circuit) == get_printable_graph(circuit.opgraph, show_data_types=True)
assert str(circuit) == format_operation_graph(circuit.opgraph, show_data_types=True)
def test_circuit_draw(default_compilation_configuration):

View File

@@ -8,7 +8,7 @@ import pytest
from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.debugging import draw_graph, get_printable_graph
from concrete.common.debugging import draw_graph, format_operation_graph
from concrete.common.extensions.multi_table import MultiLookupTable
from concrete.common.extensions.table import LookupTable
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
@@ -490,7 +490,7 @@ def test_compile_function_multiple_outputs(
# when we have the converter, we can check the MLIR
draw_graph(op_graph, show=False)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")
@@ -732,7 +732,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration):
default_compilation_configuration,
)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
print(f"\n{str_of_the_graph}\n")
@@ -1134,7 +1134,7 @@ def test_compile_function_with_dot(
data_gen(max_for_ij, repeat),
default_compilation_configuration,
)
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
f"==================\nExpected \n{ref_graph_str}"

View File

@@ -4,7 +4,7 @@ import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import draw_graph, get_printable_graph
from concrete.common.debugging import draw_graph, format_operation_graph
from concrete.common.extensions.table import LookupTable
from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
@@ -154,13 +154,13 @@ return(%4)
],
)
def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
"Test get_printable_graph and draw_graph"
"Test format_operation_graph and draw_graph"
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
str_of_the_graph = format_operation_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
@@ -185,12 +185,12 @@ def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y):
],
)
def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
"Test get_printable_graph and draw_graph on graphs with direct table lookup"
"Test format_operation_graph and draw_graph on graphs with direct table lookup"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
str_of_the_graph = format_operation_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
@@ -213,12 +213,12 @@ def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str):
],
)
def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str):
"Test get_printable_graph and draw_graph on graphs with dot"
"Test format_operation_graph and draw_graph on graphs with dot"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph)
str_of_the_graph = format_operation_graph(graph)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
@@ -289,12 +289,12 @@ return(%1)
],
)
def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str):
"Test get_printable_graph and draw_graph on graphs with generic function"
"Test format_operation_graph and draw_graph on graphs with generic function"
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
@@ -343,11 +343,11 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_
],
)
def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
"""Test get_printable_graph with show_data_types"""
"""Test format_operation_graph with show_data_types"""
x, y = x_y
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
@@ -399,12 +399,12 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str):
],
)
def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str):
"""Test get_printable_graph with show_data_types on graphs with direct table lookup"""
"""Test format_operation_graph with show_data_types on graphs with direct table lookup"""
graph = tracing.trace_numpy_function(lambda_f, params)
draw_graph(graph, show=False)
str_of_the_graph = get_printable_graph(graph, show_data_types=True)
str_of_the_graph = format_operation_graph(graph, show_data_types=True)
assert str_of_the_graph == ref_graph_str, (
f"\n==================\nGot \n{str_of_the_graph}"
@@ -414,7 +414,7 @@ def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_
def test_numpy_long_constant():
"Test get_printable_graph with long constant"
"Test format_operation_graph with long constant"
def all_explicit_operations(x):
intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10))
@@ -440,7 +440,7 @@ def test_numpy_long_constant():
return(%8)
""".lstrip() # noqa: E501
str_of_the_graph = get_printable_graph(op_graph, show_data_types=True)
str_of_the_graph = format_operation_graph(op_graph, show_data_types=True)
assert str_of_the_graph == expected, (
f"\n==================\nGot \n{str_of_the_graph}"

View File

@@ -9,7 +9,7 @@ import pytest
from concrete.common.data_types.dtypes_helpers import broadcast_shapes
from concrete.common.data_types.floats import Float
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import get_printable_graph
from concrete.common.debugging import format_operation_graph
from concrete.common.representation import intermediate as ir
from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy import tracing
@@ -204,7 +204,7 @@ def test_numpy_tracing_tensors():
return(%12)
""".lstrip() # noqa: E501
assert get_printable_graph(op_graph, show_data_types=True) == expected
assert format_operation_graph(op_graph, show_data_types=True) == expected
def test_numpy_explicit_tracing_tensors():
@@ -243,7 +243,7 @@ def test_numpy_explicit_tracing_tensors():
return(%12)
""".lstrip() # noqa: E501
assert get_printable_graph(op_graph, show_data_types=True) == expected
assert format_operation_graph(op_graph, show_data_types=True) == expected
@pytest.mark.parametrize(