diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index dd37fa9de..80fcae0f8 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -1,7 +1,7 @@ """Code to wrap and make manipulating networkx graphs easier.""" from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union import networkx as nx @@ -126,7 +126,7 @@ class OPGraph: """ # Replication of pred is managed e.g. x + x will yield the proper pred x twice idx_to_pred: Dict[int, IntermediateNode] = {} - for pred in self.graph.pred[node]: + for pred in self.graph.predecessors(node): edge_data = self.graph.get_edge_data(pred, node) idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values()) return [idx_to_pred[i] for i in range(len(idx_to_pred))] @@ -144,7 +144,7 @@ class OPGraph: """ idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {} - for pred in self.graph.pred[node]: + for pred in self.graph.predecessors(node): edge_data = self.graph.get_edge_data(pred, node) idx_to_inp.update( (data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values() @@ -190,7 +190,7 @@ class OPGraph: for node in nx.topological_sort(self.graph): if not isinstance(node, Input): curr_inputs = {} - for pred_node in self.graph.pred[node]: + for pred_node in self.graph.predecessors(node): edges = self.graph.get_edge_data(pred_node, node) curr_inputs.update( { @@ -296,7 +296,7 @@ class OPGraph: node.outputs[0] = deepcopy(node.inputs[0]) - successors = self.graph.succ[node] + successors = self.graph.successors(node) for succ in successors: edge_data = self.graph.get_edge_data(node, succ) for edge in edge_data.values(): @@ -306,14 +306,14 @@ class OPGraph: def prune_nodes(self): """Remove unreachable nodes from outputs.""" - current_nodes = set(self.output_nodes.values()) - useful_nodes: Set[IntermediateNode] = set() + current_nodes = {node: None for node in self.get_ordered_outputs()} + useful_nodes: Dict[IntermediateNode, None] = {} while current_nodes: - next_nodes: Set[IntermediateNode] = set() + next_nodes: Dict[IntermediateNode, None] = {} useful_nodes.update(current_nodes) for node in current_nodes: - next_nodes.update(self.graph.pred[node]) + next_nodes.update({node: None for node in self.graph.predecessors(node)}) current_nodes = next_nodes - useless_nodes = set(self.graph.nodes()) - useful_nodes + useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes] self.graph.remove_nodes_from(useless_nodes) diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 491855c68..9675e5623 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -95,18 +95,18 @@ def fuse_float_operations( def convert_float_subgraph_to_fused_node( op_graph: OPGraph, - float_subgraph_start_nodes: Set[IntermediateNode], + float_subgraph_start_nodes: Dict[IntermediateNode, None], terminal_node: IntermediateNode, - subgraph_all_nodes: Set[IntermediateNode], + subgraph_all_nodes: Dict[IntermediateNode, None], ) -> Optional[Tuple[GenericFunction, IntermediateNode]]: """Convert a float subgraph to an equivalent fused GenericFunction node. Args: op_graph (OPGraph): The OPGraph the float subgraph is part of. - float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph - in `op_graph`. + float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float + subgraph in `op_graph`. terminal_node (IntermediateNode): The node ending the float subgraph. - subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph. + subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph. Returns: Optional[Tuple[GenericFunction, IntermediateNode]]: None if the float subgraph @@ -151,11 +151,21 @@ def convert_float_subgraph_to_fused_node( nx_graph = op_graph.graph - nodes_after_input_set = subgraph_all_nodes.intersection( - nx_graph.succ[current_subgraph_variable_input] - ) + nodes_after_input_set = { + node: None + for node in subgraph_all_nodes + if node in nx_graph.succ[current_subgraph_variable_input] + } - float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes)) + # # Previous non-deterministic implementation : + # # For some reason creating a graph from a subgraph this way is not deterministic + # float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes)) + + # Create a copy of the graph, remove nodes that are not in all the subgraph nodes in order to + # get a subgraph deterministically + float_subgraph = nx.MultiDiGraph(nx_graph) + nodes_to_remove = [node for node in float_subgraph.nodes() if node not in subgraph_all_nodes] + float_subgraph.remove_nodes_from(nodes_to_remove) new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0) float_subgraph.add_node(new_subgraph_variable_input) @@ -226,20 +236,20 @@ def is_single_int_output_node(node: IntermediateNode) -> bool: def find_closest_single_int_output_nodes( nx_graph: nx.MultiDiGraph, start_nodes: List[IntermediateNode], - subgraph_all_nodes: Set[IntermediateNode], -) -> Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: + subgraph_all_nodes: Dict[IntermediateNode, None], +) -> Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: """Find in nx_graph the closest upstream single integer output nodes to some start nodes. Args: nx_graph (nx.MultiDiGraph): the networkx graph to search in. start_nodes (List[IntermediateNode]): the nodes from which to start the search. - subgraph_all_nodes (Set[IntermediateNode]): a set that will be updated with all the nodes - visited during the search. + subgraph_all_nodes (Dict[IntermediateNode, None]): a set that will be updated with all the + nodes visited during the search. Returns: - Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: returns the dict used as an - ordered set containing the found single output nodes and the updated set of the visited - nodes during the search. + Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: returns the dict used as + an ordered set containing the found single output nodes and the updated set of the + visited nodes during the search. """ # Use dict as ordered set @@ -252,13 +262,13 @@ def find_closest_single_int_output_nodes( if node in visited_nodes: continue visited_nodes.add(node) - subgraph_all_nodes.add(node) - predecessors = nx_graph.pred[node] + subgraph_all_nodes.update({node: None}) + predecessors = nx_graph.predecessors(node) for pred in predecessors: if is_single_int_output_node(pred): # Limit of subgraph, record that and record the node as we won't visit it - closest_single_int_output_nodes[pred] = None - subgraph_all_nodes.add(pred) + closest_single_int_output_nodes.update({pred: None}) + subgraph_all_nodes.update({pred: None}) else: next_nodes.update({pred: None}) current_nodes = next_nodes @@ -269,21 +279,21 @@ def find_closest_single_int_output_nodes( def add_nodes_from_to( nx_graph: nx.MultiDiGraph, from_nodes: Iterable[IntermediateNode], - to_nodes: Set[IntermediateNode], - subgraph_all_nodes: Set[IntermediateNode], -) -> Set[IntermediateNode]: + to_nodes: Dict[IntermediateNode, None], + subgraph_all_nodes: Dict[IntermediateNode, None], +) -> Dict[IntermediateNode, None]: """Add nodes from from_nodes to to_nodes to the subgraph_all_nodes set. Args: nx_graph (nx.MultiDiGraph): the graph to traverse. from_nodes (Iterable[IntermediateNode]): the nodes from which we will add nodes to subgraph_all_nodes. - to_nodes (Set[IntermediateNode]): the nodes we should stop at. - subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph, will be - updated and returned. + to_nodes (Dict[IntermediateNode, None]): the nodes we should stop at. + subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph, will + be updated and returned. Returns: - Set[IntermediateNode]: returns the updated subgraph_all_nodes. + Dict[IntermediateNode, None]: returns the updated subgraph_all_nodes. """ # Add the end nodes we won't visit @@ -297,10 +307,10 @@ def add_nodes_from_to( if node in visited_nodes: continue visited_nodes.add(node) - subgraph_all_nodes.add(node) - predecessors = nx_graph.pred[node] + subgraph_all_nodes.update({node: None}) + predecessors = nx_graph.predecessors(node) # Add nodes to explore next if they are not indicated as end nodes - next_nodes.update({pred: node for pred in predecessors if pred not in to_nodes}) + next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes}) current_nodes = next_nodes return subgraph_all_nodes @@ -309,19 +319,20 @@ def add_nodes_from_to( def find_float_subgraph_with_unique_terminal_node( nx_graph: nx.MultiDiGraph, processed_terminal_nodes: Set[IntermediateNode], -) -> Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]: +) -> Optional[Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]: """Find a subgraph of the graph with float computations. Args: nx_graph (nx.MultiDiGraph): The networkx graph to search in. - processed_terminal_nodes (Set[IntermediateNode]): The set of terminal nodes for which + processed_terminal_nodes (Dict[IntermediateNode, None]): The set of terminal nodes for which subgraphs have already been searched, those will be skipped. Returns: - Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]: - None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a tuple - containing the set of nodes beginning a float subgraph, the terminal node of the - subgraph and the set of all the nodes in the subgraph. + Optional[ + Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]: + None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a + tuple containing the set of nodes beginning a float subgraph, the terminal node of + the subgraph and the set of all the nodes in the subgraph. """ def is_float_to_single_int_node(node: IntermediateNode) -> bool: @@ -352,12 +363,13 @@ def find_float_subgraph_with_unique_terminal_node( equivalent_digraph_without_constants = nx.DiGraph(nx_graph) constant_graph_nodes = [ constant_node - for constant_node in equivalent_digraph_without_constants + for constant_node in equivalent_digraph_without_constants.nodes() if isinstance(constant_node, Constant) ] equivalent_digraph_without_constants.remove_nodes_from(constant_graph_nodes) - subgraph_all_nodes: Set[IntermediateNode] = set() + # Use dict as ordered set + subgraph_all_nodes: Dict[IntermediateNode, None] = {} start_single_int_output_nodes_search_from = terminal_node @@ -397,7 +409,7 @@ def find_float_subgraph_with_unique_terminal_node( # if lca is not None, add the nodes from the current start nodes to the lca to # subgraph_all_nodes subgraph_all_nodes = add_nodes_from_to( - nx_graph, float_subgraph_start_nodes, {lca}, subgraph_all_nodes + nx_graph, float_subgraph_start_nodes, {lca: None}, subgraph_all_nodes ) # if the lca is a valid starting node for fusing break @@ -410,12 +422,12 @@ def find_float_subgraph_with_unique_terminal_node( # integer output e.g.) start_single_int_output_nodes_search_from = lca - return set(float_subgraph_start_nodes.keys()), terminal_node, subgraph_all_nodes + return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes def subgraph_nodes_and_values_allow_fusing( - float_subgraph_start_nodes: Set[IntermediateNode], - subgraph_all_nodes: Set[IntermediateNode], + float_subgraph_start_nodes: Dict[IntermediateNode, None], + subgraph_all_nodes: Dict[IntermediateNode, None], node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]], ) -> bool: """Check if a subgraph's values are compatible with fusing. @@ -424,8 +436,9 @@ def subgraph_nodes_and_values_allow_fusing( can be applied per cell, hence shuffling or tensor shape changes make fusing impossible. Args: - float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph. - subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph. + float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float + subgraph. + subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph. node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill with potential nodes issues preventing fusing. @@ -544,14 +557,14 @@ def subgraph_nodes_and_values_allow_fusing( def subgraph_has_unique_variable_input( - float_subgraph_start_nodes: Set[IntermediateNode], + float_subgraph_start_nodes: Dict[IntermediateNode, None], terminal_node: IntermediateNode, node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]], ) -> bool: """Check that only one of the nodes starting the subgraph is variable. Args: - float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph. + float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the subgraph. terminal_node (IntermediateNode): The node ending the float subgraph. node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill with potential nodes issues preventing fusing. diff --git a/tests/common/optimization/test_float_fusing.py b/tests/common/optimization/test_float_fusing.py index 062e77890..2f3902ae8 100644 --- a/tests/common/optimization/test_float_fusing.py +++ b/tests/common/optimization/test_float_fusing.py @@ -1,12 +1,14 @@ """Test file for float subgraph fusing""" import random +from copy import deepcopy from inspect import signature import numpy import pytest from concrete.common.data_types.integers import Integer +from concrete.common.debugging import format_operation_graph from concrete.common.debugging.custom_assert import assert_not_reached from concrete.common.optimization.topological import fuse_float_operations from concrete.common.values import EncryptedScalar, EncryptedTensor @@ -386,9 +388,14 @@ def test_fuse_float_operations( function_to_trace, params, ) + copied_graph = deepcopy(op_graph) orig_num_nodes = len(op_graph.graph) fuse_float_operations(op_graph) fused_num_nodes = len(op_graph.graph) + fuse_float_operations(copied_graph) + + # Check determinism + assert format_operation_graph(copied_graph) == format_operation_graph(op_graph) if fused: assert fused_num_nodes < orig_num_nodes @@ -504,9 +511,14 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape): for param_name in params_names }, ) + copied_graph = deepcopy(op_graph) orig_num_nodes = len(op_graph.graph) fuse_float_operations(op_graph) fused_num_nodes = len(op_graph.graph) + fuse_float_operations(copied_graph) + + # Check determinism + assert format_operation_graph(copied_graph) == format_operation_graph(op_graph) assert fused_num_nodes < orig_num_nodes @@ -643,9 +655,14 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape): for param_name in params_names }, ) + copied_graph = deepcopy(op_graph) orig_num_nodes = len(op_graph.graph) fuse_float_operations(op_graph) fused_num_nodes = len(op_graph.graph) + fuse_float_operations(copied_graph) + + # Check determinism + assert format_operation_graph(copied_graph) == format_operation_graph(op_graph) assert fused_num_nodes < orig_num_nodes