From f417246ea39be9bc59a4d646cf444dfd6452821c Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 8 Nov 2021 12:51:20 +0300 Subject: [PATCH] refactor(debugging): rename get_printable_graph to format_operation_graph --- concrete/common/compilation/artifacts.py | 4 +-- concrete/common/debugging/__init__.py | 2 +- concrete/common/debugging/printing.py | 2 +- concrete/common/fhe_circuit.py | 4 +-- concrete/common/mlir/utils.py | 6 +++-- concrete/common/optimization/topological.py | 4 +-- concrete/numpy/__init__.py | 2 +- concrete/numpy/compile.py | 6 +++-- tests/common/debugging/test_printing.py | 14 +++++----- tests/common/test_fhe_circuit.py | 4 +-- tests/numpy/test_compile.py | 8 +++--- tests/numpy/test_debugging.py | 30 ++++++++++----------- tests/numpy/test_tracing.py | 6 ++--- 13 files changed, 48 insertions(+), 44 deletions(-) diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py index 9c7178b00..35e7a8d4c 100644 --- a/concrete/common/compilation/artifacts.py +++ b/concrete/common/compilation/artifacts.py @@ -10,7 +10,7 @@ from typing import Any, Callable, Dict, Optional, Union import networkx as nx from PIL import Image -from ..debugging import assert_true, draw_graph, get_printable_graph +from ..debugging import assert_true, draw_graph, format_operation_graph from ..operator_graph import OPGraph from ..representation.intermediate import IntermediateNode from ..values import BaseValue @@ -85,7 +85,7 @@ class CompilationArtifacts: """ drawing = draw_graph(operation_graph) - textual_representation = get_printable_graph(operation_graph, show_data_types=True) + textual_representation = format_operation_graph(operation_graph, show_data_types=True) self.drawings_of_operation_graphs[name] = drawing self.textual_representations_of_operation_graphs[name] = textual_representation diff --git a/concrete/common/debugging/__init__.py b/concrete/common/debugging/__init__.py index 811bf62da..2532226ab 100644 --- a/concrete/common/debugging/__init__.py +++ b/concrete/common/debugging/__init__.py @@ -1,4 +1,4 @@ """Module for debugging.""" from .custom_assert import assert_true from .drawing import draw_graph -from .printing import get_printable_graph +from .printing import format_operation_graph diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 33ba7b0c9..795879d1e 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -45,7 +45,7 @@ def shorten_a_constant(constant_data: str): return short_content -def get_printable_graph( +def format_operation_graph( opgraph: OPGraph, show_data_types: bool = False, highlighted_nodes: Optional[Dict[IntermediateNode, List[str]]] = None, diff --git a/concrete/common/fhe_circuit.py b/concrete/common/fhe_circuit.py index 03b6b3623..40f2a5e60 100644 --- a/concrete/common/fhe_circuit.py +++ b/concrete/common/fhe_circuit.py @@ -5,7 +5,7 @@ from typing import List, Optional, Union from zamalang import CompilerEngine -from .debugging import draw_graph, get_printable_graph +from .debugging import draw_graph, format_operation_graph from .operator_graph import OPGraph @@ -20,7 +20,7 @@ class FHECircuit: self.engine = engine def __str__(self): - return get_printable_graph(self.opgraph, show_data_types=True) + return format_operation_graph(self.opgraph, show_data_types=True) def draw( self, diff --git a/concrete/common/mlir/utils.py b/concrete/common/mlir/utils.py index c2b7baebf..6ba3a1070 100644 --- a/concrete/common/mlir/utils.py +++ b/concrete/common/mlir/utils.py @@ -13,7 +13,7 @@ from ..data_types.dtypes_helpers import ( value_is_scalar, value_is_unsigned_integer, ) -from ..debugging import get_printable_graph +from ..debugging import format_operation_graph from ..debugging.custom_assert import assert_not_reached, assert_true from ..operator_graph import OPGraph from ..representation import intermediate @@ -199,7 +199,9 @@ def update_bit_width_for_mlir(op_graph: OPGraph): f"max_bit_width of some nodes is too high for the current version of " f"the compiler (maximum must be {ACCEPTABLE_MAXIMAL_BITWIDTH_FROM_CONCRETE_LIB}) " f"which is not compatible with:\n" - + get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes) + + format_operation_graph( + op_graph, show_data_types=True, highlighted_nodes=offending_nodes + ) ) _set_all_bit_width(op_graph, max_bit_width) diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 391d19390..3e284f289 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -9,7 +9,7 @@ from loguru import logger from ..compilation.artifacts import CompilationArtifacts from ..data_types.floats import Float from ..data_types.integers import Integer -from ..debugging import get_printable_graph +from ..debugging import format_operation_graph from ..debugging.custom_assert import assert_true from ..operator_graph import OPGraph from ..representation.intermediate import Constant, GenericFunction, Input, IntermediateNode @@ -131,7 +131,7 @@ def convert_float_subgraph_to_fused_node( float_subgraph = nx.MultiDiGraph(op_graph.graph.subgraph(subgraph_all_nodes)) float_subgraph_as_op_graph = OPGraph.from_graph(float_subgraph, [], [terminal_node]) - printable_graph = get_printable_graph( + printable_graph = format_operation_graph( float_subgraph_as_op_graph, show_data_types=True, highlighted_nodes=node_with_issues_for_fusing, diff --git a/concrete/numpy/__init__.py b/concrete/numpy/__init__.py index 60df0c831..19dc1cb7b 100644 --- a/concrete/numpy/__init__.py +++ b/concrete/numpy/__init__.py @@ -2,7 +2,7 @@ from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Float, Float32, Float64, Integer, SignedInteger, UnsignedInteger -from ..common.debugging import draw_graph, get_printable_graph +from ..common.debugging import draw_graph, format_operation_graph from ..common.extensions.multi_table import MultiLookupTable from ..common.extensions.table import LookupTable from ..common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor, TensorValue diff --git a/concrete/numpy/compile.py b/concrete/numpy/compile.py index ce7736145..e99bdee63 100644 --- a/concrete/numpy/compile.py +++ b/concrete/numpy/compile.py @@ -11,7 +11,7 @@ from ..common.bounds_measurement.inputset_eval import eval_op_graph_bounds_on_in from ..common.common_helpers import check_op_graph_is_integer_program from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Integer -from ..common.debugging import get_printable_graph +from ..common.debugging import format_operation_graph from ..common.fhe_circuit import FHECircuit from ..common.mlir import V0_OPSET_CONVERSION_FUNCTIONS from ..common.mlir.utils import ( @@ -244,7 +244,9 @@ def prepare_op_graph_for_mlir(op_graph): if offending_nodes is not None: raise RuntimeError( "function you are trying to compile isn't supported for MLIR lowering\n\n" - + get_printable_graph(op_graph, show_data_types=True, highlighted_nodes=offending_nodes) + + format_operation_graph( + op_graph, show_data_types=True, highlighted_nodes=offending_nodes + ) ) # Update bit_width for MLIR diff --git a/tests/common/debugging/test_printing.py b/tests/common/debugging/test_printing.py index 622ef67fd..9e7bc0cae 100644 --- a/tests/common/debugging/test_printing.py +++ b/tests/common/debugging/test_printing.py @@ -1,13 +1,13 @@ """Test file for printing""" from concrete.common.data_types.integers import Integer -from concrete.common.debugging import get_printable_graph +from concrete.common.debugging import format_operation_graph from concrete.common.values import EncryptedScalar from concrete.numpy.compile import compile_numpy_function_into_op_graph -def test_get_printable_graph_with_offending_nodes(default_compilation_configuration): - """Test get_printable_graph with offending nodes""" +def test_format_operation_graph_with_offending_nodes(default_compilation_configuration): + """Test format_operation_graph with offending nodes""" def function(x): return x + 42 @@ -21,10 +21,10 @@ def test_get_printable_graph_with_offending_nodes(default_compilation_configurat highlighted_nodes = {opgraph.input_nodes[0]: ["foo"]} - without_types = get_printable_graph( + without_types = format_operation_graph( opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes ).strip() - with_types = get_printable_graph( + with_types = format_operation_graph( opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes ).strip() @@ -56,10 +56,10 @@ return(%2) highlighted_nodes = {opgraph.input_nodes[0]: ["foo"], opgraph.output_nodes[0]: ["bar", "baz"]} - without_types = get_printable_graph( + without_types = format_operation_graph( opgraph, show_data_types=False, highlighted_nodes=highlighted_nodes ).strip() - with_types = get_printable_graph( + with_types = format_operation_graph( opgraph, show_data_types=True, highlighted_nodes=highlighted_nodes ).strip() diff --git a/tests/common/test_fhe_circuit.py b/tests/common/test_fhe_circuit.py index 72145026b..70315cfda 100644 --- a/tests/common/test_fhe_circuit.py +++ b/tests/common/test_fhe_circuit.py @@ -3,7 +3,7 @@ import filecmp import concrete.numpy as hnp -from concrete.common.debugging import draw_graph, get_printable_graph +from concrete.common.debugging import draw_graph, format_operation_graph def test_circuit_str(default_compilation_configuration): @@ -17,7 +17,7 @@ def test_circuit_str(default_compilation_configuration): inputset = [(i,) for i in range(2 ** 3)] circuit = hnp.compile_numpy_function(f, {"x": x}, inputset, default_compilation_configuration) - assert str(circuit) == get_printable_graph(circuit.opgraph, show_data_types=True) + assert str(circuit) == format_operation_graph(circuit.opgraph, show_data_types=True) def test_circuit_draw(default_compilation_configuration): diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 9e6999df3..1c58d1e54 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -8,7 +8,7 @@ import pytest from concrete.common.compilation import CompilationConfiguration from concrete.common.data_types.integers import Integer, UnsignedInteger -from concrete.common.debugging import draw_graph, get_printable_graph +from concrete.common.debugging import draw_graph, format_operation_graph from concrete.common.extensions.multi_table import MultiLookupTable from concrete.common.extensions.table import LookupTable from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor @@ -490,7 +490,7 @@ def test_compile_function_multiple_outputs( # when we have the converter, we can check the MLIR draw_graph(op_graph, show=False) - str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) print(f"\n{str_of_the_graph}\n") @@ -732,7 +732,7 @@ def test_compile_function_with_direct_tlu(default_compilation_configuration): default_compilation_configuration, ) - str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) print(f"\n{str_of_the_graph}\n") @@ -1134,7 +1134,7 @@ def test_compile_function_with_dot( data_gen(max_for_ij, repeat), default_compilation_configuration, ) - str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" f"==================\nExpected \n{ref_graph_str}" diff --git a/tests/numpy/test_debugging.py b/tests/numpy/test_debugging.py index d7542ca3f..5933971dc 100644 --- a/tests/numpy/test_debugging.py +++ b/tests/numpy/test_debugging.py @@ -4,7 +4,7 @@ import numpy import pytest from concrete.common.data_types.integers import Integer -from concrete.common.debugging import draw_graph, get_printable_graph +from concrete.common.debugging import draw_graph, format_operation_graph from concrete.common.extensions.table import LookupTable from concrete.common.values import ClearScalar, EncryptedScalar, EncryptedTensor from concrete.numpy import tracing @@ -154,13 +154,13 @@ return(%4) ], ) def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y): - "Test get_printable_graph and draw_graph" + "Test format_operation_graph and draw_graph" x, y = x_y graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) draw_graph(graph, show=False) - str_of_the_graph = get_printable_graph(graph) + str_of_the_graph = format_operation_graph(graph) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" @@ -185,12 +185,12 @@ def test_print_and_draw_graph(lambda_f, ref_graph_str, x_y): ], ) def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str): - "Test get_printable_graph and draw_graph on graphs with direct table lookup" + "Test format_operation_graph and draw_graph on graphs with direct table lookup" graph = tracing.trace_numpy_function(lambda_f, params) draw_graph(graph, show=False) - str_of_the_graph = get_printable_graph(graph) + str_of_the_graph = format_operation_graph(graph) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" @@ -213,12 +213,12 @@ def test_print_and_draw_graph_with_direct_tlu(lambda_f, params, ref_graph_str): ], ) def test_print_and_draw_graph_with_dot(lambda_f, params, ref_graph_str): - "Test get_printable_graph and draw_graph on graphs with dot" + "Test format_operation_graph and draw_graph on graphs with dot" graph = tracing.trace_numpy_function(lambda_f, params) draw_graph(graph, show=False) - str_of_the_graph = get_printable_graph(graph) + str_of_the_graph = format_operation_graph(graph) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" @@ -289,12 +289,12 @@ return(%1) ], ) def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_str): - "Test get_printable_graph and draw_graph on graphs with generic function" + "Test format_operation_graph and draw_graph on graphs with generic function" graph = tracing.trace_numpy_function(lambda_f, params) draw_graph(graph, show=False) - str_of_the_graph = get_printable_graph(graph, show_data_types=True) + str_of_the_graph = format_operation_graph(graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" @@ -343,11 +343,11 @@ def test_print_and_draw_graph_with_generic_function(lambda_f, params, ref_graph_ ], ) def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str): - """Test get_printable_graph with show_data_types""" + """Test format_operation_graph with show_data_types""" x, y = x_y graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) - str_of_the_graph = get_printable_graph(graph, show_data_types=True) + str_of_the_graph = format_operation_graph(graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" @@ -399,12 +399,12 @@ def test_print_with_show_data_types(lambda_f, x_y, ref_graph_str): ], ) def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_str): - """Test get_printable_graph with show_data_types on graphs with direct table lookup""" + """Test format_operation_graph with show_data_types on graphs with direct table lookup""" graph = tracing.trace_numpy_function(lambda_f, params) draw_graph(graph, show=False) - str_of_the_graph = get_printable_graph(graph, show_data_types=True) + str_of_the_graph = format_operation_graph(graph, show_data_types=True) assert str_of_the_graph == ref_graph_str, ( f"\n==================\nGot \n{str_of_the_graph}" @@ -414,7 +414,7 @@ def test_print_with_show_data_types_with_direct_tlu(lambda_f, params, ref_graph_ def test_numpy_long_constant(): - "Test get_printable_graph with long constant" + "Test format_operation_graph with long constant" def all_explicit_operations(x): intermediate = numpy.add(x, numpy.arange(100).reshape(10, 10)) @@ -440,7 +440,7 @@ def test_numpy_long_constant(): return(%8) """.lstrip() # noqa: E501 - str_of_the_graph = get_printable_graph(op_graph, show_data_types=True) + str_of_the_graph = format_operation_graph(op_graph, show_data_types=True) assert str_of_the_graph == expected, ( f"\n==================\nGot \n{str_of_the_graph}" diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index 60011f1c5..f84df3fc3 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -9,7 +9,7 @@ import pytest from concrete.common.data_types.dtypes_helpers import broadcast_shapes from concrete.common.data_types.floats import Float from concrete.common.data_types.integers import Integer -from concrete.common.debugging import get_printable_graph +from concrete.common.debugging import format_operation_graph from concrete.common.representation import intermediate as ir from concrete.common.values import ClearScalar, ClearTensor, EncryptedScalar, EncryptedTensor from concrete.numpy import tracing @@ -204,7 +204,7 @@ def test_numpy_tracing_tensors(): return(%12) """.lstrip() # noqa: E501 - assert get_printable_graph(op_graph, show_data_types=True) == expected + assert format_operation_graph(op_graph, show_data_types=True) == expected def test_numpy_explicit_tracing_tensors(): @@ -243,7 +243,7 @@ def test_numpy_explicit_tracing_tensors(): return(%12) """.lstrip() # noqa: E501 - assert get_printable_graph(op_graph, show_data_types=True) == expected + assert format_operation_graph(op_graph, show_data_types=True) == expected @pytest.mark.parametrize(