diff --git a/hdk/common/__init__.py b/hdk/common/__init__.py index cd563aaf9..48c4420ae 100644 --- a/hdk/common/__init__.py +++ b/hdk/common/__init__.py @@ -1,2 +1,2 @@ """HDK's module for shared data structures and code""" -from . import data_types, representation +from . import data_types, debugging, representation diff --git a/hdk/common/debugging/__init__.py b/hdk/common/debugging/__init__.py new file mode 100644 index 000000000..b18d821bf --- /dev/null +++ b/hdk/common/debugging/__init__.py @@ -0,0 +1,2 @@ +"""HDK's module for debugging""" +from .draw_graph import draw_graph diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py new file mode 100644 index 000000000..46474d0b7 --- /dev/null +++ b/hdk/common/debugging/draw_graph.py @@ -0,0 +1,196 @@ +"""functions to draw the different graphs we can generate in the package, eg to debug""" +from typing import Dict, List + +import matplotlib.pyplot as plt +import networkx as nx + +from hdk.common.representation import intermediate as ir + +IR_NODE_COLOR_MAPPING = {ir.Input: "blue", ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green"} + + +def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float = 1.0) -> Dict: + """ + Returns a pos to be used later with eg nx.draw_networkx_nodes, so that nodes + are ordered by depth from input along the x axis and have a uniform + distribution along the y axis + + Args: + graph (nx.Graph): The graph that we want to draw + x_delta (float): Parameter used to set the increment in x + y_delta (float): Parameter used to set the increment in y + + Returns: + pos (Dict): the argument to use with eg nx.draw_networkx_nodes + + """ + + # FIXME: less variables + # pylint: disable=too-many-locals + + nodes_depth = {node: 0 for node in graph.nodes()} + input_nodes = [node for node in graph.nodes() if len(list(graph.predecessors(node))) == 0] + + # Init a layout so that unreachable nodes have a pos, avoids potential crashes wiht networkx + # use a cheap layout + pos = nx.random_layout(graph) + + curr_x = 0.0 + curr_y = -(len(input_nodes) - 1) / 2 * y_delta + + for in_node in input_nodes: + pos[in_node] = (curr_x, curr_y) + curr_y += y_delta + + curr_x += x_delta + + curr_nodes = input_nodes + + current_depth = 0 + while len(curr_nodes) > 0: + current_depth += 1 + next_nodes_set = set() + for node in curr_nodes: + next_nodes_set.update(graph.successors(node)) + + curr_nodes = list(next_nodes_set) + for node in curr_nodes: + nodes_depth[node] = current_depth + + nodes_by_depth: Dict[int, List[int]] = {} + for node, depth in nodes_depth.items(): + nodes_for_depth = nodes_by_depth.get(depth, []) + nodes_for_depth.append(node) + nodes_by_depth[depth] = nodes_for_depth + + depths = sorted(nodes_by_depth.keys()) + + for depth in depths: + nodes_at_depth = nodes_by_depth[depth] + + curr_y = -(len(nodes_at_depth) - 1) / 2 * y_delta + for node in nodes_at_depth: + pos[node] = (curr_x, curr_y) + curr_y += y_delta + + curr_x += x_delta + + # pylint: enable=too-many-locals + return pos + + +def draw_graph( + graph: nx.DiGraph, block_until_user_closes_graph: bool = True, draw_edge_numbers: bool = True +) -> None: + """ + Draw a graph + + Args: + graph (nx.DiGraph): The graph that we want to draw + block_until_user_closes_graph (bool): if True, will wait the user to + close the figure before continuing; False is useful for the CI tests + draw_edge_numbers (bool): if True, add the edge number on the arrow + linking nodes, eg to differentiate the x and y in a Sub coding + (x - y). This option is not that useful for commutative ops, and + may make the picture a bit too dense, so could be deactivated + + Returns: + None + + """ + + # FIXME: less variables + # pylint: disable=too-many-locals + + # Positions of the node + pos = human_readable_layout(graph) + + # Colors and labels + color_map = [IR_NODE_COLOR_MAPPING[type(node)] for node in graph.nodes()] + + # For most types, we just pick the operation as the label, but for Input, + # we take the name of the variable, ie the argument name of the function + # to compile + label_dict = { + node: node.input_name if isinstance(node, ir.Input) else node.__class__.__name__ + for node in graph.nodes() + } + + # Draw nodes + nx.draw_networkx_nodes( + graph, + pos, + node_color=color_map, + node_size=1000, + alpha=1, + ) + + # Draw labels + nx.draw_networkx_labels(graph, pos, labels=label_dict) + + current_axes = plt.gca() + + # And draw edges in a way which works when we have two "equivalent edges", + # ie from the same node A to the same node B, like to represent y = x + x + already_done = set() + + for e in graph.edges: + + # If we already drew the different edges from e[0] to e[1], continue + if (e[0], e[1]) in already_done: + continue + + already_done.add((e[0], e[1])) + + edges = graph.get_edge_data(e[0], e[1]) + + # Draw the different edges from e[0] to e[1], continue + for which, edge in enumerate(edges.values()): + edge_index = edge["input_idx"] + + # Draw the edge + current_axes.annotate( + "", + xy=pos[e[0]], + xycoords="data", + xytext=pos[e[1]], + textcoords="data", + arrowprops=dict( + arrowstyle="<-", + color="0.5", + shrinkA=5, + shrinkB=5, + patchA=None, + patchB=None, + connectionstyle="arc3,rad=rrr".replace("rrr", str(0.3 * which)), + ), + ) + + if draw_edge_numbers: + # Print the number of the node on the edge. This is a bit artisanal, + # since it seems not possible to add the text directly on the + # previously drawn arrow. So, more or less, we try to put a text at + # a position which is close to pos[e[1]] and which varies a bit with + # 'which' + a, b = pos[e[0]] + c, d = pos[e[1]] + const_0 = 1 + const_1 = 2 + + current_axes.annotate( + str(edge_index), + xycoords="data", + xy=( + (const_0 * a + const_1 * c) / (const_0 + const_1), + (const_0 * b + const_1 * d + 0.1 * which) / (const_0 + const_1), + ), + textcoords="data", + ) + + plt.axis("off") + + # block_until_user_closes_graph is used as True for real users and False + # for CI + plt.show(block=block_until_user_closes_graph) + + # pylint: enable=too-many-locals diff --git a/poetry.lock b/poetry.lock index c0eb2c77b..2337759c8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -821,7 +821,7 @@ testing = ["pytest (>=4.6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytes [metadata] lock-version = "1.1" python-versions = ">=3.7,<3.10" -content-hash = "2dbebb6e1d3b5f35cd48891bb9ed49b9d57d1f5659419e27150c5a3786d4b054" +content-hash = "e0f2368e94fc45bae93f9695b3702a12f4d2c013caaff22012e3abfc4f7af6ab" [metadata.files] alabaster = [ @@ -1199,6 +1199,11 @@ pathspec = [ {file = "pathspec-0.8.1.tar.gz", hash = "sha256:86379d6b86d75816baba717e64b1a3a3469deb93bb76d613c9ce79edc5cb68fd"}, ] pillow = [ + {file = "Pillow-8.3.1-1-cp36-cp36m-win_amd64.whl", hash = "sha256:fd7eef578f5b2200d066db1b50c4aa66410786201669fb76d5238b007918fb24"}, + {file = "Pillow-8.3.1-1-cp37-cp37m-win_amd64.whl", hash = "sha256:75e09042a3b39e0ea61ce37e941221313d51a9c26b8e54e12b3ececccb71718a"}, + {file = "Pillow-8.3.1-1-cp38-cp38-win_amd64.whl", hash = "sha256:c0e0550a404c69aab1e04ae89cca3e2a042b56ab043f7f729d984bf73ed2a093"}, + {file = "Pillow-8.3.1-1-cp39-cp39-win_amd64.whl", hash = "sha256:479ab11cbd69612acefa8286481f65c5dece2002ffaa4f9db62682379ca3bb77"}, + {file = "Pillow-8.3.1-1-pp37-pypy37_pp73-win_amd64.whl", hash = "sha256:f156d6ecfc747ee111c167f8faf5f4953761b5e66e91a4e6767e548d0f80129c"}, {file = "Pillow-8.3.1-cp36-cp36m-macosx_10_10_x86_64.whl", hash = "sha256:196560dba4da7a72c5e7085fccc5938ab4075fd37fe8b5468869724109812edd"}, {file = "Pillow-8.3.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29c9569049d04aaacd690573a0398dbd8e0bf0255684fee512b413c2142ab723"}, {file = "Pillow-8.3.1-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:c088a000dfdd88c184cc7271bfac8c5b82d9efa8637cd2b68183771e3cf56f04"}, diff --git a/pyproject.toml b/pyproject.toml index 13c80cb78..0b1603c32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,7 @@ Sphinx = "^4.1.1" sphinx-rtd-theme = "^0.5.2" myst-parser = "^0.15.1" networkx = "^2.6.1" +matplotlib = "^3.4.2" [tool.poetry.dev-dependencies] isort = "^5.9.2" diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py new file mode 100644 index 000000000..2ade31562 --- /dev/null +++ b/tests/hnumpy/test_debugging.py @@ -0,0 +1,58 @@ +"""Test file for HDK's hnumpy debugging functions""" + +import pytest + +from hdk.common.data_types.integers import Integer +from hdk.common.data_types.values import ClearValue, EncryptedValue +from hdk.common.debugging import draw_graph +from hdk.hnumpy import tracing + + +@pytest.mark.parametrize( + "lambda_f", + [ + lambda x, y: x + y, + lambda x, y: x + x - y * y * y + x, + ], +) +@pytest.mark.parametrize( + "x", + [ + pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"), + # pytest.param( + # EncryptedValue(Integer(64, is_signed=True)), + # id="Encrypted int", + # ), + # pytest.param( + # ClearValue(Integer(64, is_signed=False)), + # id="Clear uint", + # ), + # pytest.param( + # ClearValue(Integer(64, is_signed=True)), + # id="Clear int", + # ), + ], +) +@pytest.mark.parametrize( + "y", + [ + pytest.param(EncryptedValue(Integer(64, is_signed=False)), id="Encrypted uint"), + # pytest.param( + # EncryptedValue(Integer(64, is_signed=True)), + # id="Encrypted int", + # ), + pytest.param( + ClearValue(Integer(64, is_signed=False)), + id="Clear uint", + ), + # pytest.param( + # ClearValue(Integer(64, is_signed=True)), + # id="Clear int", + # ), + ], +) +def test_hnumpy_draw_graph(lambda_f, x, y): + "Test hnumpy draw_graph" + graph = tracing.trace_numpy_function(lambda_f, {"x": x, "y": y}) + + draw_graph(graph, block_until_user_closes_graph=False)