mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: easily print and draw NPFHECompiler underlying OPGraph
closes #1075
This commit is contained in:
@@ -2,10 +2,14 @@
|
||||
|
||||
from copy import deepcopy
|
||||
from enum import Enum, unique
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from ..common.compilation import CompilationArtifacts, CompilationConfiguration
|
||||
from ..common.data_types import Integer
|
||||
from ..common.debugging import draw_graph, format_operation_graph
|
||||
from ..common.fhe_circuit import FHECircuit
|
||||
from ..common.operator_graph import OPGraph
|
||||
from ..common.representation.intermediate import IntermediateNode
|
||||
@@ -127,6 +131,45 @@ class NPFHECompiler:
|
||||
assert self._op_graph is not None
|
||||
return self._op_graph(*args)
|
||||
|
||||
def __str__(self) -> str:
|
||||
self._eval_on_current_inputset()
|
||||
if self._op_graph is None:
|
||||
warning_msg = (
|
||||
f"__str__ failed: OPGraph is None, {self.__class__.__name__} "
|
||||
"needs evaluation on an inputset"
|
||||
)
|
||||
logger.warning(warning_msg)
|
||||
return warning_msg
|
||||
return format_operation_graph(self._op_graph)
|
||||
|
||||
def draw_graph(
|
||||
self,
|
||||
show: bool = False,
|
||||
vertical: bool = True,
|
||||
save_to: Optional[Path] = None,
|
||||
) -> Optional[str]:
|
||||
"""Draws operation graphs and optionally saves/shows the drawing.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): the operation graph to be drawn and optionally saved/shown
|
||||
show (bool): if set to True, the drawing will be shown using matplotlib
|
||||
vertical (bool): if set to True, the orientation will be vertical
|
||||
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else
|
||||
it is saved in a temporary file
|
||||
|
||||
Returns:
|
||||
Optional[str]: if OPGraph was not None returns the path as a string of the file where
|
||||
the drawn graph is saved
|
||||
"""
|
||||
self._eval_on_current_inputset()
|
||||
if self._op_graph is None:
|
||||
logger.warning(
|
||||
f"{self.draw_graph.__name__} failed: OPGraph is None, {self.__class__.__name__} "
|
||||
"needs evaluation on an inputset"
|
||||
)
|
||||
return None
|
||||
return draw_graph(self._op_graph, show, vertical, save_to)
|
||||
|
||||
def eval_on_inputset(self, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]) -> None:
|
||||
"""Evaluate the underlying function on an inputset in one go, populates OPGraph and bounds.
|
||||
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
"""Test file for drawing"""
|
||||
|
||||
import filecmp
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import draw_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy import NPFHECompiler
|
||||
from concrete.numpy.compile import compile_numpy_function_into_op_graph_and_measure_bounds
|
||||
|
||||
|
||||
@@ -22,8 +24,23 @@ def test_draw_graph_with_saving(default_compilation_configuration):
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
|
||||
|
||||
assert (got := compiler.draw_graph()) is None, got
|
||||
|
||||
compiler.eval_on_inputset(range(-5, 5))
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
output_directory = Path(tmp)
|
||||
output_file = output_directory.joinpath("test.png")
|
||||
draw_graph(op_graph, save_to=output_file)
|
||||
assert output_file.exists()
|
||||
|
||||
output_file_compiler = output_directory.joinpath("test_compiler.png")
|
||||
compiler_output_file = compiler.draw_graph(save_to=output_file_compiler)
|
||||
assert compiler_output_file is not None
|
||||
compiler_output_file = Path(compiler_output_file)
|
||||
assert compiler_output_file == output_file_compiler
|
||||
assert compiler_output_file.exists()
|
||||
|
||||
assert filecmp.cmp(output_file, compiler_output_file)
|
||||
|
||||
@@ -5,6 +5,7 @@ import numpy
|
||||
from concrete.common.data_types.integers import Integer, UnsignedInteger
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.values import EncryptedScalar
|
||||
from concrete.numpy import NPFHECompiler
|
||||
from concrete.numpy.compile import (
|
||||
compile_numpy_function,
|
||||
compile_numpy_function_into_op_graph_and_measure_bounds,
|
||||
@@ -98,9 +99,8 @@ def test_format_operation_graph_with_fusing(default_compilation_configuration):
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
assert (
|
||||
str(circuit)
|
||||
== """
|
||||
assert (got := str(circuit)) == (
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint5>
|
||||
%1 = 1 # ClearScalar<uint6>
|
||||
@@ -122,4 +122,39 @@ Subgraphs:
|
||||
return %6
|
||||
|
||||
""".strip()
|
||||
), str(circuit)
|
||||
), got
|
||||
|
||||
compiler = NPFHECompiler(function, {"x": "encrypted"}, default_compilation_configuration)
|
||||
|
||||
assert (
|
||||
got := str(compiler)
|
||||
) == "__str__ failed: OPGraph is None, NPFHECompiler needs evaluation on an inputset", got
|
||||
|
||||
compiler.eval_on_inputset(range(2 ** 3))
|
||||
|
||||
# String is different here as the type that is first propagated to trace the opgraph is not the
|
||||
# same
|
||||
|
||||
assert (got := str(compiler)) == (
|
||||
"""
|
||||
|
||||
%0 = x # EncryptedScalar<uint3>
|
||||
%1 = 1 # ClearScalar<uint1>
|
||||
%2 = add(%0, %1) # EncryptedScalar<uint4>
|
||||
%3 = subgraph(%2) # EncryptedScalar<uint5>
|
||||
return %3
|
||||
|
||||
Subgraphs:
|
||||
|
||||
%3 = subgraph(%2):
|
||||
|
||||
%0 = 10 # ClearScalar<uint4>
|
||||
%1 = 1 # ClearScalar<uint1>
|
||||
%2 = float_subgraph_input # EncryptedScalar<uint1>
|
||||
%3 = cos(%2) # EncryptedScalar<float64>
|
||||
%4 = add(%3, %1) # EncryptedScalar<float64>
|
||||
%5 = mul(%4, %0) # EncryptedScalar<float64>
|
||||
%6 = astype(%5, dtype=uint32) # EncryptedScalar<uint5>
|
||||
return %6
|
||||
""".strip()
|
||||
), got
|
||||
|
||||
Reference in New Issue
Block a user