feat: easily print and draw NPFHECompiler underlying OPGraph

closes #1075
This commit is contained in:
Arthur Meyre
2021-12-06 09:54:14 +01:00
parent 5aad8c50ac
commit a9a8cdb223
3 changed files with 99 additions and 4 deletions

View File

@@ -2,10 +2,14 @@
from copy import deepcopy
from enum import Enum, unique
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
from loguru import logger
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
from ..common.data_types import Integer
from ..common.debugging import draw_graph, format_operation_graph
from ..common.fhe_circuit import FHECircuit
from ..common.operator_graph import OPGraph
from ..common.representation.intermediate import IntermediateNode
@@ -127,6 +131,45 @@ class NPFHECompiler:
assert self._op_graph is not None
return self._op_graph(*args)
def __str__(self) -> str:
self._eval_on_current_inputset()
if self._op_graph is None:
warning_msg = (
f"__str__ failed: OPGraph is None, {self.__class__.__name__} "
"needs evaluation on an inputset"
)
logger.warning(warning_msg)
return warning_msg
return format_operation_graph(self._op_graph)
def draw_graph(
self,
show: bool = False,
vertical: bool = True,
save_to: Optional[Path] = None,
) -> Optional[str]:
"""Draws operation graphs and optionally saves/shows the drawing.
Args:
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
show (bool): if set to True, the drawing will be shown using matplotlib
vertical (bool): if set to True, the orientation will be vertical
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
it is saved in a temporary file
Returns:
Optional[str]: if OPGraph was not None returns the path as a string of the file where
the drawn graph is saved
"""
self._eval_on_current_inputset()
if self._op_graph is None:
logger.warning(
f"{self.draw_graph.__name__} failed: OPGraph is None, {self.__class__.__name__} "
"needs evaluation on an inputset"
)
return None
return draw_graph(self._op_graph, show, vertical, save_to)
def eval_on_inputset(self, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]) -> None:
"""Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds.

View File

@@ -1,11 +1,13 @@
"""Test file for drawing"""
import filecmp
import tempfile
from pathlib import Path
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import draw_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy import NPFHECompiler
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
@@ -22,8 +24,23 @@ def test_draw_graph_with_saving(default_compilation_configuration):
default_compilation_configuration,
)
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
assert (got := compiler.draw_graph()) is None, got
compiler.eval_on_inputset(range(-5, 5))
with tempfile.TemporaryDirectory() as tmp:
output_directory = Path(tmp)
output_file = output_directory.joinpath("test.png")
draw_graph(op_graph, save_to=output_file)
assert output_file.exists()
output_file_compiler = output_directory.joinpath("test_compiler.png")
compiler_output_file = compiler.draw_graph(save_to=output_file_compiler)
assert compiler_output_file is not None
compiler_output_file = Path(compiler_output_file)
assert compiler_output_file == output_file_compiler
assert compiler_output_file.exists()
assert filecmp.cmp(output_file, compiler_output_file)

View File

@@ -5,6 +5,7 @@ import numpy
from concrete.common.data_types.integers import Integer, UnsignedInteger
from concrete.common.debugging import format_operation_graph
from concrete.common.values import EncryptedScalar
from concrete.numpy import NPFHECompiler
from concrete.numpy.compile import (
compile_numpy_function,
compile_numpy_function_into_op_graph_and_measure_bounds,
@@ -98,9 +99,8 @@ def test_format_operation_graph_with_fusing(default_compilation_configuration):
default_compilation_configuration,
)
assert (
str(circuit)
== """
assert (got := str(circuit)) == (
"""
%0 = x # EncryptedScalar<uint5>
%1 = 1 # ClearScalar<uint6>
@@ -122,4 +122,39 @@ Subgraphs:
return %6
""".strip()
), str(circuit)
), got
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
assert (
got := str(compiler)
) == "__str__ failed: OPGraph is None, NPFHECompiler needs evaluation on an inputset", got
compiler.eval_on_inputset(range(2 ** 3))
# String is different here as the type that is first propagated to trace the opgraph is not the
# same
assert (got := str(compiler)) == (
"""
%0 = x # EncryptedScalar<uint3>
%1 = 1 # ClearScalar<uint1>
%2 = add(%0, %1) # EncryptedScalar<uint4>
%3 = subgraph(%2) # EncryptedScalar<uint5>
return %3
Subgraphs:
%3 = subgraph(%2):
%0 = 10 # ClearScalar<uint4>
%1 = 1 # ClearScalar<uint1>
%2 = float_subgraph_input # EncryptedScalar<uint1>
%3 = cos(%2) # EncryptedScalar<float64>
%4 = add(%3, %1) # EncryptedScalar<float64>
%5 = mul(%4, %0) # EncryptedScalar<float64>
%6 = astype(%5, dtype=uint32) # EncryptedScalar<uint5>
return %6
""".strip()
), got