From 1ee8195af0213519b4612420cde7b8445a205ebc Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Wed, 4 Aug 2021 16:56:13 +0200 Subject: [PATCH] feat: supporting several outputs in print and draw of graphs refs #76 --- hdk/common/debugging/draw_graph.py | 31 +++++++++++++------- tests/hnumpy/test_debugging.py | 45 ++++++++++++++++++++---------- 2 files changed, 52 insertions(+), 24 deletions(-) diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py index d29c42f07..9d404c94c 100644 --- a/hdk/common/debugging/draw_graph.py +++ b/hdk/common/debugging/draw_graph.py @@ -1,5 +1,5 @@ """functions to draw the different graphs we can generate in the package, eg to debug""" -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import matplotlib.pyplot as plt import networkx as nx @@ -13,6 +13,7 @@ IR_NODE_COLOR_MAPPING = { ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green", + "output": "magenta", } @@ -87,7 +88,7 @@ def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float def draw_graph( - graph: Union[OPGraph, nx.MultiDiGraph], + opgraph: OPGraph, block_until_user_closes_graph: bool = True, draw_edge_numbers: bool = True, ) -> None: @@ -95,7 +96,7 @@ def draw_graph( Draw a graph Args: - graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw + graph (OPGraph): The graph that we want to draw block_until_user_closes_graph (bool): if True, will wait the user to close the figure before continuing; False is useful for the CI tests draw_edge_numbers (bool): if True, add the edge number on the arrow @@ -111,14 +112,20 @@ def draw_graph( # FIXME: less variables # pylint: disable=too-many-locals - # Allow to pass either OPGraph or an nx graph, manage this here - graph = graph.graph if isinstance(graph, OPGraph) else graph + assert isinstance(opgraph, OPGraph) + set_of_nodes_which_are_outputs = set(opgraph.output_nodes.values()) + graph = opgraph.graph # Positions of the node pos = human_readable_layout(graph) # Colors and labels - color_map = [IR_NODE_COLOR_MAPPING[type(node)] for node in graph.nodes()] + def get_color(node): + if node in set_of_nodes_which_are_outputs: + return IR_NODE_COLOR_MAPPING["output"] + return IR_NODE_COLOR_MAPPING[type(node)] + + color_map = [get_color(node) for node in graph.nodes()] # For most types, we just pick the operation as the label, but for Input, # we take the name of the variable, ie the argument name of the function @@ -212,18 +219,19 @@ def draw_graph( # pylint: enable=too-many-locals -def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str: +def get_printable_graph(opgraph: OPGraph) -> str: """Return a string representing a graph Args: - graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw + graph (OPGraph): The graph that we want to draw Returns: str: a string to print or save in a file """ - # Allow to pass either OPGraph or an nx graph, manage this here - graph = graph.graph if isinstance(graph, OPGraph) else graph + assert isinstance(opgraph, OPGraph) + list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values()) + graph = opgraph.graph returned_str = "" @@ -261,4 +269,7 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str: map_table[node] = i i += 1 + return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs]) + returned_str += f"\nreturn({return_part})" + return returned_str diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index e28b956f0..9187a91a1 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -11,23 +11,23 @@ from hdk.hnumpy import tracing @pytest.mark.parametrize( "lambda_f,ref_graph_str", [ - (lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"), - (lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"), - (lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)"), + (lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)\nreturn(%2)"), + (lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)\nreturn(%2)"), + (lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)\nreturn(%1)"), ( lambda x, y: x + x - y * y * y + x, "\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)" - "\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)", + "\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)\nreturn(%6)", ), - (lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"), - (lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"), - (lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)"), - (lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"), - (lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"), - (lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)"), - (lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)"), - (lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)"), - (lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)"), + (lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2)"), + (lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2)"), + (lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)\nreturn(%2)"), + (lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)\nreturn(%2)"), + (lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)\nreturn(%2)"), + (lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)\nreturn(%2)"), + (lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)\nreturn(%2)"), + (lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"), + (lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"), ( lambda x, y: x + 13 - y * (-21) * y + 44, "\n%0 = ConstantInput(44)" @@ -39,7 +39,24 @@ from hdk.hnumpy import tracing "\n%6 = Mul(3, 4)" "\n%7 = Mul(6, 3)" "\n%8 = Sub(5, 7)" - "\n%9 = Add(8, 0)", + "\n%9 = Add(8, 0)" + "\nreturn(%9)", + ), + # Multiple outputs + ( + lambda x, y: (x + 1, x + y + 2), + "\n%0 = x" + "\n%1 = ConstantInput(1)" + "\n%2 = ConstantInput(2)" + "\n%3 = y" + "\n%4 = Add(0, 1)" + "\n%5 = Add(0, 3)" + "\n%6 = Add(5, 2)" + "\nreturn(%4, %6)", + ), + ( + lambda x, y: (y, x), + "\n%0 = y\n%1 = x\nreturn(%0, %1)", ), ], )