mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor(optimization): refactor code to prepare new topologies fusing
refs #499
This commit is contained in:
@@ -211,6 +211,57 @@ def convert_float_subgraph_to_fused_node(
|
||||
)
|
||||
|
||||
|
||||
def is_single_int_output_node(node: IntermediateNode) -> bool:
|
||||
"""Check if a node has a single output and that output is an integer.
|
||||
|
||||
Args:
|
||||
node (IntermediateNode): the node to check.
|
||||
|
||||
Returns:
|
||||
bool: returns True if the node has a single integer output, False otherwise.
|
||||
"""
|
||||
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
|
||||
|
||||
|
||||
def find_closest_single_int_output_nodes(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
start_nodes: List[IntermediateNode],
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
) -> Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]:
|
||||
"""Find in nx_graph the closest upstream single integer output nodes to some start nodes.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): the networkx graph to search in.
|
||||
start_nodes (List[IntermediateNode]): the nodes from which to start the search.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): a set that will be updated with all the nodes
|
||||
visited during the search.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: returns the dict used as an
|
||||
ordered set containing the found single output nodes and the updated set of the visited
|
||||
nodes during the search.
|
||||
"""
|
||||
|
||||
# Use dict as ordered set
|
||||
current_nodes = {start_node: None for start_node in start_nodes}
|
||||
closest_single_int_output_nodes: Dict[IntermediateNode, None] = {}
|
||||
while current_nodes:
|
||||
next_nodes: Dict[IntermediateNode, None] = {}
|
||||
for node in current_nodes:
|
||||
subgraph_all_nodes.add(node)
|
||||
predecessors = nx_graph.pred[node]
|
||||
for pred in predecessors:
|
||||
if is_single_int_output_node(pred):
|
||||
# Limit of subgraph, record that and record the node as we won't visit it
|
||||
closest_single_int_output_nodes[pred] = None
|
||||
subgraph_all_nodes.add(pred)
|
||||
else:
|
||||
next_nodes.update({pred: None})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return closest_single_int_output_nodes, subgraph_all_nodes
|
||||
|
||||
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
processed_terminal_nodes: Set[IntermediateNode],
|
||||
@@ -239,9 +290,6 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
and isinstance(node.outputs[0].dtype, Integer)
|
||||
)
|
||||
|
||||
def single_int_output_node(node: IntermediateNode) -> bool:
|
||||
return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer)
|
||||
|
||||
float_subgraphs_terminal_nodes = (
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
@@ -255,25 +303,15 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
# Use dict as ordered set
|
||||
current_nodes = {terminal_node: None}
|
||||
float_subgraph_start_nodes: Set[IntermediateNode] = set()
|
||||
subgraph_all_nodes: Set[IntermediateNode] = set()
|
||||
while current_nodes:
|
||||
next_nodes: Dict[IntermediateNode, None] = {}
|
||||
for node in current_nodes:
|
||||
subgraph_all_nodes.add(node)
|
||||
predecessors = nx_graph.pred[node]
|
||||
for pred in predecessors:
|
||||
if single_int_output_node(pred):
|
||||
# Limit of subgraph, record that and record the node as we won't visit it
|
||||
float_subgraph_start_nodes.add(pred)
|
||||
subgraph_all_nodes.add(pred)
|
||||
else:
|
||||
next_nodes.update({pred: None})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
|
||||
float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes(
|
||||
nx_graph,
|
||||
[terminal_node],
|
||||
subgraph_all_nodes,
|
||||
)
|
||||
|
||||
return set(float_subgraph_start_nodes.keys()), terminal_node, subgraph_all_nodes
|
||||
|
||||
|
||||
def subgraph_nodes_and_values_allow_fusing(
|
||||
|
||||
Reference in New Issue
Block a user