feat: supporting several outputs in print and draw of graphs

refs #76
This commit is contained in:
Benoit Chevallier-Mames
2021-08-04 16:56:13 +02:00
committed by Benoit Chevallier
parent 078c8dc8f1
commit 1ee8195af0
2 changed files with 52 additions and 24 deletions

View File

@@ -1,5 +1,5 @@
"""functions to draw the different graphs we can generate in the package, eg to debug"""
from typing import Any, Dict, List, Union
from typing import Any, Dict, List
import matplotlib.pyplot as plt
import networkx as nx
@@ -13,6 +13,7 @@ IR_NODE_COLOR_MAPPING = {
ir.Add: "red",
ir.Sub: "yellow",
ir.Mul: "green",
"output": "magenta",
}
@@ -87,7 +88,7 @@ def human_readable_layout(graph: nx.Graph, x_delta: float = 1.0, y_delta: float
def draw_graph(
graph: Union[OPGraph, nx.MultiDiGraph],
opgraph: OPGraph,
block_until_user_closes_graph: bool = True,
draw_edge_numbers: bool = True,
) -> None:
@@ -95,7 +96,7 @@ def draw_graph(
Draw a graph
Args:
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
graph (OPGraph): The graph that we want to draw
block_until_user_closes_graph (bool): if True, will wait the user to
close the figure before continuing; False is useful for the CI tests
draw_edge_numbers (bool): if True, add the edge number on the arrow
@@ -111,14 +112,20 @@ def draw_graph(
# FIXME: less variables
# pylint: disable=too-many-locals
# Allow to pass either OPGraph or an nx graph, manage this here
graph = graph.graph if isinstance(graph, OPGraph) else graph
assert isinstance(opgraph, OPGraph)
set_of_nodes_which_are_outputs = set(opgraph.output_nodes.values())
graph = opgraph.graph
# Positions of the node
pos = human_readable_layout(graph)
# Colors and labels
color_map = [IR_NODE_COLOR_MAPPING[type(node)] for node in graph.nodes()]
def get_color(node):
if node in set_of_nodes_which_are_outputs:
return IR_NODE_COLOR_MAPPING["output"]
return IR_NODE_COLOR_MAPPING[type(node)]
color_map = [get_color(node) for node in graph.nodes()]
# For most types, we just pick the operation as the label, but for Input,
# we take the name of the variable, ie the argument name of the function
@@ -212,18 +219,19 @@ def draw_graph(
# pylint: enable=too-many-locals
def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
def get_printable_graph(opgraph: OPGraph) -> str:
"""Return a string representing a graph
Args:
graph (Union[OPGraph, nx.MultiDiGraph]): The graph that we want to draw
graph (OPGraph): The graph that we want to draw
Returns:
str: a string to print or save in a file
"""
# Allow to pass either OPGraph or an nx graph, manage this here
graph = graph.graph if isinstance(graph, OPGraph) else graph
assert isinstance(opgraph, OPGraph)
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
graph = opgraph.graph
returned_str = ""
@@ -261,4 +269,7 @@ def get_printable_graph(graph: Union[OPGraph, nx.MultiDiGraph]) -> str:
map_table[node] = i
i += 1
return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs])
returned_str += f"\nreturn({return_part})"
return returned_str

View File

@@ -11,23 +11,23 @@ from hdk.hnumpy import tracing
@pytest.mark.parametrize(
"lambda_f,ref_graph_str",
[
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)"),
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)"),
(lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)"),
(lambda x, y: x + y, "\n%0 = x\n%1 = y\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: x - y, "\n%0 = x\n%1 = y\n%2 = Sub(0, 1)\nreturn(%2)"),
(lambda x, y: x + x, "\n%0 = x\n%1 = Add(0, 0)\nreturn(%1)"),
(
lambda x, y: x + x - y * y * y + x,
"\n%0 = x\n%1 = y\n%2 = Add(0, 0)\n%3 = Mul(1, 1)"
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)",
"\n%4 = Mul(3, 1)\n%5 = Sub(2, 4)\n%6 = Add(5, 0)\nreturn(%6)",
),
(lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
(lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)"),
(lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)"),
(lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
(lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)"),
(lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)"),
(lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)"),
(lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)"),
(lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)"),
(lambda x, y: x + 1, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: 1 + x, "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: (-1) + x, "\n%0 = x\n%1 = ConstantInput(-1)\n%2 = Add(0, 1)\nreturn(%2)"),
(lambda x, y: 3 * x, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)\nreturn(%2)"),
(lambda x, y: x * 3, "\n%0 = x\n%1 = ConstantInput(3)\n%2 = Mul(0, 1)\nreturn(%2)"),
(lambda x, y: x * (-3), "\n%0 = x\n%1 = ConstantInput(-3)\n%2 = Mul(0, 1)\nreturn(%2)"),
(lambda x, y: x - 11, "\n%0 = x\n%1 = ConstantInput(11)\n%2 = Sub(0, 1)\nreturn(%2)"),
(lambda x, y: 11 - x, "\n%0 = ConstantInput(11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"),
(lambda x, y: (-11) - x, "\n%0 = ConstantInput(-11)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2)"),
(
lambda x, y: x + 13 - y * (-21) * y + 44,
"\n%0 = ConstantInput(44)"
@@ -39,7 +39,24 @@ from hdk.hnumpy import tracing
"\n%6 = Mul(3, 4)"
"\n%7 = Mul(6, 3)"
"\n%8 = Sub(5, 7)"
"\n%9 = Add(8, 0)",
"\n%9 = Add(8, 0)"
"\nreturn(%9)",
),
# Multiple outputs
(
lambda x, y: (x + 1, x + y + 2),
"\n%0 = x"
"\n%1 = ConstantInput(1)"
"\n%2 = ConstantInput(2)"
"\n%3 = y"
"\n%4 = Add(0, 1)"
"\n%5 = Add(0, 3)"
"\n%6 = Add(5, 2)"
"\nreturn(%4, %6)",
),
(
lambda x, y: (y, x),
"\n%0 = y\n%1 = x\nreturn(%0, %1)",
),
],
)