diff --git a/concrete/numpy/compilation/artifacts.py b/concrete/numpy/compilation/artifacts.py index b99fab395..6a8dbd703 100644 --- a/concrete/numpy/compilation/artifacts.py +++ b/concrete/numpy/compilation/artifacts.py @@ -7,7 +7,7 @@ import platform import shutil import subprocess from pathlib import Path -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import networkx as nx @@ -26,8 +26,8 @@ class DebugArtifacts: source_code: Optional[str] parameter_encryption_statuses: Dict[str, str] - drawings_of_graphs: Dict[str, str] - textual_representations_of_graphs: Dict[str, str] + drawings_of_graphs: Dict[str, List[str]] + textual_representations_of_graphs: Dict[str, List[str]] final_graph: Optional[Graph] bounds_of_the_final_graph: Optional[Dict[Node, Dict[str, Any]]] @@ -94,12 +94,18 @@ class DebugArtifacts: a representation of the function being compiled """ + if name not in self.textual_representations_of_graphs: + self.textual_representations_of_graphs[name] = [] + + if name not in self.drawings_of_graphs: + self.drawings_of_graphs[name] = [] + textual_representation = graph.format() - self.textual_representations_of_graphs[name] = textual_representation + self.textual_representations_of_graphs[name].append(textual_representation) try: drawing = graph.draw() - self.drawings_of_graphs[name] = str(drawing) + self.drawings_of_graphs[name].append(str(drawing)) except ImportError as error: # pragma: no cover if "pygraphviz" in str(error): pass @@ -146,6 +152,8 @@ class DebugArtifacts: Export the collected information to `self.output_directory`. """ + # pylint: disable=too-many-branches + output_directory = self.output_directory if output_directory.exists(): shutil.rmtree(output_directory) @@ -197,16 +205,24 @@ class DebugArtifacts: for name, parameter in self.parameter_encryption_statuses.items(): f.write(f"{name} :: {parameter}\n") + identifier = 0 + drawings = self.drawings_of_graphs.items() - for index, (name, drawing_filename) in enumerate(drawings): - identifier = f"{index + 1}.{name}.graph" - shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png")) + for name, drawing_filenames in drawings: + for drawing_filename in drawing_filenames: + identifier += 1 + output_path = output_directory.joinpath(f"{identifier}.{name}.graph.png") + shutil.copy(drawing_filename, output_path) + + identifier = 0 textual_representations = self.textual_representations_of_graphs.items() - for index, (name, representation) in enumerate(textual_representations): - identifier = f"{index + 1}.{name}.graph" - with open(output_directory.joinpath(f"{identifier}.txt"), "w", encoding="utf-8") as f: - f.write(f"{representation}\n") + for name, representations in textual_representations: + for representation in representations: + identifier += 1 + output_path = output_directory.joinpath(f"{identifier}.{name}.graph.txt") + with open(output_path, "w", encoding="utf-8") as f: + f.write(f"{representation}\n") if self.bounds_of_the_final_graph is not None: assert self.final_graph is not None @@ -223,3 +239,5 @@ class DebugArtifacts: if self.client_parameters is not None: with open(output_directory.joinpath("client_parameters.json"), "wb") as f: f.write(self.client_parameters) + + # pylint: enable=too-many-branches diff --git a/tests/compilation/test_artifacts.py b/tests/compilation/test_artifacts.py index 710749285..cb72735eb 100644 --- a/tests/compilation/test_artifacts.py +++ b/tests/compilation/test_artifacts.py @@ -5,6 +5,8 @@ Tests of `DebugArtifacts` class. import tempfile from pathlib import Path +import numpy as np + from concrete.numpy.compilation import DebugArtifacts, compiler @@ -21,9 +23,11 @@ def test_artifacts_export(helpers): @compiler({"x": "encrypted"}) def f(x): - return x + 10 + a = ((np.sin(x) ** 2) + (np.cos(x) ** 2)).astype(np.int64) + b = np.where(x < 5, x * 10, x + 10) + return a + b - inputset = range(100) + inputset = range(10) f.compile(inputset, configuration, artifacts) artifacts.export() @@ -37,8 +41,14 @@ def test_artifacts_export(helpers): assert (tmpdir / "1.initial.graph.txt").exists() assert (tmpdir / "1.initial.graph.png").exists() - assert (tmpdir / "2.final.graph.txt").exists() - assert (tmpdir / "2.final.graph.png").exists() + assert (tmpdir / "2.after-fusing.graph.txt").exists() + assert (tmpdir / "2.after-fusing.graph.png").exists() + + assert (tmpdir / "3.after-fusing.graph.txt").exists() + assert (tmpdir / "3.after-fusing.graph.png").exists() + + assert (tmpdir / "4.final.graph.txt").exists() + assert (tmpdir / "4.final.graph.png").exists() assert (tmpdir / "bounds.txt").exists() assert (tmpdir / "mlir.txt").exists() @@ -55,8 +65,14 @@ def test_artifacts_export(helpers): assert (tmpdir / "1.initial.graph.txt").exists() assert (tmpdir / "1.initial.graph.png").exists() - assert (tmpdir / "2.final.graph.txt").exists() - assert (tmpdir / "2.final.graph.png").exists() + assert (tmpdir / "2.after-fusing.graph.txt").exists() + assert (tmpdir / "2.after-fusing.graph.png").exists() + + assert (tmpdir / "3.after-fusing.graph.txt").exists() + assert (tmpdir / "3.after-fusing.graph.png").exists() + + assert (tmpdir / "4.final.graph.txt").exists() + assert (tmpdir / "4.final.graph.png").exists() assert (tmpdir / "bounds.txt").exists() assert (tmpdir / "mlir.txt").exists() diff --git a/tests/compilation/test_decorator.py b/tests/compilation/test_decorator.py index dae690a85..8ddda22bf 100644 --- a/tests/compilation/test_decorator.py +++ b/tests/compilation/test_decorator.py @@ -46,7 +46,7 @@ def test_compiler_verbose_trace(helpers, capsys): Computation Graph ------------------------------------------------ -{str(list(artifacts.textual_representations_of_graphs.values())[-1])} +{str(list(artifacts.textual_representations_of_graphs.values())[-1][-1])} ------------------------------------------------ """.strip() @@ -74,7 +74,7 @@ def test_compiler_verbose_compile(helpers, capsys): Computation Graph -------------------------------------------------------------------------------- -{list(artifacts.textual_representations_of_graphs.values())[-1]} +{list(artifacts.textual_representations_of_graphs.values())[-1][-1]} -------------------------------------------------------------------------------- MLIR