feat(optimization): support more fusing topologies

- corrected docstring that was mistaken on what was returned
- updated pyproject.toml to ignore warnings that happened naturally in
networkx and that was blocking proper test execution (no way around that
this is code from networkx that triggered the warning)
- add a test case for the newly supported fusing topology

closes #499
This commit is contained in:
Arthur Meyre
2021-11-17 11:53:35 +01:00
parent bc145e21e1
commit ff03bc2220
4 changed files with 153 additions and 32 deletions

View File

@@ -1,7 +1,7 @@
"""File holding topological optimization/simplification code."""
from collections import defaultdict
from copy import deepcopy
from typing import DefaultDict, Dict, List, Optional, Set, Tuple, cast
from typing import DefaultDict, Dict, Iterable, List, Optional, Set, Tuple, cast
import networkx as nx
from loguru import logger
@@ -245,9 +245,13 @@ def find_closest_single_int_output_nodes(
# Use dict as ordered set
current_nodes = {start_node: None for start_node in start_nodes}
closest_single_int_output_nodes: Dict[IntermediateNode, None] = {}
visited_nodes: Set[IntermediateNode] = set()
while current_nodes:
next_nodes: Dict[IntermediateNode, None] = {}
for node in current_nodes:
if node in visited_nodes:
continue
visited_nodes.add(node)
subgraph_all_nodes.add(node)
predecessors = nx_graph.pred[node]
for pred in predecessors:
@@ -262,15 +266,52 @@ def find_closest_single_int_output_nodes(
return closest_single_int_output_nodes, subgraph_all_nodes
def add_nodes_from_to(
nx_graph: nx.MultiDiGraph,
from_nodes: Iterable[IntermediateNode],
to_nodes: Set[IntermediateNode],
subgraph_all_nodes: Set[IntermediateNode],
) -> Set[IntermediateNode]:
"""Add nodes from from_nodes to to_nodes to the subgraph_all_nodes set.
Args:
nx_graph (nx.MultiDiGraph): the graph to traverse.
from_nodes (Iterable[IntermediateNode]): the nodes from which we will add nodes to
subgraph_all_nodes.
to_nodes (Set[IntermediateNode]): the nodes we should stop at.
subgraph_all_nodes (Set[IntermediateNode]): All the nodes in the float subgraph, will be
updated and returned.
Returns:
Set[IntermediateNode]: returns the updated subgraph_all_nodes.
"""
# Add the end nodes we won't visit
subgraph_all_nodes.update(to_nodes)
current_nodes = {from_node: None for from_node in from_nodes}
visited_nodes: Set[IntermediateNode] = set()
while current_nodes:
next_nodes: Dict[IntermediateNode, None] = {}
for node in current_nodes:
if node in visited_nodes:
continue
visited_nodes.add(node)
subgraph_all_nodes.add(node)
predecessors = nx_graph.pred[node]
# Add nodes to explore next if they are not indicated as end nodes
next_nodes.update({pred: node for pred in predecessors if pred not in to_nodes})
current_nodes = next_nodes
return subgraph_all_nodes
def find_float_subgraph_with_unique_terminal_node(
nx_graph: nx.MultiDiGraph,
processed_terminal_nodes: Set[IntermediateNode],
) -> Optional[Tuple[Set[IntermediateNode], IntermediateNode, Set[IntermediateNode]]]:
"""Find a subgraph of the graph with float computations.
The subgraph has a single terminal node with a single Integer output and has a single variable
predecessor node with a single Integer output.
Args:
nx_graph (nx.MultiDiGraph): The networkx graph to search in.
processed_terminal_nodes (Set[IntermediateNode]): The set of terminal nodes for which
@@ -303,13 +344,71 @@ def find_float_subgraph_with_unique_terminal_node(
except StopIteration:
return None
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
# about parent relationship here and not the meaning of edges, so we can convert our
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
# issues in our search so we remove them.
equivalent_digraph_without_constants = nx.DiGraph(nx_graph)
constant_graph_nodes = [
constant_node
for constant_node in equivalent_digraph_without_constants
if isinstance(constant_node, Constant)
]
equivalent_digraph_without_constants.remove_nodes_from(constant_graph_nodes)
subgraph_all_nodes: Set[IntermediateNode] = set()
float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes(
nx_graph,
[terminal_node],
subgraph_all_nodes,
)
start_single_int_output_nodes_search_from = terminal_node
while True:
float_subgraph_start_nodes, subgraph_all_nodes = find_closest_single_int_output_nodes(
nx_graph,
[start_single_int_output_nodes_search_from],
subgraph_all_nodes,
)
variable_start_nodes = [
start_node
for start_node in float_subgraph_start_nodes
if not isinstance(start_node, Constant)
]
# We found a single input variable node
if len(variable_start_nodes) == 1:
break
# Otherwise find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
# lca search only works for node pairs in networkx, so we progressively find the ancestors
# setting the lca by default to one of the nodes we are searching the lca for
lca = variable_start_nodes.pop()
while len(variable_start_nodes) > 0 and lca is not None:
node_to_find_lca = variable_start_nodes.pop()
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
equivalent_digraph_without_constants, lca, node_to_find_lca, default=None
)
# The subgraph cannot be fused as there is no way to find a common ancestor
if lca is None:
break
# if lca is not None, add the nodes from the current start nodes to the lca to
# subgraph_all_nodes
subgraph_all_nodes = add_nodes_from_to(
nx_graph, float_subgraph_start_nodes, {lca}, subgraph_all_nodes
)
# if the lca is a valid starting node for fusing break
if is_single_int_output_node(lca):
# the lca is our new start node
float_subgraph_start_nodes = {lca: None}
break
# otherwise push a little bit further the search (if there is a node just before that has an
# integer output e.g.)
start_single_int_output_nodes_search_from = lca
return set(float_subgraph_start_nodes.keys()), terminal_node, subgraph_all_nodes

View File

@@ -56,25 +56,3 @@ def non_fusable(x, y):
```
From `add_int` you will find two `Add` nodes going from int to float (`x_1` and `y_1`) which we cannot represent with a single input table look-up.
## Possible improvements
This technique is not perfect because you could try to go further back in the graph to find a single variable input.
Firstly, it does not cover optimizing the graph, so you can end up with multiple operations, like additions with constants, or two look-up tables in a row, that can trivially be fused but that are not fused, as the optimization work is left to the compiler backend with MLIR and LLVM tooling. This first limitation does not impact the kind of programs that are compilable.
Secondly, the current approach fails to handle some programs that in practice could be compiled. The following example could be covered by pushing the search to find a single integer input:
<!--python-test:skip-->
```python
def theoretically_fusable(x):
x_1 = x + 1.5
x_2 = x + 3.4
add = x_1 + x_2
add_int = add.astype(numpy.int32)
return add_int
```
Here the whole function is a single int giving a single int (i.e. representable by a table look-up), but the current implementation of fusing is going to find `x_1` and `x_2` as the starting nodes of the float subgraph and say it cannot fuse it. Room for improvement. It is probably a graph coloring problem where we would have to list for all graph's inputs which nodes depend on them.
At some point having a proper optimization system with patterns and rules to rewrite them could become more interesting than having this ad-hoc system.

View File

@@ -53,7 +53,9 @@ build-backend = "poetry.core.masonry.api"
[tool.pytest.ini_options]
filterwarnings = [
"error"
"error",
"ignore:pandas not found, skipping conversion test.:ImportWarning",
"ignore:scipy not found, skipping conversion test.:ImportWarning",
]
[tool.semantic_release]

View File

@@ -27,6 +27,34 @@ def no_fuse_unhandled(x, y):
return intermediate.astype(numpy.int32)
def fusable_with_bigger_search(x, y):
"""fusable with bigger search"""
x = x + 1
x_1 = x.astype(numpy.int32)
x_1 = x_1 + 1.5
x_2 = x.astype(numpy.int32)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(numpy.int32)
return add_int + y
def fusable_with_bigger_search_needs_second_iteration(x, y):
"""fusable with bigger search and triggers a second iteration in the fusing"""
x = x + 1
x = x + 0.5
x = numpy.cos(x)
x_1 = x.astype(numpy.int32)
x_1 = x_1 + 1.5
x_p = x + 1
x_p2 = x_p + 1
x_2 = (x_p + x_p2).astype(numpy.int32)
x_2 = x_2 + 3.4
add = x_1 + x_2
add_int = add.astype(numpy.int32)
return add_int + y
def no_fuse_big_constant_3_10_10(x):
"""Pass an array x with size < 100 to trigger a no fuse condition."""
x = x.astype(numpy.float64)
@@ -177,6 +205,20 @@ return %7
""".strip(), # noqa: E501 # pylint: disable=line-too-long
id="no_fuse_unhandled",
),
pytest.param(
fusable_with_bigger_search,
True,
get_func_params_int32(fusable_with_bigger_search),
None,
id="fusable_with_bigger_search",
),
pytest.param(
fusable_with_bigger_search_needs_second_iteration,
True,
get_func_params_int32(fusable_with_bigger_search_needs_second_iteration),
None,
id="fusable_with_bigger_search",
),
pytest.param(
no_fuse_dot,
False,