mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
feat: allow multiple graphs with the same name in debug artifacts
This commit is contained in:
@@ -7,7 +7,7 @@ import platform
|
||||
import shutil
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@@ -26,8 +26,8 @@ class DebugArtifacts:
|
||||
source_code: Optional[str]
|
||||
parameter_encryption_statuses: Dict[str, str]
|
||||
|
||||
drawings_of_graphs: Dict[str, str]
|
||||
textual_representations_of_graphs: Dict[str, str]
|
||||
drawings_of_graphs: Dict[str, List[str]]
|
||||
textual_representations_of_graphs: Dict[str, List[str]]
|
||||
|
||||
final_graph: Optional[Graph]
|
||||
bounds_of_the_final_graph: Optional[Dict[Node, Dict[str, Any]]]
|
||||
@@ -94,12 +94,18 @@ class DebugArtifacts:
|
||||
a representation of the function being compiled
|
||||
"""
|
||||
|
||||
if name not in self.textual_representations_of_graphs:
|
||||
self.textual_representations_of_graphs[name] = []
|
||||
|
||||
if name not in self.drawings_of_graphs:
|
||||
self.drawings_of_graphs[name] = []
|
||||
|
||||
textual_representation = graph.format()
|
||||
self.textual_representations_of_graphs[name] = textual_representation
|
||||
self.textual_representations_of_graphs[name].append(textual_representation)
|
||||
|
||||
try:
|
||||
drawing = graph.draw()
|
||||
self.drawings_of_graphs[name] = str(drawing)
|
||||
self.drawings_of_graphs[name].append(str(drawing))
|
||||
except ImportError as error: # pragma: no cover
|
||||
if "pygraphviz" in str(error):
|
||||
pass
|
||||
@@ -146,6 +152,8 @@ class DebugArtifacts:
|
||||
Export the collected information to `self.output_directory`.
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
|
||||
output_directory = self.output_directory
|
||||
if output_directory.exists():
|
||||
shutil.rmtree(output_directory)
|
||||
@@ -197,16 +205,24 @@ class DebugArtifacts:
|
||||
for name, parameter in self.parameter_encryption_statuses.items():
|
||||
f.write(f"{name} :: {parameter}\n")
|
||||
|
||||
identifier = 0
|
||||
|
||||
drawings = self.drawings_of_graphs.items()
|
||||
for index, (name, drawing_filename) in enumerate(drawings):
|
||||
identifier = f"{index + 1}.{name}.graph"
|
||||
shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png"))
|
||||
for name, drawing_filenames in drawings:
|
||||
for drawing_filename in drawing_filenames:
|
||||
identifier += 1
|
||||
output_path = output_directory.joinpath(f"{identifier}.{name}.graph.png")
|
||||
shutil.copy(drawing_filename, output_path)
|
||||
|
||||
identifier = 0
|
||||
|
||||
textual_representations = self.textual_representations_of_graphs.items()
|
||||
for index, (name, representation) in enumerate(textual_representations):
|
||||
identifier = f"{index + 1}.{name}.graph"
|
||||
with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f:
|
||||
f.write(f"{representation}\n")
|
||||
for name, representations in textual_representations:
|
||||
for representation in representations:
|
||||
identifier += 1
|
||||
output_path = output_directory.joinpath(f"{identifier}.{name}.graph.txt")
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
f.write(f"{representation}\n")
|
||||
|
||||
if self.bounds_of_the_final_graph is not None:
|
||||
assert self.final_graph is not None
|
||||
@@ -223,3 +239,5 @@ class DebugArtifacts:
|
||||
if self.client_parameters is not None:
|
||||
with open(output_directory.joinpath("client_parameters.json"), "wb") as f:
|
||||
f.write(self.client_parameters)
|
||||
|
||||
# pylint: enable=too-many-branches
|
||||
|
||||
@@ -5,6 +5,8 @@ Tests of `DebugArtifacts` class.
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from concrete.numpy.compilation import DebugArtifacts, compiler
|
||||
|
||||
|
||||
@@ -21,9 +23,11 @@ def test_artifacts_export(helpers):
|
||||
|
||||
@compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return x + 10
|
||||
a = ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.int64)
|
||||
b = np.where(x < 5, x * 10, x + 10)
|
||||
return a + b
|
||||
|
||||
inputset = range(100)
|
||||
inputset = range(10)
|
||||
f.compile(inputset, configuration, artifacts)
|
||||
|
||||
artifacts.export()
|
||||
@@ -37,8 +41,14 @@ def test_artifacts_export(helpers):
|
||||
assert (tmpdir / "1.initial.graph.txt").exists()
|
||||
assert (tmpdir / "1.initial.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "2.final.graph.txt").exists()
|
||||
assert (tmpdir / "2.final.graph.png").exists()
|
||||
assert (tmpdir / "2.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "2.after-fusing.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "3.after-fusing.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "4.final.graph.txt").exists()
|
||||
assert (tmpdir / "4.final.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "bounds.txt").exists()
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
@@ -55,8 +65,14 @@ def test_artifacts_export(helpers):
|
||||
assert (tmpdir / "1.initial.graph.txt").exists()
|
||||
assert (tmpdir / "1.initial.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "2.final.graph.txt").exists()
|
||||
assert (tmpdir / "2.final.graph.png").exists()
|
||||
assert (tmpdir / "2.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "2.after-fusing.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "3.after-fusing.graph.txt").exists()
|
||||
assert (tmpdir / "3.after-fusing.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "4.final.graph.txt").exists()
|
||||
assert (tmpdir / "4.final.graph.png").exists()
|
||||
|
||||
assert (tmpdir / "bounds.txt").exists()
|
||||
assert (tmpdir / "mlir.txt").exists()
|
||||
|
||||
@@ -46,7 +46,7 @@ def test_compiler_verbose_trace(helpers, capsys):
|
||||
|
||||
Computation Graph
|
||||
------------------------------------------------
|
||||
{str(list(artifacts.textual_representations_of_graphs.values())[-1])}
|
||||
{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])}
|
||||
------------------------------------------------
|
||||
|
||||
""".strip()
|
||||
@@ -74,7 +74,7 @@ def test_compiler_verbose_compile(helpers, capsys):
|
||||
|
||||
Computation Graph
|
||||
--------------------------------------------------------------------------------
|
||||
{list(artifacts.textual_representations_of_graphs.values())[-1]}
|
||||
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
MLIR
|
||||
|
||||
Reference in New Issue
Block a user