fix(debug): manage all get_color cases and unsorted arg names

This commit is contained in:
Arthur Meyre
2021-08-16 12:14:16 +02:00
parent 63eac35a43
commit 4976855c1d

View File

@@ -13,6 +13,7 @@ IR_NODE_COLOR_MAPPING = {
ir.Add: "red",
ir.Sub: "yellow",
ir.Mul: "green",
ir.ArbitraryFunction: "orange",
"ArbitraryFunction": "orange",
"TLU": "grey",
"output": "magenta",
@@ -114,11 +115,12 @@ def draw_graph(
# Colors and labels
def get_color(node):
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
if node in set_of_nodes_which_are_outputs:
return IR_NODE_COLOR_MAPPING["output"]
if isinstance(node, ir.ArbitraryFunction):
return IR_NODE_COLOR_MAPPING[node.op_name]
return IR_NODE_COLOR_MAPPING[type(node)]
value_to_return = IR_NODE_COLOR_MAPPING["output"]
elif isinstance(node, ir.ArbitraryFunction):
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
return value_to_return
color_map = [get_color(node) for node in graph.nodes()]
@@ -273,11 +275,11 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
# Some checks, because the previous algorithm is not clear
assert len(list_of_arg_name) == len({x[0] for x in list_of_arg_name})
assert len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))
list_of_arg_name.sort()
assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))
# Then, just print the predecessors in the right order
list_of_arg_name.sort()
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
new_line = f"%{i} = {what_to_print}"