mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(debug): manage all get_color cases and unsorted arg names
This commit is contained in:
@@ -13,6 +13,7 @@ IR_NODE_COLOR_MAPPING = {
|
||||
ir.Add: "red",
|
||||
ir.Sub: "yellow",
|
||||
ir.Mul: "green",
|
||||
ir.ArbitraryFunction: "orange",
|
||||
"ArbitraryFunction": "orange",
|
||||
"TLU": "grey",
|
||||
"output": "magenta",
|
||||
@@ -114,11 +115,12 @@ def draw_graph(
|
||||
|
||||
# Colors and labels
|
||||
def get_color(node):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING[type(node)]
|
||||
if node in set_of_nodes_which_are_outputs:
|
||||
return IR_NODE_COLOR_MAPPING["output"]
|
||||
if isinstance(node, ir.ArbitraryFunction):
|
||||
return IR_NODE_COLOR_MAPPING[node.op_name]
|
||||
return IR_NODE_COLOR_MAPPING[type(node)]
|
||||
value_to_return = IR_NODE_COLOR_MAPPING["output"]
|
||||
elif isinstance(node, ir.ArbitraryFunction):
|
||||
value_to_return = IR_NODE_COLOR_MAPPING.get(node.op_name, value_to_return)
|
||||
return value_to_return
|
||||
|
||||
color_map = [get_color(node) for node in graph.nodes()]
|
||||
|
||||
@@ -273,11 +275,11 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
|
||||
|
||||
# Some checks, because the previous algorithm is not clear
|
||||
assert len(list_of_arg_name) == len({x[0] for x in list_of_arg_name})
|
||||
assert len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))
|
||||
list_of_arg_name.sort()
|
||||
assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))
|
||||
|
||||
# Then, just print the predecessors in the right order
|
||||
list_of_arg_name.sort()
|
||||
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
|
||||
|
||||
new_line = f"%{i} = {what_to_print}"
|
||||
|
||||
Reference in New Issue
Block a user