mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
631 lines
20 KiB
Python
631 lines
20 KiB
Python
"""
|
|
Declaration of various functions and constants related to compilation.
|
|
"""
|
|
|
|
from copy import deepcopy
|
|
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
|
|
|
import networkx as nx
|
|
|
|
from ..dtypes import Float, Integer
|
|
from ..representation import Graph, Node, Operation
|
|
from .artifacts import DebugArtifacts
|
|
|
|
|
|
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
|
|
"""
|
|
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
|
|
|
|
Args:
|
|
graph (Graph):
|
|
graph to search and update
|
|
|
|
artifacts (Optional[DebugArtifacts], default = None):
|
|
compilation artifacts to store information about the fusing process
|
|
"""
|
|
|
|
nx_graph = graph.graph
|
|
processed_terminal_nodes: Set[Node] = set()
|
|
|
|
fusing_floats = True
|
|
while True:
|
|
subgraph_to_fuse = (
|
|
find_float_subgraph_with_unique_terminal_node(
|
|
nx_graph,
|
|
processed_terminal_nodes,
|
|
)
|
|
if fusing_floats
|
|
else find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
|
|
nx_graph,
|
|
processed_terminal_nodes,
|
|
)
|
|
)
|
|
|
|
if subgraph_to_fuse is None:
|
|
if fusing_floats:
|
|
fusing_floats = False
|
|
processed_terminal_nodes.clear()
|
|
continue
|
|
break
|
|
|
|
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
|
|
processed_terminal_nodes.add(terminal_node)
|
|
|
|
subgraph_conversion_result = convert_subgraph_to_subgraph_node(
|
|
nx_graph,
|
|
all_nodes,
|
|
start_nodes,
|
|
terminal_node,
|
|
)
|
|
if subgraph_conversion_result is None:
|
|
continue
|
|
|
|
fused_node, node_before_subgraph = subgraph_conversion_result
|
|
nx_graph.add_node(fused_node)
|
|
|
|
if terminal_node in graph.output_nodes.values():
|
|
output_node_to_idx: Dict[Node, List[int]] = {
|
|
out_node: [] for out_node in graph.output_nodes.values()
|
|
}
|
|
for output_idx, output_node in graph.output_nodes.items():
|
|
output_node_to_idx[output_node].append(output_idx)
|
|
|
|
for output_idx in output_node_to_idx.get(terminal_node, []):
|
|
graph.output_nodes[output_idx] = fused_node
|
|
|
|
terminal_node_succ = list(nx_graph.successors(terminal_node))
|
|
for succ in terminal_node_succ:
|
|
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
|
|
for edge_key, edge_data in succ_edge_data.items():
|
|
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
|
|
new_edge_data = deepcopy(edge_data)
|
|
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
|
|
|
|
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
|
|
|
|
graph.prune_useless_nodes()
|
|
if artifacts is not None:
|
|
artifacts.add_graph("after-fusing", graph)
|
|
|
|
|
|
def find_float_subgraph_with_unique_terminal_node(
|
|
nx_graph: nx.MultiDiGraph,
|
|
processed_terminal_nodes: Set[Node],
|
|
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
|
"""
|
|
Find a subgraph with float computations that end with an integer output.
|
|
|
|
Args:
|
|
nx_graph (nx.MultiDiGraph):
|
|
graph to search
|
|
|
|
processed_terminal_nodes (Set[Node]):
|
|
set of terminal nodes which have already been searched for float subgraphs
|
|
|
|
Returns:
|
|
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
|
None if there are no such subgraphs,
|
|
tuple containing all nodes in the subgraph, start nodes of the subgraph,
|
|
and terminal node of the subgraph otherwise
|
|
"""
|
|
|
|
terminal_nodes = (
|
|
node
|
|
for node in nx_graph.nodes()
|
|
if (
|
|
node not in processed_terminal_nodes
|
|
and any(isinstance(input.dtype, Float) for input in node.inputs)
|
|
and isinstance(node.output.dtype, Integer)
|
|
)
|
|
)
|
|
try:
|
|
terminal_node = next(terminal_nodes)
|
|
except StopIteration:
|
|
return None
|
|
|
|
all_nodes: Dict[Node, None] = {}
|
|
|
|
start_single_int_output_nodes_search_from = terminal_node
|
|
while True:
|
|
all_nodes, start_nodes = find_closest_integer_output_nodes(
|
|
nx_graph,
|
|
[start_single_int_output_nodes_search_from],
|
|
all_nodes,
|
|
)
|
|
|
|
variable_start_nodes = [
|
|
start_node for start_node in start_nodes if start_node.operation != Operation.Constant
|
|
]
|
|
if len(variable_start_nodes) == 1:
|
|
break
|
|
|
|
# find a common ancestor as we need a single variable input node
|
|
# lca == lowest common ancestor
|
|
lca = find_single_lca(nx_graph, variable_start_nodes)
|
|
|
|
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
|
if lca is None:
|
|
break
|
|
|
|
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
|
|
all_nodes = add_nodes_from_to(nx_graph, start_nodes, {lca: None}, all_nodes)
|
|
|
|
# if `lca` is a valid starting node for fusing break
|
|
if isinstance(lca.output.dtype, Integer):
|
|
# `lca` is the new start node
|
|
start_nodes = {lca: None}
|
|
break
|
|
|
|
# otherwise, push a little further
|
|
# (e.g., if there is a node just before, which has an integer output)
|
|
start_single_int_output_nodes_search_from = lca
|
|
|
|
return all_nodes, start_nodes, terminal_node
|
|
|
|
|
|
def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
|
|
nx_graph: nx.MultiDiGraph,
|
|
processed_terminal_nodes: Set[Node],
|
|
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
|
"""
|
|
Find a subgraph with a tlu computation that has multiple variable inputs \
|
|
where all variable inputs share a common ancestor.
|
|
|
|
Args:
|
|
nx_graph (nx.MultiDiGraph):
|
|
graph to search
|
|
|
|
processed_terminal_nodes (Set[Node]):
|
|
set of terminal nodes which have already been searched for tlu subgraphs
|
|
|
|
Returns:
|
|
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
|
|
None if there are no such subgraphs,
|
|
tuple containing all nodes in the subgraph, start nodes of the subgraph,
|
|
and terminal node of the subgraph otherwise
|
|
"""
|
|
|
|
terminal_nodes = (
|
|
node
|
|
for node in nx_graph.nodes()
|
|
if (
|
|
node not in processed_terminal_nodes
|
|
and node.converted_to_table_lookup
|
|
and all(isinstance(input.dtype, Integer) for input in node.inputs)
|
|
and isinstance(node.output.dtype, Integer)
|
|
and len(
|
|
[
|
|
pred
|
|
for pred in nx_graph.predecessors(node)
|
|
if pred.operation != Operation.Constant
|
|
]
|
|
)
|
|
> 1
|
|
)
|
|
)
|
|
try:
|
|
terminal_node = next(terminal_nodes)
|
|
except StopIteration:
|
|
return None
|
|
|
|
all_nodes: Dict[Node, None] = {}
|
|
|
|
while True:
|
|
variable_start_nodes = list(nx_graph.predecessors(terminal_node))
|
|
|
|
# find a common ancestor as we need a single variable input node
|
|
# lca == lowest common ancestor
|
|
lca = find_single_lca(nx_graph, variable_start_nodes)
|
|
|
|
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
|
if lca is None:
|
|
start_nodes = {}
|
|
break
|
|
|
|
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
|
|
all_nodes = add_nodes_from_to(
|
|
nx_graph,
|
|
list(nx_graph.predecessors(terminal_node)),
|
|
{lca: None},
|
|
all_nodes,
|
|
)
|
|
all_nodes[terminal_node] = None
|
|
|
|
# if `lca` is a valid starting node for fusing break
|
|
if isinstance(lca.output.dtype, Integer):
|
|
# `lca` is the new start node
|
|
start_nodes = {lca: None}
|
|
break
|
|
|
|
return all_nodes, start_nodes, terminal_node
|
|
|
|
|
|
def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]:
|
|
"""
|
|
Find the single lowest common ancestor of a list of nodes.
|
|
|
|
Args:
|
|
nx_graph (nx.MultiDiGraph):
|
|
graph to search for single lca
|
|
|
|
nodes (List[Node]):
|
|
nodes to find the single lca of
|
|
|
|
Returns
|
|
Optional[Node]:
|
|
single lca if it exists, None otherwise
|
|
"""
|
|
|
|
# find all ancestors of `nodes`
|
|
# nodes themselves need to be in this set because the single lca can be within `nodes`
|
|
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
|
|
|
|
# find common ancestors among `nodes`
|
|
# if the single lca exists, it's in this set
|
|
common_ancestors = set(
|
|
node
|
|
for node in nx_graph.nodes()
|
|
if node.operation != Operation.Constant
|
|
and all(node in ancestors for ancestors in all_ancestors)
|
|
)
|
|
|
|
# iterate over every node in the graph reversed topological order
|
|
# this is to ensure result, if found, is the single "lowest" common ancestor
|
|
for candidate in reversed(list(nx.topological_sort(nx_graph))):
|
|
# check if node is a common ancestor of all `nodes`
|
|
if candidate not in common_ancestors:
|
|
# if not, it cannot be the single lca
|
|
continue
|
|
|
|
# check if node is a single common ancestor of `nodes`
|
|
if is_single_common_ancestor(nx_graph, candidate, nodes):
|
|
# if so, it's the single lca of `nodes`
|
|
# so return it
|
|
return candidate
|
|
|
|
# if none of the nodes in `common_ancestors` is the single lca
|
|
# there is no single lca of this set of nodes, so return None
|
|
return None
|
|
|
|
|
|
def is_single_common_ancestor(
|
|
nx_graph: nx.MultiDiGraph,
|
|
candidate: Node,
|
|
nodes: List[Node],
|
|
) -> bool:
|
|
"""
|
|
Determine if a node is the single common ancestor of a list of nodes.
|
|
|
|
Note that this function doesn't care about `lowest` property of `lca`.
|
|
|
|
Args:
|
|
nx_graph (nx.MultiDiGraph):
|
|
graph to perform the check
|
|
|
|
candidate (Node):
|
|
node to determine single common ancestor status
|
|
|
|
nodes (List[Node]):
|
|
nodes to determine single common ancestor status against
|
|
|
|
Returns
|
|
bool:
|
|
True if `candidate` is a single common ancestor of `nodes`, False otherwise
|
|
"""
|
|
|
|
# create a subgraph with `candidate` node
|
|
subgraph = nx.DiGraph()
|
|
subgraph.add_node(candidate)
|
|
|
|
# iterate over `nodes` to add them to the subgraph
|
|
# along with every path from `candidate` to them
|
|
for node in nodes:
|
|
subgraph.add_node(node)
|
|
for path in nx.all_simple_paths(nx_graph, source=candidate, target=node):
|
|
nx.add_path(subgraph, path)
|
|
|
|
# iterate over the nodes of the subgraph
|
|
for node in subgraph.nodes():
|
|
# the condition below doesn't apply to `candidate`
|
|
# as its predecessors are not in the subgraph
|
|
if node == candidate:
|
|
continue
|
|
|
|
# find number of predecessors in the subgraph and in the original graph
|
|
# except constant nodes in the original graph as
|
|
# - they are not in the subgraph
|
|
# - they don't affect fusability status
|
|
predecessor_count_in_subgraph = len(list(subgraph.predecessors(node)))
|
|
predecessor_count_in_nx_graph = len(
|
|
list(
|
|
pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant
|
|
)
|
|
)
|
|
|
|
# see if number of predecessors are different
|
|
if predecessor_count_in_subgraph != predecessor_count_in_nx_graph:
|
|
# if so, `candidate` cannot be a single common ancestor
|
|
# reasoning for is explained below
|
|
return False
|
|
|
|
# if every node in the subgraph has the same number of predecessors
|
|
# as in the original graph `candidate` is in fact a single common ancestor
|
|
return True
|
|
|
|
# Here is why this function works.
|
|
#
|
|
# Legend:
|
|
# - /|\- = Edge
|
|
# - (...) = Intermediate Node
|
|
# - {...} = Candidate Node
|
|
# - [...] = Node of which single common ancestor is searched
|
|
# - {[...]} = Both Candidate Node and Node of which single common ancestor is searched
|
|
#
|
|
# Consider the folowing graph:
|
|
#
|
|
# (3) (x) (2)
|
|
# \ / \ /
|
|
# [{*}] (/)
|
|
# \ /
|
|
# [+]
|
|
#
|
|
# - Operation: (x * 3) + (x / 2)
|
|
# - Candidate: {*}
|
|
# - Nodes: [*] and [+]
|
|
#
|
|
# So we want to know if multiplication node is a single common ancestor of
|
|
# multiplication and addition nodes. The result is no in this case for our purposes.
|
|
#
|
|
# Once you apply the subgraph creation above, you'll get the following graph:
|
|
#
|
|
# (*)
|
|
# |
|
|
# (+)
|
|
#
|
|
# In this subgraph, addition node only have a single predecessor,
|
|
# which means there is path leading to the addition node and that path doesn't include
|
|
# the multiplication node, so we conclude multiplication node is not a single common ancestor
|
|
#
|
|
# Now, consider the folowing graph:
|
|
#
|
|
# (3) {x} (2)
|
|
# \ / \ /
|
|
# [*] (/)
|
|
# \ /
|
|
# [+]
|
|
#
|
|
# - Operation: (x * 3) + (x / 2)
|
|
# - Candidate: {x}
|
|
# - Nodes: [*] and [+]
|
|
#
|
|
# So we want to know if the input node 'x' is the single common ancestor of
|
|
# multiplication and addition nodes. The result is yes in this case.
|
|
#
|
|
# Once you apply the subgraph creation above, you'll get the following graph:
|
|
#
|
|
# {x}
|
|
# / \
|
|
# [*] (/)
|
|
# \ /
|
|
# [+]
|
|
#
|
|
# In this subgraph, every node except the candidate node
|
|
# will keep all of their non-constant predecessors,
|
|
# which means all of their non-constant predecessors originated
|
|
# from the `candidate`, so it's a single common anscestor.
|
|
#
|
|
# When you think about it, this implementation makes a lot of sense for our purposes
|
|
# It basically determines if `nodes` "solely" depend on the `candidate`,
|
|
# which is the condition for fusing.
|
|
|
|
|
|
def find_closest_integer_output_nodes(
|
|
nx_graph: nx.MultiDiGraph,
|
|
start_nodes: List[Node],
|
|
all_nodes: Dict[Node, None],
|
|
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
|
|
"""
|
|
Find the closest upstream integer output nodes to a set of start nodes in a graph.
|
|
|
|
Args:
|
|
nx_graph (nx.MultiDiGraph):
|
|
graph to search
|
|
|
|
start_nodes (List[Node]):
|
|
nodes from which to start the search
|
|
|
|
all_nodes (Dict[Node, None]):
|
|
set of nodes to be extended with visited nodes during the search
|
|
|
|
Returns:
|
|
Tuple[Dict[Node, None], Dict[Node, None]]:
|
|
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
|
|
"""
|
|
|
|
closest_integer_output_nodes: Dict[Node, None] = {}
|
|
visited_nodes: Set[Node] = set()
|
|
|
|
current_nodes = {start_node: None for start_node in start_nodes}
|
|
while current_nodes:
|
|
next_nodes: Dict[Node, None] = {}
|
|
for node in current_nodes:
|
|
if node not in visited_nodes:
|
|
visited_nodes.add(node)
|
|
|
|
all_nodes.update({node: None})
|
|
for pred in nx_graph.predecessors(node):
|
|
if isinstance(pred.output.dtype, Integer):
|
|
closest_integer_output_nodes.update({pred: None})
|
|
all_nodes.update({pred: None})
|
|
else:
|
|
next_nodes.update({pred: None})
|
|
current_nodes = next_nodes
|
|
|
|
return all_nodes, closest_integer_output_nodes
|
|
|
|
|
|
def add_nodes_from_to(
|
|
nx_graph: nx.MultiDiGraph,
|
|
from_nodes: Iterable[Node],
|
|
to_nodes: Dict[Node, None],
|
|
all_nodes: Dict[Node, None],
|
|
) -> Dict[Node, None]:
|
|
"""
|
|
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
|
|
|
|
Args:
|
|
nx_graph (nx.MultiDiGraph):
|
|
graph to traverse
|
|
|
|
from_nodes (Iterable[Node]):
|
|
nodes from which extending `all_nodes` start
|
|
|
|
to_nodes (Dict[Node, None]):
|
|
nodes to which extending `all_nodes` stop
|
|
|
|
all_nodes (Dict[Node, None]):
|
|
nodes to be extended
|
|
|
|
Returns:
|
|
Dict[Node, None]:
|
|
extended `all_nodes`
|
|
"""
|
|
|
|
all_nodes.update(to_nodes)
|
|
visited_nodes: Set[Node] = set()
|
|
|
|
current_nodes = {from_node: None for from_node in from_nodes}
|
|
while current_nodes:
|
|
next_nodes: Dict[Node, None] = {}
|
|
for node in current_nodes:
|
|
if node not in visited_nodes:
|
|
visited_nodes.add(node)
|
|
|
|
all_nodes.update({node: None})
|
|
if node not in to_nodes:
|
|
predecessors = nx_graph.predecessors(node)
|
|
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
|
|
current_nodes = next_nodes
|
|
|
|
return all_nodes
|
|
|
|
|
|
def convert_subgraph_to_subgraph_node(
|
|
nx_graph: nx.MultiDiGraph,
|
|
all_nodes: Dict[Node, None],
|
|
start_nodes: Dict[Node, None],
|
|
terminal_node: Node,
|
|
) -> Optional[Tuple[Node, Node]]:
|
|
"""
|
|
Convert a subgraph to Operation.Generic node.
|
|
|
|
Args:
|
|
nx_graph (nx.MultiDiGraph):
|
|
orginal networkx graph
|
|
|
|
all_nodes (Dict[Node, None]):
|
|
all nodes in the subgraph
|
|
|
|
start_nodes (Dict[Node, None]):
|
|
start nodes of the subgraph
|
|
|
|
terminal_node (Node):
|
|
terminal node of the subgraph
|
|
|
|
Returns:
|
|
Optional[Tuple[Node, Node]]:
|
|
None if the subgraph cannot be fused,
|
|
subgraph node and its predecessor otherwise
|
|
"""
|
|
|
|
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
|
|
if len(variable_input_nodes) != 1:
|
|
return None
|
|
|
|
variable_input_node = variable_input_nodes[0]
|
|
if not subgraph_can_be_fused(all_nodes, variable_input_node):
|
|
return None
|
|
|
|
nx_subgraph = nx.MultiDiGraph(nx_graph)
|
|
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
|
|
nx_subgraph.remove_nodes_from(nodes_to_remove)
|
|
|
|
subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output))
|
|
nx_subgraph.add_node(subgraph_variable_input_node)
|
|
|
|
variable_input_node_successors = {
|
|
node: None for node in all_nodes if node in nx_graph.succ[variable_input_node]
|
|
}
|
|
for successor in variable_input_node_successors:
|
|
edges = deepcopy(nx_subgraph.get_edge_data(variable_input_node, successor))
|
|
for edge_key, edge_data in edges.items():
|
|
nx_subgraph.remove_edge(variable_input_node, successor, key=edge_key)
|
|
new_edge_data = deepcopy(edge_data)
|
|
nx_subgraph.add_edge(
|
|
subgraph_variable_input_node,
|
|
successor,
|
|
key=edge_key,
|
|
**new_edge_data,
|
|
)
|
|
|
|
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
|
|
subgraph_node = Node.generic(
|
|
"subgraph",
|
|
subgraph_variable_input_node.inputs,
|
|
terminal_node.output,
|
|
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
|
|
kwargs={
|
|
"subgraph": subgraph,
|
|
"terminal_node": terminal_node,
|
|
},
|
|
)
|
|
|
|
return subgraph_node, variable_input_node
|
|
|
|
|
|
def subgraph_can_be_fused(
|
|
all_nodes: Dict[Node, None],
|
|
variable_input_node: Node,
|
|
) -> bool:
|
|
"""
|
|
Determine if a subgraph can be fused.
|
|
|
|
e.g.,
|
|
|
|
shuffling or reshaping a tensor make fusing impossible as there should be a one-to-one mapping
|
|
between each cell of the input and each cell of the output for table lookups
|
|
|
|
Args:
|
|
all_nodes (Dict[Node, None]):
|
|
all nodes in the subgraph
|
|
|
|
variable_input_node (Node):
|
|
variable input node to the subgraph
|
|
|
|
Returns:
|
|
bool:
|
|
True if subgraph can be fused,
|
|
False otherwise
|
|
"""
|
|
|
|
constant_nodes_with_bigger_size_than_variable_input = [
|
|
node
|
|
for node in all_nodes
|
|
if (
|
|
node.operation == Operation.Constant
|
|
and node.output.size > variable_input_node.output.size
|
|
)
|
|
]
|
|
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
|
|
return False
|
|
|
|
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
|
|
for node in non_constant_nodes:
|
|
if node == variable_input_node:
|
|
continue
|
|
|
|
if not node.is_fusable or node.output.shape != variable_input_node.output.shape:
|
|
return False
|
|
|
|
return True
|