feat: allow multiple graphs with the same name in debug artifacts

This commit is contained in:
Umut
2022-06-10 12:20:39 +02:00
parent 68e9ada9bf
commit 53e5dda732
3 changed files with 54 additions and 20 deletions

View File

@@ -7,7 +7,7 @@ import platform
import shutil
import subprocess
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import networkx as nx
@@ -26,8 +26,8 @@ class DebugArtifacts:
source_code: Optional[str]
parameter_encryption_statuses: Dict[str, str]
drawings_of_graphs: Dict[str, str]
textual_representations_of_graphs: Dict[str, str]
drawings_of_graphs: Dict[str, List[str]]
textual_representations_of_graphs: Dict[str, List[str]]
final_graph: Optional[Graph]
bounds_of_the_final_graph: Optional[Dict[Node, Dict[str, Any]]]
@@ -94,12 +94,18 @@ class DebugArtifacts:
a representation of the function being compiled
"""
if name not in self.textual_representations_of_graphs:
self.textual_representations_of_graphs[name] = []
if name not in self.drawings_of_graphs:
self.drawings_of_graphs[name] = []
textual_representation = graph.format()
self.textual_representations_of_graphs[name] = textual_representation
self.textual_representations_of_graphs[name].append(textual_representation)
try:
drawing = graph.draw()
self.drawings_of_graphs[name] = str(drawing)
self.drawings_of_graphs[name].append(str(drawing))
except ImportError as error: # pragma: no cover
if "pygraphviz" in str(error):
pass
@@ -146,6 +152,8 @@ class DebugArtifacts:
Export the collected information to `self.output_directory`.
"""
# pylint: disable=too-many-branches
output_directory = self.output_directory
if output_directory.exists():
shutil.rmtree(output_directory)
@@ -197,16 +205,24 @@ class DebugArtifacts:
for name, parameter in self.parameter_encryption_statuses.items():
f.write(f"{name} :: {parameter}\n")
identifier = 0
drawings = self.drawings_of_graphs.items()
for index, (name, drawing_filename) in enumerate(drawings):
identifier = f"{index + 1}.{name}.graph"
shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png"))
for name, drawing_filenames in drawings:
for drawing_filename in drawing_filenames:
identifier += 1
output_path = output_directory.joinpath(f"{identifier}.{name}.graph.png")
shutil.copy(drawing_filename, output_path)
identifier = 0
textual_representations = self.textual_representations_of_graphs.items()
for index, (name, representation) in enumerate(textual_representations):
identifier = f"{index + 1}.{name}.graph"
with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f:
f.write(f"{representation}\n")
for name, representations in textual_representations:
for representation in representations:
identifier += 1
output_path = output_directory.joinpath(f"{identifier}.{name}.graph.txt")
with open(output_path, "w", encoding="utf-8") as f:
f.write(f"{representation}\n")
if self.bounds_of_the_final_graph is not None:
assert self.final_graph is not None
@@ -223,3 +239,5 @@ class DebugArtifacts:
if self.client_parameters is not None:
with open(output_directory.joinpath("client_parameters.json"), "wb") as f:
f.write(self.client_parameters)
# pylint: enable=too-many-branches

View File

@@ -5,6 +5,8 @@ Tests of `DebugArtifacts` class.
import tempfile
from pathlib import Path
import numpy as np
from concrete.numpy.compilation import DebugArtifacts, compiler
@@ -21,9 +23,11 @@ def test_artifacts_export(helpers):
@compiler({"x": "encrypted"})
def f(x):
return x + 10
a = ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.int64)
b = np.where(x < 5, x * 10, x + 10)
return a + b
inputset = range(100)
inputset = range(10)
f.compile(inputset, configuration, artifacts)
artifacts.export()
@@ -37,8 +41,14 @@ def test_artifacts_export(helpers):
assert (tmpdir / "1.initial.graph.txt").exists()
assert (tmpdir / "1.initial.graph.png").exists()
assert (tmpdir / "2.final.graph.txt").exists()
assert (tmpdir / "2.final.graph.png").exists()
assert (tmpdir / "2.after-fusing.graph.txt").exists()
assert (tmpdir / "2.after-fusing.graph.png").exists()
assert (tmpdir / "3.after-fusing.graph.txt").exists()
assert (tmpdir / "3.after-fusing.graph.png").exists()
assert (tmpdir / "4.final.graph.txt").exists()
assert (tmpdir / "4.final.graph.png").exists()
assert (tmpdir / "bounds.txt").exists()
assert (tmpdir / "mlir.txt").exists()
@@ -55,8 +65,14 @@ def test_artifacts_export(helpers):
assert (tmpdir / "1.initial.graph.txt").exists()
assert (tmpdir / "1.initial.graph.png").exists()
assert (tmpdir / "2.final.graph.txt").exists()
assert (tmpdir / "2.final.graph.png").exists()
assert (tmpdir / "2.after-fusing.graph.txt").exists()
assert (tmpdir / "2.after-fusing.graph.png").exists()
assert (tmpdir / "3.after-fusing.graph.txt").exists()
assert (tmpdir / "3.after-fusing.graph.png").exists()
assert (tmpdir / "4.final.graph.txt").exists()
assert (tmpdir / "4.final.graph.png").exists()
assert (tmpdir / "bounds.txt").exists()
assert (tmpdir / "mlir.txt").exists()

View File

@@ -46,7 +46,7 @@ def test_compiler_verbose_trace(helpers, capsys):
Computation Graph
------------------------------------------------
{str(list(artifacts.textual_representations_of_graphs.values())[-1])}
{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])}
------------------------------------------------
""".strip()
@@ -74,7 +74,7 @@ def test_compiler_verbose_compile(helpers, capsys):
Computation Graph
--------------------------------------------------------------------------------
{list(artifacts.textual_representations_of_graphs.values())[-1]}
{list(artifacts.textual_representations_of_graphs.values())[-1][-1]}
--------------------------------------------------------------------------------
MLIR