mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat: make get_printable_graph give correct info for np.dot
closes #204
This commit is contained in:
committed by
Benoit Chevallier
parent
6b6aa7ee4e
commit
150d33ba48
@@ -18,7 +18,7 @@ def output_data_type_to_string(node):
|
||||
str: a string representing the datatypes of the outputs of the node
|
||||
|
||||
"""
|
||||
return ", ".join([str(o.data_type) for o in node.outputs])
|
||||
return ", ".join([str(o) for o in node.outputs])
|
||||
|
||||
|
||||
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
@@ -43,6 +43,11 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
|
||||
for node in nx.topological_sort(graph):
|
||||
|
||||
# This code doesn't work with more than a single output. For more outputs,
|
||||
# we would need to change the way the destination are created: currently,
|
||||
# they only are done by incrementing i
|
||||
assert len(node.outputs) == 1
|
||||
|
||||
if isinstance(node, ir.Input):
|
||||
what_to_print = node.input_name
|
||||
elif isinstance(node, ir.Constant):
|
||||
@@ -74,6 +79,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
||||
# Then, just print the predecessors in the right order
|
||||
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
|
||||
|
||||
# This code doesn't work with more than a single output
|
||||
new_line = f"%{i} = {what_to_print}"
|
||||
|
||||
# Manage datatypes
|
||||
|
||||
Reference in New Issue
Block a user