mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(debugging): implementing draw_graph
draw_graph is the function to show on a plot the traced function-to-compile add edge numbers refs #38
This commit is contained in:
committed by
Arthur Meyre
parent
8925fbd2db
commit
1196b00c6b
@@ -1,2 +1,2 @@
|
||||
"""HDK's module for shared data structures and code"""
|
||||
from . import data_types, representation
|
||||
from . import data_types, debugging, representation
|
||||
|
||||
2
hdk/common/debugging/__init__.py
Normal file
2
hdk/common/debugging/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""HDK's module for debugging"""
|
||||
from .draw_graph import draw_graph
|
||||
196
hdk/common/debugging/draw_graph.py
Normal file
196
hdk/common/debugging/draw_graph.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""functions to draw the different graphs we can generate in the package, eg to debug"""
|
||||
from typing import Dict, List
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
|
||||
from hdk.common.representation import intermediate as ir
|
||||
|
||||
IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"}
|
||||
|
||||
|
||||
def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict:
|
||||
"""
|
||||
Returns a pos to be used later with eg nx.draw_networkx_nodes, so that nodes
|
||||
are ordered by depth from input along the x axis and have a uniform
|
||||
distribution along the y axis
|
||||
|
||||
Args:
|
||||
graph (nx.Graph): The graph that we want to draw
|
||||
x_delta (float): Parameter used to set the increment in x
|
||||
y_delta (float): Parameter used to set the increment in y
|
||||
|
||||
Returns:
|
||||
pos (Dict): the argument to use with eg nx.draw_networkx_nodes
|
||||
|
||||
"""
|
||||
|
||||
# FIXME: less variables
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
nodes_depth = {node: 0 for node in graph.nodes()}
|
||||
input_nodes = [node for node in graph.nodes() if len(list(graph.predecessors(node))) == 0]
|
||||
|
||||
# Init a layout so that unreachable nodes have a pos, avoids potential crashes wiht networkx
|
||||
# use a cheap layout
|
||||
pos = nx.random_layout(graph)
|
||||
|
||||
curr_x = 0.0
|
||||
curr_y = -(len(input_nodes) - 1) / 2 * y_delta
|
||||
|
||||
for in_node in input_nodes:
|
||||
pos[in_node] = (curr_x, curr_y)
|
||||
curr_y += y_delta
|
||||
|
||||
curr_x += x_delta
|
||||
|
||||
curr_nodes = input_nodes
|
||||
|
||||
current_depth = 0
|
||||
while len(curr_nodes) > 0:
|
||||
current_depth += 1
|
||||
next_nodes_set = set()
|
||||
for node in curr_nodes:
|
||||
next_nodes_set.update(graph.successors(node))
|
||||
|
||||
curr_nodes = list(next_nodes_set)
|
||||
for node in curr_nodes:
|
||||
nodes_depth[node] = current_depth
|
||||
|
||||
nodes_by_depth: Dict[int, List[int]] = {}
|
||||
for node, depth in nodes_depth.items():
|
||||
nodes_for_depth = nodes_by_depth.get(depth, [])
|
||||
nodes_for_depth.append(node)
|
||||
nodes_by_depth[depth] = nodes_for_depth
|
||||
|
||||
depths = sorted(nodes_by_depth.keys())
|
||||
|
||||
for depth in depths:
|
||||
nodes_at_depth = nodes_by_depth[depth]
|
||||
|
||||
curr_y = -(len(nodes_at_depth) - 1) / 2 * y_delta
|
||||
for node in nodes_at_depth:
|
||||
pos[node] = (curr_x, curr_y)
|
||||
curr_y += y_delta
|
||||
|
||||
curr_x += x_delta
|
||||
|
||||
# pylint: enable=too-many-locals
|
||||
return pos
|
||||
|
||||
|
||||
def draw_graph(
|
||||
graph: nx.DiGraph, block_until_user_closes_graph: bool = True, draw_edge_numbers: bool = True
|
||||
) -> None:
|
||||
"""
|
||||
Draw a graph
|
||||
|
||||
Args:
|
||||
graph (nx.DiGraph): The graph that we want to draw
|
||||
block_until_user_closes_graph (bool): if True, will wait the user to
|
||||
close the figure before continuing; False is useful for the CI tests
|
||||
draw_edge_numbers (bool): if True, add the edge number on the arrow
|
||||
linking nodes, eg to differentiate the x and y in a Sub coding
|
||||
(x - y). This option is not that useful for commutative ops, and
|
||||
may make the picture a bit too dense, so could be deactivated
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
|
||||
# FIXME: less variables
|
||||
# pylint: disable=too-many-locals
|
||||
|
||||
# Positions of the node
|
||||
pos = human_readable_layout(graph)
|
||||
|
||||
# Colors and labels
|
||||
color_map = [IR_NODE_COLOR_MAPPING[type(node)] for node in graph.nodes()]
|
||||
|
||||
# For most types, we just pick the operation as the label, but for Input,
|
||||
# we take the name of the variable, ie the argument name of the function
|
||||
# to compile
|
||||
label_dict = {
|
||||
node: node.input_name if isinstance(node, ir.Input) else node.__class__.__name__
|
||||
for node in graph.nodes()
|
||||
}
|
||||
|
||||
# Draw nodes
|
||||
nx.draw_networkx_nodes(
|
||||
graph,
|
||||
pos,
|
||||
node_color=color_map,
|
||||
node_size=1000,
|
||||
alpha=1,
|
||||
)
|
||||
|
||||
# Draw labels
|
||||
nx.draw_networkx_labels(graph, pos, labels=label_dict)
|
||||
|
||||
current_axes = plt.gca()
|
||||
|
||||
# And draw edges in a way which works when we have two "equivalent edges",
|
||||
# ie from the same node A to the same node B, like to represent y = x + x
|
||||
already_done = set()
|
||||
|
||||
for e in graph.edges:
|
||||
|
||||
# If we already drew the different edges from e[0] to e[1], continue
|
||||
if (e[0], e[1]) in already_done:
|
||||
continue
|
||||
|
||||
already_done.add((e[0], e[1]))
|
||||
|
||||
edges = graph.get_edge_data(e[0], e[1])
|
||||
|
||||
# Draw the different edges from e[0] to e[1], continue
|
||||
for which, edge in enumerate(edges.values()):
|
||||
edge_index = edge["input_idx"]
|
||||
|
||||
# Draw the edge
|
||||
current_axes.annotate(
|
||||
"",
|
||||
xy=pos[e[0]],
|
||||
xycoords="data",
|
||||
xytext=pos[e[1]],
|
||||
textcoords="data",
|
||||
arrowprops=dict(
|
||||
arrowstyle="<-",
|
||||
color="0.5",
|
||||
shrinkA=5,
|
||||
shrinkB=5,
|
||||
patchA=None,
|
||||
patchB=None,
|
||||
connectionstyle="arc3,rad=rrr".replace("rrr", str(0.3 * which)),
|
||||
),
|
||||
)
|
||||
|
||||
if draw_edge_numbers:
|
||||
# Print the number of the node on the edge. This is a bit artisanal,
|
||||
# since it seems not possible to add the text directly on the
|
||||
# previously drawn arrow. So, more or less, we try to put a text at
|
||||
# a position which is close to pos[e[1]] and which varies a bit with
|
||||
# 'which'
|
||||
a, b = pos[e[0]]
|
||||
c, d = pos[e[1]]
|
||||
const_0 = 1
|
||||
const_1 = 2
|
||||
|
||||
current_axes.annotate(
|
||||
str(edge_index),
|
||||
xycoords="data",
|
||||
xy=(
|
||||
(const_0 * a + const_1 * c) / (const_0 + const_1),
|
||||
(const_0 * b + const_1 * d + 0.1 * which) / (const_0 + const_1),
|
||||
),
|
||||
textcoords="data",
|
||||
)
|
||||
|
||||
plt.axis("off")
|
||||
|
||||
# block_until_user_closes_graph is used as True for real users and False
|
||||
# for CI
|
||||
plt.show(block=block_until_user_closes_graph)
|
||||
|
||||
# pylint: enable=too-many-locals
|
||||
7
poetry.lock
generated
7
poetry.lock
generated
@@ -821,7 +821,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = ">=3.7,<3.10"
|
||||
content-hash = "2dbebb6e1d3b5f35cd48891bb9ed49b9d57d1f5659419e27150c5a3786d4b054"
|
||||
content-hash = "e0f2368e94fc45bae93f9695b3702a12f4d2c013caaff22012e3abfc4f7af6ab"
|
||||
|
||||
[metadata.files]
|
||||
alabaster = [
|
||||
@@ -1199,6 +1199,11 @@ pathspec = [
|
||||
{file = "pathspec-0.8.1.tar.gz", hash = "sha256:86379d6b86d75816baba717e64b1a3a3469deb93bb76d613c9ce79edc5cb68fd"},
|
||||
]
|
||||
pillow = [
|
||||
{file = "Pillow-8.3.1-1-cp36-cp36m-win_amd64.whl", hash = "sha256:fd7eef578f5b2200d066db1b50c4aa66410786201669fb76d5238b007918fb24"},
|
||||
{file = "Pillow-8.3.1-1-cp37-cp37m-win_amd64.whl", hash = "sha256:75e09042a3b39e0ea61ce37e941221313d51a9c26b8e54e12b3ececccb71718a"},
|
||||
{file = "Pillow-8.3.1-1-cp38-cp38-win_amd64.whl", hash = "sha256:c0e0550a404c69aab1e04ae89cca3e2a042b56ab043f7f729d984bf73ed2a093"},
|
||||
{file = "Pillow-8.3.1-1-cp39-cp39-win_amd64.whl", hash = "sha256:479ab11cbd69612acefa8286481f65c5dece2002ffaa4f9db62682379ca3bb77"},
|
||||
{file = "Pillow-8.3.1-1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:f156d6ecfc747ee111c167f8faf5f4953761b5e66e91a4e6767e548d0f80129c"},
|
||||
{file = "Pillow-8.3.1-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:196560dba4da7a72c5e7085fccc5938ab4075fd37fe8b5468869724109812edd"},
|
||||
{file = "Pillow-8.3.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c9569049d04aaacd690573a0398dbd8e0bf0255684fee512b413c2142ab723"},
|
||||
{file = "Pillow-8.3.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c088a000dfdd88c184cc7271bfac8c5b82d9efa8637cd2b68183771e3cf56f04"},
|
||||
|
||||
@@ -10,6 +10,7 @@ Sphinx = "^4.1.1"
|
||||
sphinx-rtd-theme = "^0.5.2"
|
||||
myst-parser = "^0.15.1"
|
||||
networkx = "^2.6.1"
|
||||
matplotlib = "^3.4.2"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
isort = "^5.9.2"
|
||||
|
||||
58
tests/hnumpy/test_debugging.py
Normal file
58
tests/hnumpy/test_debugging.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Test file for HDK's hnumpy debugging functions"""
|
||||
|
||||
import pytest
|
||||
|
||||
from hdk.common.data_types.integers import Integer
|
||||
from hdk.common.data_types.values import ClearValue, EncryptedValue
|
||||
from hdk.common.debugging import draw_graph
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f",
|
||||
[
|
||||
lambda x, y: x + y,
|
||||
lambda x, y: x + x - y * y * y + x,
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"x",
|
||||
[
|
||||
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
|
||||
# pytest.param(
|
||||
# EncryptedValue(Integer(64, is_signed=True)),
|
||||
# id="Encrypted int",
|
||||
# ),
|
||||
# pytest.param(
|
||||
# ClearValue(Integer(64, is_signed=False)),
|
||||
# id="Clear uint",
|
||||
# ),
|
||||
# pytest.param(
|
||||
# ClearValue(Integer(64, is_signed=True)),
|
||||
# id="Clear int",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"y",
|
||||
[
|
||||
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
|
||||
# pytest.param(
|
||||
# EncryptedValue(Integer(64, is_signed=True)),
|
||||
# id="Encrypted int",
|
||||
# ),
|
||||
pytest.param(
|
||||
ClearValue(Integer(64, is_signed=False)),
|
||||
id="Clear uint",
|
||||
),
|
||||
# pytest.param(
|
||||
# ClearValue(Integer(64, is_signed=True)),
|
||||
# id="Clear int",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
def test_hnumpy_draw_graph(lambda_f, x, y):
|
||||
"Test hnumpy draw_graph"
|
||||
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
|
||||
|
||||
draw_graph(graph, block_until_user_closes_graph=False)
|
||||
Reference in New Issue
Block a user