From b9c52947f6eac6342b2372b21e35d6306e2999ff Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 17 Nov 2021 11:53:35 +0100 Subject: [PATCH] refactor(optimization): refactor code to prepare new topologies fusing refs #499 --- concrete/common/optimization/topological.py | 78 +++++++++++++++------ 1 file changed, 58 insertions(+), 20 deletions(-) diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 272039604..805e24b2a 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -211,6 +211,57 @@ def convert_float_subgraph_to_fused_node( ) +def is_single_int_output_node(node: IntermediateNode) -> bool: + """Check if a node has a single output and that output is an integer. + + Args: + node (IntermediateNode): the node to check. + + Returns: + bool: returns True if the node has a single integer output, False otherwise. + """ + return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer) + + +def find_closest_single_int_output_nodes( + nx_graph: nx.MultiDiGraph, + start_nodes: List[IntermediateNode], + subgraph_all_nodes: Set[IntermediateNode], +) -> Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: + """Find in nx_graph the closest upstream single integer output nodes to some start nodes. + + Args: + nx_graph (nx.MultiDiGraph): the networkx graph to search in. + start_nodes (List[IntermediateNode]): the nodes from which to start the search. + subgraph_all_nodes (Set[IntermediateNode]): a set that will be updated with all the nodes + visited during the search. + + Returns: + Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: returns the dict used as an + ordered set containing the found single output nodes and the updated set of the visited + nodes during the search. + """ + + # Use dict as ordered set + current_nodes = {start_node: None for start_node in start_nodes} + closest_single_int_output_nodes: Dict[IntermediateNode, None] = {} + while current_nodes: + next_nodes: Dict[IntermediateNode, None] = {} + for node in current_nodes: + subgraph_all_nodes.add(node) + predecessors = nx_graph.pred[node] + for pred in predecessors: + if is_single_int_output_node(pred): + # Limit of subgraph, record that and record the node as we won't visit it + closest_single_int_output_nodes[pred] = None + subgraph_all_nodes.add(pred) + else: + next_nodes.update({pred: None}) + current_nodes = next_nodes + + return closest_single_int_output_nodes, subgraph_all_nodes + + def find_float_subgraph_with_unique_terminal_node( nx_graph: nx.MultiDiGraph, processed_terminal_nodes: Set[IntermediateNode], @@ -239,9 +290,6 @@ def find_float_subgraph_with_unique_terminal_node( and isinstance(node.outputs[0].dtype, Integer) ) - def single_int_output_node(node: IntermediateNode) -> bool: - return len(node.outputs) == 1 and isinstance(node.outputs[0].dtype, Integer) - float_subgraphs_terminal_nodes = ( node for node in nx_graph.nodes() @@ -255,25 +303,15 @@ def find_float_subgraph_with_unique_terminal_node( except StopIteration: return None - # Use dict as ordered set - current_nodes = {terminal_node: None} - float_subgraph_start_nodes: Set[IntermediateNode] = set() subgraph_all_nodes: Set[IntermediateNode] = set() - while current_nodes: - next_nodes: Dict[IntermediateNode, None] = {} - for node in current_nodes: - subgraph_all_nodes.add(node) - predecessors = nx_graph.pred[node] - for pred in predecessors: - if single_int_output_node(pred): - # Limit of subgraph, record that and record the node as we won't visit it - float_subgraph_start_nodes.add(pred) - subgraph_all_nodes.add(pred) - else: - next_nodes.update({pred: None}) - current_nodes = next_nodes - return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes + float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes( + nx_graph, + [terminal_node], + subgraph_all_nodes, + ) + + return set(float_subgraph_start_nodes.keys()), terminal_node, subgraph_all_nodes def subgraph_nodes_and_values_allow_fusing(