mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
"""functions to print the different graphs we can generate in the package, eg to debug."""
|
|
|
|
from typing import Any, Dict
|
|
|
|
import networkx as nx
|
|
|
|
from ..operator_graph import OPGraph
|
|
from ..representation import intermediate as ir
|
|
|
|
|
|
def output_data_type_to_string(node):
|
|
"""Return the datatypes of the outputs of the node.
|
|
|
|
Args:
|
|
node: a graph node
|
|
|
|
Returns:
|
|
str: a string representing the datatypes of the outputs of the node
|
|
|
|
"""
|
|
return ", ".join([str(o) for o in node.outputs])
|
|
|
|
|
|
def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
|
|
"""Return a string representing a graph.
|
|
|
|
Args:
|
|
opgraph (OPGraph): The graph that we want to draw
|
|
show_data_types (bool): Whether or not showing data_types of nodes, eg
|
|
to see their width
|
|
|
|
Returns:
|
|
str: a string to print or save in a file
|
|
"""
|
|
assert isinstance(opgraph, OPGraph)
|
|
list_of_nodes_which_are_outputs = list(opgraph.output_nodes.values())
|
|
graph = opgraph.graph
|
|
|
|
returned_str = ""
|
|
|
|
i = 0
|
|
map_table: Dict[Any, int] = {}
|
|
|
|
for node in nx.topological_sort(graph):
|
|
|
|
# This code doesn't work with more than a single output. For more outputs,
|
|
# we would need to change the way the destination are created: currently,
|
|
# they only are done by incrementing i
|
|
assert len(node.outputs) == 1
|
|
|
|
if isinstance(node, ir.Input):
|
|
what_to_print = node.input_name
|
|
elif isinstance(node, ir.Constant):
|
|
what_to_print = f"Constant({node.constant_data})"
|
|
else:
|
|
|
|
base_name = node.__class__.__name__
|
|
|
|
if isinstance(node, ir.ArbitraryFunction):
|
|
base_name = node.op_name
|
|
|
|
what_to_print = base_name + "("
|
|
|
|
# Find all the names of the current predecessors of the node
|
|
list_of_arg_name = []
|
|
|
|
for pred, index_list in graph.pred[node].items():
|
|
for index in index_list.values():
|
|
# Remark that we keep the index of the predecessor and its
|
|
# name, to print sources in the right order, which is
|
|
# important for eg non commutative operations
|
|
list_of_arg_name += [(index["input_idx"], str(map_table[pred]))]
|
|
|
|
# Some checks, because the previous algorithm is not clear
|
|
assert len(list_of_arg_name) == len(set(x[0] for x in list_of_arg_name))
|
|
list_of_arg_name.sort()
|
|
assert [x[0] for x in list_of_arg_name] == list(range(len(list_of_arg_name)))
|
|
|
|
# Then, just print the predecessors in the right order
|
|
what_to_print += ", ".join([x[1] for x in list_of_arg_name]) + ")"
|
|
|
|
# This code doesn't work with more than a single output
|
|
new_line = f"%{i} = {what_to_print}"
|
|
|
|
# Manage datatypes
|
|
if show_data_types:
|
|
new_line = f"{new_line: <40s} # {output_data_type_to_string(node)}"
|
|
|
|
returned_str += f"{new_line}\n"
|
|
|
|
map_table[node] = i
|
|
i += 1
|
|
|
|
return_part = ", ".join(["%" + str(map_table[n]) for n in list_of_nodes_which_are_outputs])
|
|
returned_str += f"return({return_part})\n"
|
|
|
|
return returned_str
|