refactor: make float fusing deterministic

- consider having a an OrderedSet class instead of using dicts for it

closes #1438
This commit is contained in:
Arthur Meyre
2022-02-24 14:57:06 +01:00
committed by Umut
parent 8af6d83ed6
commit 19c78e6dad
3 changed files with 86 additions and 56 deletions

View File

@@ -1,7 +1,7 @@
"""Code to wrap and make manipulating networkx graphs easier."""
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import networkx as nx
@@ -126,7 +126,7 @@ class OPGraph:
"""
# Replication of pred is managed e.g. x + x will yield the proper pred x twice
idx_to_pred: Dict[int, IntermediateNode] = {}
for pred in self.graph.pred[node]:
for pred in self.graph.predecessors(node):
edge_data = self.graph.get_edge_data(pred, node)
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
@@ -144,7 +144,7 @@ class OPGraph:
"""
idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {}
for pred in self.graph.pred[node]:
for pred in self.graph.predecessors(node):
edge_data = self.graph.get_edge_data(pred, node)
idx_to_inp.update(
(data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values()
@@ -190,7 +190,7 @@ class OPGraph:
for node in nx.topological_sort(self.graph):
if not isinstance(node, Input):
curr_inputs = {}
for pred_node in self.graph.pred[node]:
for pred_node in self.graph.predecessors(node):
edges = self.graph.get_edge_data(pred_node, node)
curr_inputs.update(
{
@@ -296,7 +296,7 @@ class OPGraph:
node.outputs[0] = deepcopy(node.inputs[0])
successors = self.graph.succ[node]
successors = self.graph.successors(node)
for succ in successors:
edge_data = self.graph.get_edge_data(node, succ)
for edge in edge_data.values():
@@ -306,14 +306,14 @@ class OPGraph:
def prune_nodes(self):
"""Remove unreachable nodes from outputs."""
current_nodes = set(self.output_nodes.values())
useful_nodes: Set[IntermediateNode] = set()
current_nodes = {node: None for node in self.get_ordered_outputs()}
useful_nodes: Dict[IntermediateNode, None] = {}
while current_nodes:
next_nodes: Set[IntermediateNode] = set()
next_nodes: Dict[IntermediateNode, None] = {}
useful_nodes.update(current_nodes)
for node in current_nodes:
next_nodes.update(self.graph.pred[node])
next_nodes.update({node: None for node in self.graph.predecessors(node)})
current_nodes = next_nodes
useless_nodes = set(self.graph.nodes()) - useful_nodes
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
self.graph.remove_nodes_from(useless_nodes)

View File

@@ -95,18 +95,18 @@ def fuse_float_operations(
def convert_float_subgraph_to_fused_node(
op_graph: OPGraph,
float_subgraph_start_nodes: Set[IntermediateNode],
float_subgraph_start_nodes: Dict[IntermediateNode, None],
terminal_node: IntermediateNode,
subgraph_all_nodes: Set[IntermediateNode],
subgraph_all_nodes: Dict[IntermediateNode, None],
) -> Optional[Tuple[GenericFunction, IntermediateNode]]:
"""Convert a float subgraph to an equivalent fused GenericFunction node.
Args:
op_graph (OPGraph): The OPGraph the float subgraph is part of.
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph
in `op_graph`.
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
subgraph in `op_graph`.
terminal_node (IntermediateNode): The node ending the float subgraph.
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
Returns:
Optional[Tuple[GenericFunction, IntermediateNode]]: None if the float subgraph
@@ -151,11 +151,21 @@ def convert_float_subgraph_to_fused_node(
nx_graph = op_graph.graph
nodes_after_input_set = subgraph_all_nodes.intersection(
nx_graph.succ[current_subgraph_variable_input]
)
nodes_after_input_set = {
node: None
for node in subgraph_all_nodes
if node in nx_graph.succ[current_subgraph_variable_input]
}
float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
# # Previous non-deterministic implementation :
# # For some reason creating a graph from a subgraph this way is not deterministic
# float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
# Create a copy of the graph, remove nodes that are not in all the subgraph nodes in order to
# get a subgraph deterministically
float_subgraph = nx.MultiDiGraph(nx_graph)
nodes_to_remove = [node for node in float_subgraph.nodes() if node not in subgraph_all_nodes]
float_subgraph.remove_nodes_from(nodes_to_remove)
new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0)
float_subgraph.add_node(new_subgraph_variable_input)
@@ -226,20 +236,20 @@ def is_single_int_output_node(node: IntermediateNode) -> bool:
def find_closest_single_int_output_nodes(
nx_graph: nx.MultiDiGraph,
start_nodes: List[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
) -> Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]:
subgraph_all_nodes: Dict[IntermediateNode, None],
) -> Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]:
"""Find in nx_graph the closest upstream single integer output nodes to some start nodes.
Args:
nx_graph (nx.MultiDiGraph): the networkx graph to search in.
start_nodes (List[IntermediateNode]): the nodes from which to start the search.
subgraph_all_nodes (Set[IntermediateNode]): a set that will be updated with all the nodes
visited during the search.
subgraph_all_nodes (Dict[IntermediateNode, None]): a set that will be updated with all the
nodes visited during the search.
Returns:
Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: returns the dict used as an
ordered set containing the found single output nodes and the updated set of the visited
nodes during the search.
Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: returns the dict used as
an ordered set containing the found single output nodes and the updated set of the
visited nodes during the search.
"""
# Use dict as ordered set
@@ -252,13 +262,13 @@ def find_closest_single_int_output_nodes(
if node in visited_nodes:
continue
visited_nodes.add(node)
subgraph_all_nodes.add(node)
predecessors = nx_graph.pred[node]
subgraph_all_nodes.update({node: None})
predecessors = nx_graph.predecessors(node)
for pred in predecessors:
if is_single_int_output_node(pred):
# Limit of subgraph, record that and record the node as we won't visit it
closest_single_int_output_nodes[pred] = None
subgraph_all_nodes.add(pred)
closest_single_int_output_nodes.update({pred: None})
subgraph_all_nodes.update({pred: None})
else:
next_nodes.update({pred: None})
current_nodes = next_nodes
@@ -269,21 +279,21 @@ def find_closest_single_int_output_nodes(
def add_nodes_from_to(
nx_graph: nx.MultiDiGraph,
from_nodes: Iterable[IntermediateNode],
to_nodes: Set[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
) -> Set[IntermediateNode]:
to_nodes: Dict[IntermediateNode, None],
subgraph_all_nodes: Dict[IntermediateNode, None],
) -> Dict[IntermediateNode, None]:
"""Add nodes from from_nodes to to_nodes to the subgraph_all_nodes set.
Args:
nx_graph (nx.MultiDiGraph): the graph to traverse.
from_nodes (Iterable[IntermediateNode]): the nodes from which we will add nodes to
subgraph_all_nodes.
to_nodes (Set[IntermediateNode]): the nodes we should stop at.
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph, will be
updated and returned.
to_nodes (Dict[IntermediateNode, None]): the nodes we should stop at.
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph, will
be updated and returned.
Returns:
Set[IntermediateNode]: returns the updated subgraph_all_nodes.
Dict[IntermediateNode, None]: returns the updated subgraph_all_nodes.
"""
# Add the end nodes we won't visit
@@ -297,10 +307,10 @@ def add_nodes_from_to(
if node in visited_nodes:
continue
visited_nodes.add(node)
subgraph_all_nodes.add(node)
predecessors = nx_graph.pred[node]
subgraph_all_nodes.update({node: None})
predecessors = nx_graph.predecessors(node)
# Add nodes to explore next if they are not indicated as end nodes
next_nodes.update({pred: node for pred in predecessors if pred not in to_nodes})
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
current_nodes = next_nodes
return subgraph_all_nodes
@@ -309,19 +319,20 @@ def add_nodes_from_to(
def find_float_subgraph_with_unique_terminal_node(
nx_graph: nx.MultiDiGraph,
processed_terminal_nodes: Set[IntermediateNode],
) -> Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]:
) -> Optional[Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
"""Find a subgraph of the graph with float computations.
Args:
nx_graph (nx.MultiDiGraph): The networkx graph to search in.
processed_terminal_nodes (Set[IntermediateNode]): The set of terminal nodes for which
processed_terminal_nodes (Dict[IntermediateNode, None]): The set of terminal nodes for which
subgraphs have already been searched, those will be skipped.
Returns:
Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]:
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a tuple
containing the set of nodes beginning a float subgraph, the terminal node of the
subgraph and the set of all the nodes in the subgraph.
Optional[
Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a
tuple containing the set of nodes beginning a float subgraph, the terminal node of
the subgraph and the set of all the nodes in the subgraph.
"""
def is_float_to_single_int_node(node: IntermediateNode) -> bool:
@@ -352,12 +363,13 @@ def find_float_subgraph_with_unique_terminal_node(
equivalent_digraph_without_constants = nx.DiGraph(nx_graph)
constant_graph_nodes = [
constant_node
for constant_node in equivalent_digraph_without_constants
for constant_node in equivalent_digraph_without_constants.nodes()
if isinstance(constant_node, Constant)
]
equivalent_digraph_without_constants.remove_nodes_from(constant_graph_nodes)
subgraph_all_nodes: Set[IntermediateNode] = set()
# Use dict as ordered set
subgraph_all_nodes: Dict[IntermediateNode, None] = {}
start_single_int_output_nodes_search_from = terminal_node
@@ -397,7 +409,7 @@ def find_float_subgraph_with_unique_terminal_node(
# if lca is not None, add the nodes from the current start nodes to the lca to
# subgraph_all_nodes
subgraph_all_nodes = add_nodes_from_to(
nx_graph, float_subgraph_start_nodes, {lca}, subgraph_all_nodes
nx_graph, float_subgraph_start_nodes, {lca: None}, subgraph_all_nodes
)
# if the lca is a valid starting node for fusing break
@@ -410,12 +422,12 @@ def find_float_subgraph_with_unique_terminal_node(
# integer output e.g.)
start_single_int_output_nodes_search_from = lca
return set(float_subgraph_start_nodes.keys()), terminal_node, subgraph_all_nodes
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
def subgraph_nodes_and_values_allow_fusing(
float_subgraph_start_nodes: Set[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
float_subgraph_start_nodes: Dict[IntermediateNode, None],
subgraph_all_nodes: Dict[IntermediateNode, None],
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
) -> bool:
"""Check if a subgraph's values are compatible with fusing.
@@ -424,8 +436,9 @@ def subgraph_nodes_and_values_allow_fusing(
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
Args:
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph.
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
subgraph.
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
with potential nodes issues preventing fusing.
@@ -544,14 +557,14 @@ def subgraph_nodes_and_values_allow_fusing(
def subgraph_has_unique_variable_input(
float_subgraph_start_nodes: Set[IntermediateNode],
float_subgraph_start_nodes: Dict[IntermediateNode, None],
terminal_node: IntermediateNode,
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
) -> bool:
"""Check that only one of the nodes starting the subgraph is variable.
Args:
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph.
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the subgraph.
terminal_node (IntermediateNode): The node ending the float subgraph.
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
with potential nodes issues preventing fusing.

View File

@@ -1,12 +1,14 @@
"""Test file for float subgraph fusing"""
import random
from copy import deepcopy
from inspect import signature
import numpy
import pytest
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import format_operation_graph
from concrete.common.debugging.custom_assert import assert_not_reached
from concrete.common.optimization.topological import fuse_float_operations
from concrete.common.values import EncryptedScalar, EncryptedTensor
@@ -386,9 +388,14 @@ def test_fuse_float_operations(
function_to_trace,
params,
)
copied_graph = deepcopy(op_graph)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
fuse_float_operations(copied_graph)
# Check determinism
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
if fused:
assert fused_num_nodes < orig_num_nodes
@@ -504,9 +511,14 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
for param_name in params_names
},
)
copied_graph = deepcopy(op_graph)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
fuse_float_operations(copied_graph)
# Check determinism
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
assert fused_num_nodes < orig_num_nodes
@@ -643,9 +655,14 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
for param_name in params_names
},
)
copied_graph = deepcopy(op_graph)
orig_num_nodes = len(op_graph.graph)
fuse_float_operations(op_graph)
fused_num_nodes = len(op_graph.graph)
fuse_float_operations(copied_graph)
# Check determinism
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
assert fused_num_nodes < orig_num_nodes