From 4976855c1d344dedbf2bcec5100c0fd906e29c6e Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 16 Aug 2021 12:14:16 +0200 Subject: [PATCH] fix(debug): manage all get_color cases and unsorted arg names --- hdk/common/debugging/draw_graph.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/hdk/common/debugging/draw_graph.py b/hdk/common/debugging/draw_graph.py index b46daf631..fc63401b4 100644 --- a/hdk/common/debugging/draw_graph.py +++ b/hdk/common/debugging/draw_graph.py @@ -13,6 +13,7 @@ IR_NODE_COLOR_MAPPING = { ir.Add: "red", ir.Sub: "yellow", ir.Mul: "green", + ir.ArbitraryFunction: "orange", "ArbitraryFunction": "orange", "TLU": "grey", "output": "magenta", @@ -114,11 +115,12 @@ def draw_graph( # Colors and labels def get_color(node): + value_to_return = IR_NODE_COLOR_MAPPING[type(node)] if node in set_of_nodes_which_are_outputs: - return IR_NODE_COLOR_MAPPING["output"] - if isinstance(node, ir.ArbitraryFunction): - return IR_NODE_COLOR_MAPPING[node.op_name] - return IR_NODE_COLOR_MAPPING[type(node)] + value_to_return = IR_NODE_COLOR_MAPPING["output"] + elif isinstance(node, ir.ArbitraryFunction): + value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return) + return value_to_return color_map = [get_color(node) for node in graph.nodes()] @@ -273,11 +275,11 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: list_of_arg_name += [(index["input_idx"], str(map_table[pred]))] # Some checks, because the previous algorithm is not clear - assert len(list_of_arg_name) == len({x[0] for x in list_of_arg_name}) + assert len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name)) + list_of_arg_name.sort() assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name))) # Then, just print the predecessors in the right order - list_of_arg_name.sort() what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")" new_line = f"%{i} = {what_to_print}"