diff --git a/concrete/common/compilation/artifacts.py b/concrete/common/compilation/artifacts.py index f35347dd1..eaefd8f3f 100644 --- a/concrete/common/compilation/artifacts.py +++ b/concrete/common/compilation/artifacts.py @@ -175,9 +175,9 @@ class CompilationArtifacts: f.write(f"{name} :: {parameter}\n") drawings = self.drawings_of_operation_graphs.items() - for index, (name, drawing) in enumerate(drawings): + for index, (name, drawing_filename) in enumerate(drawings): identifier = CompilationArtifacts._identifier(index, name) - drawing.save(output_directory.joinpath(f"{identifier}.png")) + shutil.copy(drawing_filename, output_directory.joinpath(f"{identifier}.png")) textual_representations = self.textual_representations_of_operation_graphs.items() for index, (name, representation) in enumerate(textual_representations): diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index bcca75469..305b75ed7 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -43,19 +43,18 @@ def draw_graph( show: bool = False, vertical: bool = True, save_to: Optional[Path] = None, -) -> Image.Image: +) -> str: """Draws operation graphs and optionally saves/shows the drawing. Args: opgraph (OPGraph): the graph to be drawn and optionally saved/shown show (bool): if set to True, the drawing will be shown using matplotlib vertical (bool): if set to True, the orientation will be vertical - save_to (Optional[Path]): if specified, the drawn graph will be saved to this path + save_to (Optional[Path]): if specified, the drawn graph will be saved to this path; else + it is saved in a temporary file Returns: - Pillow Image of the drawn graph. - This is useful because you can use the drawing however you like. - (check https://pillow.readthedocs.io/en/stable/reference/Image.html for further information) + The path of the file where the drawn graph is saved """ @@ -90,19 +89,21 @@ def draw_graph( agraph.layout("dot") if save_to is None: - with tempfile.NamedTemporaryFile(suffix=".png") as tmp: - agraph.draw(tmp.name) - img = Image.open(tmp.name) + with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: + save_to_str = str(tmp.name) else: - agraph.draw(save_to) - img = Image.open(save_to) + save_to_str = str(save_to) + + agraph.draw(save_to_str) if show: # pragma: no cover # We can't have coverage in this branch as `plt.show()` blocks and waits for user action. plt.close("all") plt.figure() + img = Image.open(save_to_str) plt.imshow(img) + img.close() plt.axis("off") plt.show() - return img + return save_to_str