mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-10 04:35:03 -05:00
105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
"""functions to draw the different graphs we can generate in the package, eg to debug."""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import matplotlib.pyplot as plt
|
|
import networkx as nx
|
|
from PIL import Image
|
|
|
|
from ..operator_graph import OPGraph
|
|
from ..representation import intermediate as ir
|
|
from ..representation.intermediate import ALL_IR_NODES
|
|
|
|
IR_NODE_COLOR_MAPPING = {
|
|
ir.Input: "blue",
|
|
ir.Constant: "cyan",
|
|
ir.Add: "red",
|
|
ir.Sub: "yellow",
|
|
ir.Mul: "green",
|
|
ir.ArbitraryFunction: "orange",
|
|
ir.Dot: "purple",
|
|
"ArbitraryFunction": "orange",
|
|
"TLU": "grey",
|
|
"output": "magenta",
|
|
}
|
|
|
|
_missing_nodes_in_mapping = ALL_IR_NODES - IR_NODE_COLOR_MAPPING.keys()
|
|
assert len(_missing_nodes_in_mapping) == 0, (
|
|
f"Missing IR node in IR_NODE_COLOR_MAPPING : "
|
|
f"{', '.join(sorted(str(node_type) for node_type in _missing_nodes_in_mapping))}"
|
|
)
|
|
|
|
del _missing_nodes_in_mapping
|
|
|
|
|
|
def draw_graph(
|
|
opgraph: OPGraph,
|
|
show: bool = False,
|
|
vertical: bool = True,
|
|
save_to: Optional[Path] = None,
|
|
) -> Image.Image:
|
|
"""Draws operation graphs and optionally saves/shows the drawing.
|
|
|
|
Args:
|
|
opgraph (OPGraph): the graph to be drawn and optionally saved/shown
|
|
show (bool): if set to True, the drawing will be shown using matplotlib
|
|
vertical (bool): if set to True, the orientation will be vertical
|
|
save_to (Optional[Path]): if specified, the drawn graph will be saved to this path
|
|
|
|
Returns:
|
|
Pillow Image of the drawn graph.
|
|
This is useful because you can use the drawing however you like.
|
|
(check https://pillow.readthedocs.io/en/stable/reference/Image.html for further information)
|
|
|
|
"""
|
|
|
|
def get_color(node, output_nodes):
|
|
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
|
|
if node in output_nodes:
|
|
value_to_return = IR_NODE_COLOR_MAPPING["output"]
|
|
elif isinstance(node, ir.ArbitraryFunction):
|
|
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
|
|
return value_to_return
|
|
|
|
graph = opgraph.graph
|
|
output_nodes = set(opgraph.output_nodes.values())
|
|
|
|
attributes = {
|
|
node: {
|
|
"label": node.label(),
|
|
"color": get_color(node, output_nodes),
|
|
"penwidth": 2, # double thickness for circles
|
|
"peripheries": 2 if node in output_nodes else 1, # double circle for output nodes
|
|
}
|
|
for node in graph.nodes
|
|
}
|
|
nx.set_node_attributes(graph, attributes)
|
|
|
|
for edge in graph.edges(keys=True):
|
|
idx = graph.edges[edge]["input_idx"]
|
|
graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look
|
|
|
|
agraph = nx.nx_agraph.to_agraph(graph)
|
|
agraph.graph_attr["rankdir"] = "TB" if vertical else "LR"
|
|
agraph.layout("dot")
|
|
|
|
if save_to is None:
|
|
with tempfile.NamedTemporaryFile(suffix=".png") as tmp:
|
|
agraph.draw(tmp.name)
|
|
img = Image.open(tmp.name)
|
|
else:
|
|
agraph.draw(save_to)
|
|
img = Image.open(save_to)
|
|
|
|
if show: # pragma: no cover
|
|
# We can't have coverage in this branch as `plt.show()` blocks and waits for user action.
|
|
plt.close("all")
|
|
plt.figure()
|
|
plt.imshow(img)
|
|
plt.axis("off")
|
|
plt.show()
|
|
|
|
return img
|