Files
concrete/concrete/numpy/compilation/utils.py
2022-06-16 10:25:17 +03:00

631 lines
20 KiB
Python

"""
Declaration of various functions and constants related to compilation.
"""
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Set, Tuple
import networkx as nx
from ..dtypes import Float, Integer
from ..representation import Graph, Node, Operation
from .artifacts import DebugArtifacts
def fuse(graph: Graph, artifacts: Optional[DebugArtifacts] = None):
"""
Fuse appropriate subgraphs in a graph to a single Operation.Generic node.
Args:
graph (Graph):
graph to search and update
artifacts (Optional[DebugArtifacts], default = None):
compilation artifacts to store information about the fusing process
"""
nx_graph = graph.graph
processed_terminal_nodes: Set[Node] = set()
fusing_floats = True
while True:
subgraph_to_fuse = (
find_float_subgraph_with_unique_terminal_node(
nx_graph,
processed_terminal_nodes,
)
if fusing_floats
else find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
nx_graph,
processed_terminal_nodes,
)
)
if subgraph_to_fuse is None:
if fusing_floats:
fusing_floats = False
processed_terminal_nodes.clear()
continue
break
all_nodes, start_nodes, terminal_node = subgraph_to_fuse
processed_terminal_nodes.add(terminal_node)
subgraph_conversion_result = convert_subgraph_to_subgraph_node(
nx_graph,
all_nodes,
start_nodes,
terminal_node,
)
if subgraph_conversion_result is None:
continue
fused_node, node_before_subgraph = subgraph_conversion_result
nx_graph.add_node(fused_node)
if terminal_node in graph.output_nodes.values():
output_node_to_idx: Dict[Node, List[int]] = {
out_node: [] for out_node in graph.output_nodes.values()
}
for output_idx, output_node in graph.output_nodes.items():
output_node_to_idx[output_node].append(output_idx)
for output_idx in output_node_to_idx.get(terminal_node, []):
graph.output_nodes[output_idx] = fused_node
terminal_node_succ = list(nx_graph.successors(terminal_node))
for succ in terminal_node_succ:
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
for edge_key, edge_data in succ_edge_data.items():
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
new_edge_data = deepcopy(edge_data)
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
graph.prune_useless_nodes()
if artifacts is not None:
artifacts.add_graph("after-fusing", graph)
def find_float_subgraph_with_unique_terminal_node(
nx_graph: nx.MultiDiGraph,
processed_terminal_nodes: Set[Node],
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
"""
Find a subgraph with float computations that end with an integer output.
Args:
nx_graph (nx.MultiDiGraph):
graph to search
processed_terminal_nodes (Set[Node]):
set of terminal nodes which have already been searched for float subgraphs
Returns:
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
None if there are no such subgraphs,
tuple containing all nodes in the subgraph, start nodes of the subgraph,
and terminal node of the subgraph otherwise
"""
terminal_nodes = (
node
for node in nx_graph.nodes()
if (
node not in processed_terminal_nodes
and any(isinstance(input.dtype, Float) for input in node.inputs)
and isinstance(node.output.dtype, Integer)
)
)
try:
terminal_node = next(terminal_nodes)
except StopIteration:
return None
all_nodes: Dict[Node, None] = {}
start_single_int_output_nodes_search_from = terminal_node
while True:
all_nodes, start_nodes = find_closest_integer_output_nodes(
nx_graph,
[start_single_int_output_nodes_search_from],
all_nodes,
)
variable_start_nodes = [
start_node for start_node in start_nodes if start_node.operation != Operation.Constant
]
if len(variable_start_nodes) == 1:
break
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
lca = find_single_lca(nx_graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
break
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
all_nodes = add_nodes_from_to(nx_graph, start_nodes, {lca: None}, all_nodes)
# if `lca` is a valid starting node for fusing break
if isinstance(lca.output.dtype, Integer):
# `lca` is the new start node
start_nodes = {lca: None}
break
# otherwise, push a little further
# (e.g., if there is a node just before, which has an integer output)
start_single_int_output_nodes_search_from = lca
return all_nodes, start_nodes, terminal_node
def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_ancestor(
nx_graph: nx.MultiDiGraph,
processed_terminal_nodes: Set[Node],
) -> Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
"""
Find a subgraph with a tlu computation that has multiple variable inputs \
where all variable inputs share a common ancestor.
Args:
nx_graph (nx.MultiDiGraph):
graph to search
processed_terminal_nodes (Set[Node]):
set of terminal nodes which have already been searched for tlu subgraphs
Returns:
Optional[Tuple[Dict[Node, None], Dict[Node, None], Node]]:
None if there are no such subgraphs,
tuple containing all nodes in the subgraph, start nodes of the subgraph,
and terminal node of the subgraph otherwise
"""
terminal_nodes = (
node
for node in nx_graph.nodes()
if (
node not in processed_terminal_nodes
and node.converted_to_table_lookup
and all(isinstance(input.dtype, Integer) for input in node.inputs)
and isinstance(node.output.dtype, Integer)
and len(
[
pred
for pred in nx_graph.predecessors(node)
if pred.operation != Operation.Constant
]
)
> 1
)
)
try:
terminal_node = next(terminal_nodes)
except StopIteration:
return None
all_nodes: Dict[Node, None] = {}
while True:
variable_start_nodes = list(nx_graph.predecessors(terminal_node))
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
lca = find_single_lca(nx_graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
start_nodes = {}
break
# add the nodes from the `start_nodes` to `lca`, to `all_nodes`
all_nodes = add_nodes_from_to(
nx_graph,
list(nx_graph.predecessors(terminal_node)),
{lca: None},
all_nodes,
)
all_nodes[terminal_node] = None
# if `lca` is a valid starting node for fusing break
if isinstance(lca.output.dtype, Integer):
# `lca` is the new start node
start_nodes = {lca: None}
break
return all_nodes, start_nodes, terminal_node
def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]:
"""
Find the single lowest common ancestor of a list of nodes.
Args:
nx_graph (nx.MultiDiGraph):
graph to search for single lca
nodes (List[Node]):
nodes to find the single lca of
Returns
Optional[Node]:
single lca if it exists, None otherwise
"""
# find all ancestors of `nodes`
# nodes themselves need to be in this set because the single lca can be within `nodes`
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
# find common ancestors among `nodes`
# if the single lca exists, it's in this set
common_ancestors = set(
node
for node in nx_graph.nodes()
if node.operation != Operation.Constant
and all(node in ancestors for ancestors in all_ancestors)
)
# iterate over every node in the graph reversed topological order
# this is to ensure result, if found, is the single "lowest" common ancestor
for candidate in reversed(list(nx.topological_sort(nx_graph))):
# check if node is a common ancestor of all `nodes`
if candidate not in common_ancestors:
# if not, it cannot be the single lca
continue
# check if node is a single common ancestor of `nodes`
if is_single_common_ancestor(nx_graph, candidate, nodes):
# if so, it's the single lca of `nodes`
# so return it
return candidate
# if none of the nodes in `common_ancestors` is the single lca
# there is no single lca of this set of nodes, so return None
return None
def is_single_common_ancestor(
nx_graph: nx.MultiDiGraph,
candidate: Node,
nodes: List[Node],
) -> bool:
"""
Determine if a node is the single common ancestor of a list of nodes.
Note that this function doesn't care about `lowest` property of `lca`.
Args:
nx_graph (nx.MultiDiGraph):
graph to perform the check
candidate (Node):
node to determine single common ancestor status
nodes (List[Node]):
nodes to determine single common ancestor status against
Returns
bool:
True if `candidate` is a single common ancestor of `nodes`, False otherwise
"""
# create a subgraph with `candidate` node
subgraph = nx.DiGraph()
subgraph.add_node(candidate)
# iterate over `nodes` to add them to the subgraph
# along with every path from `candidate` to them
for node in nodes:
subgraph.add_node(node)
for path in nx.all_simple_paths(nx_graph, source=candidate, target=node):
nx.add_path(subgraph, path)
# iterate over the nodes of the subgraph
for node in subgraph.nodes():
# the condition below doesn't apply to `candidate`
# as its predecessors are not in the subgraph
if node == candidate:
continue
# find number of predecessors in the subgraph and in the original graph
# except constant nodes in the original graph as
# - they are not in the subgraph
# - they don't affect fusability status
predecessor_count_in_subgraph = len(list(subgraph.predecessors(node)))
predecessor_count_in_nx_graph = len(
list(
pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant
)
)
# see if number of predecessors are different
if predecessor_count_in_subgraph != predecessor_count_in_nx_graph:
# if so, `candidate` cannot be a single common ancestor
# reasoning for is explained below
return False
# if every node in the subgraph has the same number of predecessors
# as in the original graph `candidate` is in fact a single common ancestor
return True
# Here is why this function works.
#
# Legend:
# - /|\- = Edge
# - (...) = Intermediate Node
# - {...} = Candidate Node
# - [...] = Node of which single common ancestor is searched
# - {[...]} = Both Candidate Node and Node of which single common ancestor is searched
#
# Consider the folowing graph:
#
# (3) (x) (2)
# \ / \ /
# [{*}] (/)
# \ /
# [+]
#
# - Operation: (x * 3) + (x / 2)
# - Candidate: {*}
# - Nodes: [*] and [+]
#
# So we want to know if multiplication node is a single common ancestor of
# multiplication and addition nodes. The result is no in this case for our purposes.
#
# Once you apply the subgraph creation above, you'll get the following graph:
#
# (*)
# |
# (+)
#
# In this subgraph, addition node only have a single predecessor,
# which means there is path leading to the addition node and that path doesn't include
# the multiplication node, so we conclude multiplication node is not a single common ancestor
#
# Now, consider the folowing graph:
#
# (3) {x} (2)
# \ / \ /
# [*] (/)
# \ /
# [+]
#
# - Operation: (x * 3) + (x / 2)
# - Candidate: {x}
# - Nodes: [*] and [+]
#
# So we want to know if the input node 'x' is the single common ancestor of
# multiplication and addition nodes. The result is yes in this case.
#
# Once you apply the subgraph creation above, you'll get the following graph:
#
# {x}
# / \
# [*] (/)
# \ /
# [+]
#
# In this subgraph, every node except the candidate node
# will keep all of their non-constant predecessors,
# which means all of their non-constant predecessors originated
# from the `candidate`, so it's a single common anscestor.
#
# When you think about it, this implementation makes a lot of sense for our purposes
# It basically determines if `nodes` "solely" depend on the `candidate`,
# which is the condition for fusing.
def find_closest_integer_output_nodes(
nx_graph: nx.MultiDiGraph,
start_nodes: List[Node],
all_nodes: Dict[Node, None],
) -> Tuple[Dict[Node, None], Dict[Node, None]]:
"""
Find the closest upstream integer output nodes to a set of start nodes in a graph.
Args:
nx_graph (nx.MultiDiGraph):
graph to search
start_nodes (List[Node]):
nodes from which to start the search
all_nodes (Dict[Node, None]):
set of nodes to be extended with visited nodes during the search
Returns:
Tuple[Dict[Node, None], Dict[Node, None]]:
tuple containing extended `all_nodes` and integer output nodes closest to `start_nodes`
"""
closest_integer_output_nodes: Dict[Node, None] = {}
visited_nodes: Set[Node] = set()
current_nodes = {start_node: None for start_node in start_nodes}
while current_nodes:
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.update({node: None})
for pred in nx_graph.predecessors(node):
if isinstance(pred.output.dtype, Integer):
closest_integer_output_nodes.update({pred: None})
all_nodes.update({pred: None})
else:
next_nodes.update({pred: None})
current_nodes = next_nodes
return all_nodes, closest_integer_output_nodes
def add_nodes_from_to(
nx_graph: nx.MultiDiGraph,
from_nodes: Iterable[Node],
to_nodes: Dict[Node, None],
all_nodes: Dict[Node, None],
) -> Dict[Node, None]:
"""
Add nodes from `from_nodes` to `to_nodes`, to `all_nodes`.
Args:
nx_graph (nx.MultiDiGraph):
graph to traverse
from_nodes (Iterable[Node]):
nodes from which extending `all_nodes` start
to_nodes (Dict[Node, None]):
nodes to which extending `all_nodes` stop
all_nodes (Dict[Node, None]):
nodes to be extended
Returns:
Dict[Node, None]:
extended `all_nodes`
"""
all_nodes.update(to_nodes)
visited_nodes: Set[Node] = set()
current_nodes = {from_node: None for from_node in from_nodes}
while current_nodes:
next_nodes: Dict[Node, None] = {}
for node in current_nodes:
if node not in visited_nodes:
visited_nodes.add(node)
all_nodes.update({node: None})
if node not in to_nodes:
predecessors = nx_graph.predecessors(node)
next_nodes.update({pred: None for pred in predecessors if pred not in to_nodes})
current_nodes = next_nodes
return all_nodes
def convert_subgraph_to_subgraph_node(
nx_graph: nx.MultiDiGraph,
all_nodes: Dict[Node, None],
start_nodes: Dict[Node, None],
terminal_node: Node,
) -> Optional[Tuple[Node, Node]]:
"""
Convert a subgraph to Operation.Generic node.
Args:
nx_graph (nx.MultiDiGraph):
orginal networkx graph
all_nodes (Dict[Node, None]):
all nodes in the subgraph
start_nodes (Dict[Node, None]):
start nodes of the subgraph
terminal_node (Node):
terminal node of the subgraph
Returns:
Optional[Tuple[Node, Node]]:
None if the subgraph cannot be fused,
subgraph node and its predecessor otherwise
"""
variable_input_nodes = [node for node in start_nodes if node.operation != Operation.Constant]
if len(variable_input_nodes) != 1:
return None
variable_input_node = variable_input_nodes[0]
if not subgraph_can_be_fused(all_nodes, variable_input_node):
return None
nx_subgraph = nx.MultiDiGraph(nx_graph)
nodes_to_remove = [node for node in nx_subgraph.nodes() if node not in all_nodes]
nx_subgraph.remove_nodes_from(nodes_to_remove)
subgraph_variable_input_node = Node.input("input", deepcopy(variable_input_node.output))
nx_subgraph.add_node(subgraph_variable_input_node)
variable_input_node_successors = {
node: None for node in all_nodes if node in nx_graph.succ[variable_input_node]
}
for successor in variable_input_node_successors:
edges = deepcopy(nx_subgraph.get_edge_data(variable_input_node, successor))
for edge_key, edge_data in edges.items():
nx_subgraph.remove_edge(variable_input_node, successor, key=edge_key)
new_edge_data = deepcopy(edge_data)
nx_subgraph.add_edge(
subgraph_variable_input_node,
successor,
key=edge_key,
**new_edge_data,
)
subgraph = Graph(nx_subgraph, {0: subgraph_variable_input_node}, {0: terminal_node})
subgraph_node = Node.generic(
"subgraph",
subgraph_variable_input_node.inputs,
terminal_node.output,
lambda x, subgraph, terminal_node: subgraph.evaluate(x)[terminal_node],
kwargs={
"subgraph": subgraph,
"terminal_node": terminal_node,
},
)
return subgraph_node, variable_input_node
def subgraph_can_be_fused(
all_nodes: Dict[Node, None],
variable_input_node: Node,
) -> bool:
"""
Determine if a subgraph can be fused.
e.g.,
shuffling or reshaping a tensor make fusing impossible as there should be a one-to-one mapping
between each cell of the input and each cell of the output for table lookups
Args:
all_nodes (Dict[Node, None]):
all nodes in the subgraph
variable_input_node (Node):
variable input node to the subgraph
Returns:
bool:
True if subgraph can be fused,
False otherwise
"""
constant_nodes_with_bigger_size_than_variable_input = [
node
for node in all_nodes
if (
node.operation == Operation.Constant
and node.output.size > variable_input_node.output.size
)
]
if len(constant_nodes_with_bigger_size_than_variable_input) > 0:
return False
non_constant_nodes = (node for node in all_nodes if node.operation != Operation.Constant)
for node in non_constant_nodes:
if node == variable_input_node:
continue
if not node.is_fusable or node.output.shape != variable_input_node.output.shape:
return False
return True