diff --git a/concrete/numpy/np_fhe_compiler.py b/concrete/numpy/np_fhe_compiler.py index 29f7a0710..cafe12d32 100644 --- a/concrete/numpy/np_fhe_compiler.py +++ b/concrete/numpy/np_fhe_compiler.py @@ -2,10 +2,14 @@ from copy import deepcopy from enum import Enum, unique +from pathlib import Path from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from loguru import logger + from ..common.compilation import CompilationArtifacts, CompilationConfiguration from ..common.data_types import Integer +from ..common.debugging import draw_graph, format_operation_graph from ..common.fhe_circuit import FHECircuit from ..common.operator_graph import OPGraph from ..common.representation.intermediate import IntermediateNode @@ -127,6 +131,45 @@ class NPFHECompiler: assert self._op_graph is not None return self._op_graph(*args) + def __str__(self) -> str: + self._eval_on_current_inputset() + if self._op_graph is None: + warning_msg = ( + f"__str__ failed: OPGraph is None, {self.__class__.__name__} " + "needs evaluation on an inputset" + ) + logger.warning(warning_msg) + return warning_msg + return format_operation_graph(self._op_graph) + + def draw_graph( + self, + show: bool = False, + vertical: bool = True, + save_to: Optional[Path] = None, + ) -> Optional[str]: + """Draws operation graphs and optionally saves/shows the drawing. + + Args: + op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown + show (bool): if set to True, the drawing will be shown using matplotlib + vertical (bool): if set to True, the orientation will be vertical + save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else + it is saved in a temporary file + + Returns: + Optional[str]: if OPGraph was not None returns the path as a string of the file where + the drawn graph is saved + """ + self._eval_on_current_inputset() + if self._op_graph is None: + logger.warning( + f"{self.draw_graph.__name__} failed: OPGraph is None, {self.__class__.__name__} " + "needs evaluation on an inputset" + ) + return None + return draw_graph(self._op_graph, show, vertical, save_to) + def eval_on_inputset(self, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]) -> None: """Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds. diff --git a/tests/common/debugging/test_drawing.py b/tests/common/debugging/test_drawing.py index b9b33920f..34d3e6137 100644 --- a/tests/common/debugging/test_drawing.py +++ b/tests/common/debugging/test_drawing.py @@ -1,11 +1,13 @@ """Test file for drawing""" +import filecmp import tempfile from pathlib import Path from concrete.common.data_types.integers import Integer from concrete.common.debugging import draw_graph from concrete.common.values import EncryptedScalar +from concrete.numpy import NPFHECompiler from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds @@ -22,8 +24,23 @@ def test_draw_graph_with_saving(default_compilation_configuration): default_compilation_configuration, ) + compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration) + + assert (got := compiler.draw_graph()) is None, got + + compiler.eval_on_inputset(range(-5, 5)) + with tempfile.TemporaryDirectory() as tmp: output_directory = Path(tmp) output_file = output_directory.joinpath("test.png") draw_graph(op_graph, save_to=output_file) assert output_file.exists() + + output_file_compiler = output_directory.joinpath("test_compiler.png") + compiler_output_file = compiler.draw_graph(save_to=output_file_compiler) + assert compiler_output_file is not None + compiler_output_file = Path(compiler_output_file) + assert compiler_output_file == output_file_compiler + assert compiler_output_file.exists() + + assert filecmp.cmp(output_file, compiler_output_file) diff --git a/tests/common/debugging/test_formatting.py b/tests/common/debugging/test_formatting.py index 7c47c6ba6..455802c78 100644 --- a/tests/common/debugging/test_formatting.py +++ b/tests/common/debugging/test_formatting.py @@ -5,6 +5,7 @@ import numpy from concrete.common.data_types.integers import Integer, UnsignedInteger from concrete.common.debugging import format_operation_graph from concrete.common.values import EncryptedScalar +from concrete.numpy import NPFHECompiler from concrete.numpy.compile import ( compile_numpy_function, compile_numpy_function_into_op_graph_and_measure_bounds, @@ -98,9 +99,8 @@ def test_format_operation_graph_with_fusing(default_compilation_configuration): default_compilation_configuration, ) - assert ( - str(circuit) - == """ + assert (got := str(circuit)) == ( + """ %0 = x # EncryptedScalar %1 = 1 # ClearScalar @@ -122,4 +122,39 @@ Subgraphs: return %6 """.strip() - ), str(circuit) + ), got + + compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration) + + assert ( + got := str(compiler) + ) == "__str__ failed: OPGraph is None, NPFHECompiler needs evaluation on an inputset", got + + compiler.eval_on_inputset(range(2 ** 3)) + + # String is different here as the type that is first propagated to trace the opgraph is not the + # same + + assert (got := str(compiler)) == ( + """ + +%0 = x # EncryptedScalar +%1 = 1 # ClearScalar +%2 = add(%0, %1) # EncryptedScalar +%3 = subgraph(%2) # EncryptedScalar +return %3 + +Subgraphs: + + %3 = subgraph(%2): + + %0 = 10 # ClearScalar + %1 = 1 # ClearScalar + %2 = float_subgraph_input # EncryptedScalar + %3 = cos(%2) # EncryptedScalar + %4 = add(%3, %1) # EncryptedScalar + %5 = mul(%4, %0) # EncryptedScalar + %6 = astype(%5, dtype=uint32) # EncryptedScalar + return %6 +""".strip() + ), got