mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
refactor: make float fusing deterministic
- consider having a an OrderedSet class instead of using dicts for it closes #1438
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Code to wrap and make manipulating networkx graphs easier."""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Set, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@@ -126,7 +126,7 @@ class OPGraph:
|
||||
"""
|
||||
# Replication of pred is managed e.g. x + x will yield the proper pred x twice
|
||||
idx_to_pred: Dict[int, IntermediateNode] = {}
|
||||
for pred in self.graph.pred[node]:
|
||||
for pred in self.graph.predecessors(node):
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_pred.update((data["input_idx"], pred) for data in edge_data.values())
|
||||
return [idx_to_pred[i] for i in range(len(idx_to_pred))]
|
||||
@@ -144,7 +144,7 @@ class OPGraph:
|
||||
"""
|
||||
|
||||
idx_to_inp: Dict[int, Tuple[IntermediateNode, int]] = {}
|
||||
for pred in self.graph.pred[node]:
|
||||
for pred in self.graph.predecessors(node):
|
||||
edge_data = self.graph.get_edge_data(pred, node)
|
||||
idx_to_inp.update(
|
||||
(data["input_idx"], (pred, data["output_idx"])) for data in edge_data.values()
|
||||
@@ -190,7 +190,7 @@ class OPGraph:
|
||||
for node in nx.topological_sort(self.graph):
|
||||
if not isinstance(node, Input):
|
||||
curr_inputs = {}
|
||||
for pred_node in self.graph.pred[node]:
|
||||
for pred_node in self.graph.predecessors(node):
|
||||
edges = self.graph.get_edge_data(pred_node, node)
|
||||
curr_inputs.update(
|
||||
{
|
||||
@@ -296,7 +296,7 @@ class OPGraph:
|
||||
|
||||
node.outputs[0] = deepcopy(node.inputs[0])
|
||||
|
||||
successors = self.graph.succ[node]
|
||||
successors = self.graph.successors(node)
|
||||
for succ in successors:
|
||||
edge_data = self.graph.get_edge_data(node, succ)
|
||||
for edge in edge_data.values():
|
||||
@@ -306,14 +306,14 @@ class OPGraph:
|
||||
def prune_nodes(self):
|
||||
"""Remove unreachable nodes from outputs."""
|
||||
|
||||
current_nodes = set(self.output_nodes.values())
|
||||
useful_nodes: Set[IntermediateNode] = set()
|
||||
current_nodes = {node: None for node in self.get_ordered_outputs()}
|
||||
useful_nodes: Dict[IntermediateNode, None] = {}
|
||||
while current_nodes:
|
||||
next_nodes: Set[IntermediateNode] = set()
|
||||
next_nodes: Dict[IntermediateNode, None] = {}
|
||||
useful_nodes.update(current_nodes)
|
||||
for node in current_nodes:
|
||||
next_nodes.update(self.graph.pred[node])
|
||||
next_nodes.update({node: None for node in self.graph.predecessors(node)})
|
||||
current_nodes = next_nodes
|
||||
|
||||
useless_nodes = set(self.graph.nodes()) - useful_nodes
|
||||
useless_nodes = [node for node in self.graph.nodes() if node not in useful_nodes]
|
||||
self.graph.remove_nodes_from(useless_nodes)
|
||||
|
||||
@@ -95,18 +95,18 @@ def fuse_float_operations(
|
||||
|
||||
def convert_float_subgraph_to_fused_node(
|
||||
op_graph: OPGraph,
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
float_subgraph_start_nodes: Dict[IntermediateNode, None],
|
||||
terminal_node: IntermediateNode,
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
) -> Optional[Tuple[GenericFunction, IntermediateNode]]:
|
||||
"""Convert a float subgraph to an equivalent fused GenericFunction node.
|
||||
|
||||
Args:
|
||||
op_graph (OPGraph): The OPGraph the float subgraph is part of.
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph
|
||||
in `op_graph`.
|
||||
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
|
||||
subgraph in `op_graph`.
|
||||
terminal_node (IntermediateNode): The node ending the float subgraph.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[GenericFunction, IntermediateNode]]: None if the float subgraph
|
||||
@@ -151,11 +151,21 @@ def convert_float_subgraph_to_fused_node(
|
||||
|
||||
nx_graph = op_graph.graph
|
||||
|
||||
nodes_after_input_set = subgraph_all_nodes.intersection(
|
||||
nx_graph.succ[current_subgraph_variable_input]
|
||||
)
|
||||
nodes_after_input_set = {
|
||||
node: None
|
||||
for node in subgraph_all_nodes
|
||||
if node in nx_graph.succ[current_subgraph_variable_input]
|
||||
}
|
||||
|
||||
float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
|
||||
# # Previous non-deterministic implementation :
|
||||
# # For some reason creating a graph from a subgraph this way is not deterministic
|
||||
# float_subgraph = nx.MultiDiGraph(nx_graph.subgraph(subgraph_all_nodes))
|
||||
|
||||
# Create a copy of the graph, remove nodes that are not in all the subgraph nodes in order to
|
||||
# get a subgraph deterministically
|
||||
float_subgraph = nx.MultiDiGraph(nx_graph)
|
||||
nodes_to_remove = [node for node in float_subgraph.nodes() if node not in subgraph_all_nodes]
|
||||
float_subgraph.remove_nodes_from(nodes_to_remove)
|
||||
|
||||
new_subgraph_variable_input = Input(new_input_value, "float_subgraph_input", 0)
|
||||
float_subgraph.add_node(new_subgraph_variable_input)
|
||||
@@ -226,20 +236,20 @@ def is_single_int_output_node(node: IntermediateNode) -> bool:
|
||||
def find_closest_single_int_output_nodes(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
start_nodes: List[IntermediateNode],
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
) -> Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]:
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
) -> Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]:
|
||||
"""Find in nx_graph the closest upstream single integer output nodes to some start nodes.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): the networkx graph to search in.
|
||||
start_nodes (List[IntermediateNode]): the nodes from which to start the search.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): a set that will be updated with all the nodes
|
||||
visited during the search.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): a set that will be updated with all the
|
||||
nodes visited during the search.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[IntermediateNode, None], Set[IntermediateNode]]: returns the dict used as an
|
||||
ordered set containing the found single output nodes and the updated set of the visited
|
||||
nodes during the search.
|
||||
Tuple[Dict[IntermediateNode, None], Dict[IntermediateNode, None]]: returns the dict used as
|
||||
an ordered set containing the found single output nodes and the updated set of the
|
||||
visited nodes during the search.
|
||||
"""
|
||||
|
||||
# Use dict as ordered set
|
||||
@@ -252,13 +262,13 @@ def find_closest_single_int_output_nodes(
|
||||
if node in visited_nodes:
|
||||
continue
|
||||
visited_nodes.add(node)
|
||||
subgraph_all_nodes.add(node)
|
||||
predecessors = nx_graph.pred[node]
|
||||
subgraph_all_nodes.update({node: None})
|
||||
predecessors = nx_graph.predecessors(node)
|
||||
for pred in predecessors:
|
||||
if is_single_int_output_node(pred):
|
||||
# Limit of subgraph, record that and record the node as we won't visit it
|
||||
closest_single_int_output_nodes[pred] = None
|
||||
subgraph_all_nodes.add(pred)
|
||||
closest_single_int_output_nodes.update({pred: None})
|
||||
subgraph_all_nodes.update({pred: None})
|
||||
else:
|
||||
next_nodes.update({pred: None})
|
||||
current_nodes = next_nodes
|
||||
@@ -269,21 +279,21 @@ def find_closest_single_int_output_nodes(
|
||||
def add_nodes_from_to(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
from_nodes: Iterable[IntermediateNode],
|
||||
to_nodes: Set[IntermediateNode],
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
) -> Set[IntermediateNode]:
|
||||
to_nodes: Dict[IntermediateNode, None],
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
) -> Dict[IntermediateNode, None]:
|
||||
"""Add nodes from from_nodes to to_nodes to the subgraph_all_nodes set.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): the graph to traverse.
|
||||
from_nodes (Iterable[IntermediateNode]): the nodes from which we will add nodes to
|
||||
subgraph_all_nodes.
|
||||
to_nodes (Set[IntermediateNode]): the nodes we should stop at.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph, will be
|
||||
updated and returned.
|
||||
to_nodes (Dict[IntermediateNode, None]): the nodes we should stop at.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph, will
|
||||
be updated and returned.
|
||||
|
||||
Returns:
|
||||
Set[IntermediateNode]: returns the updated subgraph_all_nodes.
|
||||
Dict[IntermediateNode, None]: returns the updated subgraph_all_nodes.
|
||||
"""
|
||||
|
||||
# Add the end nodes we won't visit
|
||||
@@ -297,10 +307,10 @@ def add_nodes_from_to(
|
||||
if node in visited_nodes:
|
||||
continue
|
||||
visited_nodes.add(node)
|
||||
subgraph_all_nodes.add(node)
|
||||
predecessors = nx_graph.pred[node]
|
||||
subgraph_all_nodes.update({node: None})
|
||||
predecessors = nx_graph.predecessors(node)
|
||||
# Add nodes to explore next if they are not indicated as end nodes
|
||||
next_nodes.update({pred: node for pred in predecessors if pred not in to_nodes})
|
||||
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
|
||||
current_nodes = next_nodes
|
||||
|
||||
return subgraph_all_nodes
|
||||
@@ -309,19 +319,20 @@ def add_nodes_from_to(
|
||||
def find_float_subgraph_with_unique_terminal_node(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
processed_terminal_nodes: Set[IntermediateNode],
|
||||
) -> Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]:
|
||||
) -> Optional[Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
|
||||
"""Find a subgraph of the graph with float computations.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph): The networkx graph to search in.
|
||||
processed_terminal_nodes (Set[IntermediateNode]): The set of terminal nodes for which
|
||||
processed_terminal_nodes (Dict[IntermediateNode, None]): The set of terminal nodes for which
|
||||
subgraphs have already been searched, those will be skipped.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]:
|
||||
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a tuple
|
||||
containing the set of nodes beginning a float subgraph, the terminal node of the
|
||||
subgraph and the set of all the nodes in the subgraph.
|
||||
Optional[
|
||||
Tuple[Dict[IntermediateNode, None], IntermediateNode, Dict[IntermediateNode, None]]]:
|
||||
None if there are no float subgraphs to process in `nx_graph`. Otherwise returns a
|
||||
tuple containing the set of nodes beginning a float subgraph, the terminal node of
|
||||
the subgraph and the set of all the nodes in the subgraph.
|
||||
"""
|
||||
|
||||
def is_float_to_single_int_node(node: IntermediateNode) -> bool:
|
||||
@@ -352,12 +363,13 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
equivalent_digraph_without_constants = nx.DiGraph(nx_graph)
|
||||
constant_graph_nodes = [
|
||||
constant_node
|
||||
for constant_node in equivalent_digraph_without_constants
|
||||
for constant_node in equivalent_digraph_without_constants.nodes()
|
||||
if isinstance(constant_node, Constant)
|
||||
]
|
||||
equivalent_digraph_without_constants.remove_nodes_from(constant_graph_nodes)
|
||||
|
||||
subgraph_all_nodes: Set[IntermediateNode] = set()
|
||||
# Use dict as ordered set
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None] = {}
|
||||
|
||||
start_single_int_output_nodes_search_from = terminal_node
|
||||
|
||||
@@ -397,7 +409,7 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
# if lca is not None, add the nodes from the current start nodes to the lca to
|
||||
# subgraph_all_nodes
|
||||
subgraph_all_nodes = add_nodes_from_to(
|
||||
nx_graph, float_subgraph_start_nodes, {lca}, subgraph_all_nodes
|
||||
nx_graph, float_subgraph_start_nodes, {lca: None}, subgraph_all_nodes
|
||||
)
|
||||
|
||||
# if the lca is a valid starting node for fusing break
|
||||
@@ -410,12 +422,12 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
# integer output e.g.)
|
||||
start_single_int_output_nodes_search_from = lca
|
||||
|
||||
return set(float_subgraph_start_nodes.keys()), terminal_node, subgraph_all_nodes
|
||||
return float_subgraph_start_nodes, terminal_node, subgraph_all_nodes
|
||||
|
||||
|
||||
def subgraph_nodes_and_values_allow_fusing(
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
subgraph_all_nodes: Set[IntermediateNode],
|
||||
float_subgraph_start_nodes: Dict[IntermediateNode, None],
|
||||
subgraph_all_nodes: Dict[IntermediateNode, None],
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
|
||||
) -> bool:
|
||||
"""Check if a subgraph's values are compatible with fusing.
|
||||
@@ -424,8 +436,9 @@ def subgraph_nodes_and_values_allow_fusing(
|
||||
can be applied per cell, hence shuffling or tensor shape changes make fusing impossible.
|
||||
|
||||
Args:
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the float subgraph.
|
||||
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph.
|
||||
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the float
|
||||
subgraph.
|
||||
subgraph_all_nodes (Dict[IntermediateNode, None]): All the nodes in the float subgraph.
|
||||
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
|
||||
with potential nodes issues preventing fusing.
|
||||
|
||||
@@ -544,14 +557,14 @@ def subgraph_nodes_and_values_allow_fusing(
|
||||
|
||||
|
||||
def subgraph_has_unique_variable_input(
|
||||
float_subgraph_start_nodes: Set[IntermediateNode],
|
||||
float_subgraph_start_nodes: Dict[IntermediateNode, None],
|
||||
terminal_node: IntermediateNode,
|
||||
node_with_issues_for_fusing: DefaultDict[IntermediateNode, List[str]],
|
||||
) -> bool:
|
||||
"""Check that only one of the nodes starting the subgraph is variable.
|
||||
|
||||
Args:
|
||||
float_subgraph_start_nodes (Set[IntermediateNode]): The nodes starting the subgraph.
|
||||
float_subgraph_start_nodes (Dict[IntermediateNode, None]): The nodes starting the subgraph.
|
||||
terminal_node (IntermediateNode): The node ending the float subgraph.
|
||||
node_with_issues_for_fusing (DefaultDict[IntermediateNode, List[str]]): Dictionary to fill
|
||||
with potential nodes issues preventing fusing.
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Test file for float subgraph fusing"""
|
||||
|
||||
import random
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
|
||||
import numpy
|
||||
import pytest
|
||||
|
||||
from concrete.common.data_types.integers import Integer
|
||||
from concrete.common.debugging import format_operation_graph
|
||||
from concrete.common.debugging.custom_assert import assert_not_reached
|
||||
from concrete.common.optimization.topological import fuse_float_operations
|
||||
from concrete.common.values import EncryptedScalar, EncryptedTensor
|
||||
@@ -386,9 +388,14 @@ def test_fuse_float_operations(
|
||||
function_to_trace,
|
||||
params,
|
||||
)
|
||||
copied_graph = deepcopy(op_graph)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(copied_graph)
|
||||
|
||||
# Check determinism
|
||||
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
|
||||
|
||||
if fused:
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
@@ -504,9 +511,14 @@ def subtest_fuse_float_unary_operations_correctness(fun, tensor_shape):
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
copied_graph = deepcopy(op_graph)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(copied_graph)
|
||||
|
||||
# Check determinism
|
||||
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
@@ -643,9 +655,14 @@ def subtest_fuse_float_binary_operations_correctness(fun, tensor_shape):
|
||||
for param_name in params_names
|
||||
},
|
||||
)
|
||||
copied_graph = deepcopy(op_graph)
|
||||
orig_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(op_graph)
|
||||
fused_num_nodes = len(op_graph.graph)
|
||||
fuse_float_operations(copied_graph)
|
||||
|
||||
# Check determinism
|
||||
assert format_operation_graph(copied_graph) == format_operation_graph(op_graph)
|
||||
|
||||
assert fused_num_nodes < orig_num_nodes
|
||||
|
||||
|
||||
Reference in New Issue
Block a user