feat(debugging): implementing draw_graph

draw_graph is the function to show on a plot the traced function-to-compile
add edge numbers
refs #38
This commit is contained in:
Benoit Chevallier-Mames
2021-07-28 17:38:27 +02:00
committed by Arthur Meyre
parent 8925fbd2db
commit 1196b00c6b
6 changed files with 264 additions and 2 deletions

View File

@@ -1,2 +1,2 @@
"""HDK's module for shared data structures and code"""
from . import data_types, representation
from . import data_types, debugging, representation

View File

@@ -0,0 +1,2 @@
"""HDK's module for debugging"""
from .draw_graph import draw_graph

View File

@@ -0,0 +1,196 @@
"""functions to draw the different graphs we can generate in the package, eg to debug"""
from typing import Dict, List
import matplotlib.pyplot as plt
import networkx as nx
from hdk.common.representation import intermediate as ir
IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"}
def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict:
"""
Returns a pos to be used later with eg nx.draw_networkx_nodes, so that nodes
are ordered by depth from input along the x axis and have a uniform
distribution along the y axis
Args:
graph (nx.Graph): The graph that we want to draw
x_delta (float): Parameter used to set the increment in x
y_delta (float): Parameter used to set the increment in y
Returns:
pos (Dict): the argument to use with eg nx.draw_networkx_nodes
"""
# FIXME: less variables
# pylint: disable=too-many-locals
nodes_depth = {node: 0 for node in graph.nodes()}
input_nodes = [node for node in graph.nodes() if len(list(graph.predecessors(node))) == 0]
# Init a layout so that unreachable nodes have a pos, avoids potential crashes wiht networkx
# use a cheap layout
pos = nx.random_layout(graph)
curr_x = 0.0
curr_y = -(len(input_nodes) - 1) / 2 * y_delta
for in_node in input_nodes:
pos[in_node] = (curr_x, curr_y)
curr_y += y_delta
curr_x += x_delta
curr_nodes = input_nodes
current_depth = 0
while len(curr_nodes) > 0:
current_depth += 1
next_nodes_set = set()
for node in curr_nodes:
next_nodes_set.update(graph.successors(node))
curr_nodes = list(next_nodes_set)
for node in curr_nodes:
nodes_depth[node] = current_depth
nodes_by_depth: Dict[int, List[int]] = {}
for node, depth in nodes_depth.items():
nodes_for_depth = nodes_by_depth.get(depth, [])
nodes_for_depth.append(node)
nodes_by_depth[depth] = nodes_for_depth
depths = sorted(nodes_by_depth.keys())
for depth in depths:
nodes_at_depth = nodes_by_depth[depth]
curr_y = -(len(nodes_at_depth) - 1) / 2 * y_delta
for node in nodes_at_depth:
pos[node] = (curr_x, curr_y)
curr_y += y_delta
curr_x += x_delta
# pylint: enable=too-many-locals
return pos
def draw_graph(
graph: nx.DiGraph, block_until_user_closes_graph: bool = True, draw_edge_numbers: bool = True
) -> None:
"""
Draw a graph
Args:
graph (nx.DiGraph): The graph that we want to draw
block_until_user_closes_graph (bool): if True, will wait the user to
close the figure before continuing; False is useful for the CI tests
draw_edge_numbers (bool): if True, add the edge number on the arrow
linking nodes, eg to differentiate the x and y in a Sub coding
(x - y). This option is not that useful for commutative ops, and
may make the picture a bit too dense, so could be deactivated
Returns:
None
"""
# FIXME: less variables
# pylint: disable=too-many-locals
# Positions of the node
pos = human_readable_layout(graph)
# Colors and labels
color_map = [IR_NODE_COLOR_MAPPING[type(node)] for node in graph.nodes()]
# For most types, we just pick the operation as the label, but for Input,
# we take the name of the variable, ie the argument name of the function
# to compile
label_dict = {
node: node.input_name if isinstance(node, ir.Input) else node.__class__.__name__
for node in graph.nodes()
}
# Draw nodes
nx.draw_networkx_nodes(
graph,
pos,
node_color=color_map,
node_size=1000,
alpha=1,
)
# Draw labels
nx.draw_networkx_labels(graph, pos, labels=label_dict)
current_axes = plt.gca()
# And draw edges in a way which works when we have two "equivalent edges",
# ie from the same node A to the same node B, like to represent y = x + x
already_done = set()
for e in graph.edges:
# If we already drew the different edges from e[0] to e[1], continue
if (e[0], e[1]) in already_done:
continue
already_done.add((e[0], e[1]))
edges = graph.get_edge_data(e[0], e[1])
# Draw the different edges from e[0] to e[1], continue
for which, edge in enumerate(edges.values()):
edge_index = edge["input_idx"]
# Draw the edge
current_axes.annotate(
"",
xy=pos[e[0]],
xycoords="data",
xytext=pos[e[1]],
textcoords="data",
arrowprops=dict(
arrowstyle="<-",
color="0.5",
shrinkA=5,
shrinkB=5,
patchA=None,
patchB=None,
connectionstyle="arc3,rad=rrr".replace("rrr", str(0.3 * which)),
),
)
if draw_edge_numbers:
# Print the number of the node on the edge. This is a bit artisanal,
# since it seems not possible to add the text directly on the
# previously drawn arrow. So, more or less, we try to put a text at
# a position which is close to pos[e[1]] and which varies a bit with
# 'which'
a, b = pos[e[0]]
c, d = pos[e[1]]
const_0 = 1
const_1 = 2
current_axes.annotate(
str(edge_index),
xycoords="data",
xy=(
(const_0 * a + const_1 * c) / (const_0 + const_1),
(const_0 * b + const_1 * d + 0.1 * which) / (const_0 + const_1),
),
textcoords="data",
)
plt.axis("off")
# block_until_user_closes_graph is used as True for real users and False
# for CI
plt.show(block=block_until_user_closes_graph)
# pylint: enable=too-many-locals

7
poetry.lock generated
View File

@@ -821,7 +821,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes
[metadata]
lock-version = "1.1"
python-versions = ">=3.7,<3.10"
content-hash = "2dbebb6e1d3b5f35cd48891bb9ed49b9d57d1f5659419e27150c5a3786d4b054"
content-hash = "e0f2368e94fc45bae93f9695b3702a12f4d2c013caaff22012e3abfc4f7af6ab"
[metadata.files]
alabaster = [
@@ -1199,6 +1199,11 @@ pathspec = [
{file = "pathspec-0.8.1.tar.gz", hash = "sha256:86379d6b86d75816baba717e64b1a3a3469deb93bb76d613c9ce79edc5cb68fd"},
]
pillow = [
{file = "Pillow-8.3.1-1-cp36-cp36m-win_amd64.whl", hash = "sha256:fd7eef578f5b2200d066db1b50c4aa66410786201669fb76d5238b007918fb24"},
{file = "Pillow-8.3.1-1-cp37-cp37m-win_amd64.whl", hash = "sha256:75e09042a3b39e0ea61ce37e941221313d51a9c26b8e54e12b3ececccb71718a"},
{file = "Pillow-8.3.1-1-cp38-cp38-win_amd64.whl", hash = "sha256:c0e0550a404c69aab1e04ae89cca3e2a042b56ab043f7f729d984bf73ed2a093"},
{file = "Pillow-8.3.1-1-cp39-cp39-win_amd64.whl", hash = "sha256:479ab11cbd69612acefa8286481f65c5dece2002ffaa4f9db62682379ca3bb77"},
{file = "Pillow-8.3.1-1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:f156d6ecfc747ee111c167f8faf5f4953761b5e66e91a4e6767e548d0f80129c"},
{file = "Pillow-8.3.1-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:196560dba4da7a72c5e7085fccc5938ab4075fd37fe8b5468869724109812edd"},
{file = "Pillow-8.3.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c9569049d04aaacd690573a0398dbd8e0bf0255684fee512b413c2142ab723"},
{file = "Pillow-8.3.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c088a000dfdd88c184cc7271bfac8c5b82d9efa8637cd2b68183771e3cf56f04"},

View File

@@ -10,6 +10,7 @@ Sphinx = "^4.1.1"
sphinx-rtd-theme = "^0.5.2"
myst-parser = "^0.15.1"
networkx = "^2.6.1"
matplotlib = "^3.4.2"
[tool.poetry.dev-dependencies]
isort = "^5.9.2"

View File

@@ -0,0 +1,58 @@
"""Test file for HDK's hnumpy debugging functions"""
import pytest
from hdk.common.data_types.integers import Integer
from hdk.common.data_types.values import ClearValue, EncryptedValue
from hdk.common.debugging import draw_graph
from hdk.hnumpy import tracing
@pytest.mark.parametrize(
"lambda_f",
[
lambda x, y: x + y,
lambda x, y: x + x - y * y * y + x,
],
)
@pytest.mark.parametrize(
"x",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
# pytest.param(
# EncryptedValue(Integer(64, is_signed=True)),
# id="Encrypted int",
# ),
# pytest.param(
# ClearValue(Integer(64, is_signed=False)),
# id="Clear uint",
# ),
# pytest.param(
# ClearValue(Integer(64, is_signed=True)),
# id="Clear int",
# ),
],
)
@pytest.mark.parametrize(
"y",
[
pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"),
# pytest.param(
# EncryptedValue(Integer(64, is_signed=True)),
# id="Encrypted int",
# ),
pytest.param(
ClearValue(Integer(64, is_signed=False)),
id="Clear uint",
),
# pytest.param(
# ClearValue(Integer(64, is_signed=True)),
# id="Clear int",
# ),
],
)
def test_hnumpy_draw_graph(lambda_f, x, y):
"Test hnumpy draw_graph"
graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y})
draw_graph(graph, block_until_user_closes_graph=False)