fix: properly determine lca during fusing

This commit is contained in:
Umut
2022-06-14 10:20:48 +02:00
parent 11819fcf2f
commit ce1712c67c
2 changed files with 256 additions and 94 deletions

View File

@@ -123,19 +123,6 @@ def find_float_subgraph_with_unique_terminal_node(
except StopIteration:
return None
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
# about parent relationship here and not the meaning of edges, so we can convert our
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
# issues in our search, so we remove them.
equivalent_subgraph_without_constants = nx.DiGraph(nx_graph)
constant_nodes = [
node
for node in equivalent_subgraph_without_constants.nodes()
if node.operation == Operation.Constant
]
equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes)
all_nodes: Dict[Node, None] = {}
start_single_int_output_nodes_search_from = terminal_node
@@ -154,40 +141,7 @@ def find_float_subgraph_with_unique_terminal_node(
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
# lca search only works for node pairs in networkx, so we progressively find the ancestors
# setting the lca by default to one of the nodes we are searching the lca for
lca = variable_start_nodes.pop()
while len(variable_start_nodes) > 0 and lca is not None:
node_to_find_new_lca = variable_start_nodes.pop()
if lca == node_to_find_new_lca:
continue
ancestors_of_lca = nx.ancestors(
equivalent_subgraph_without_constants,
lca,
)
ancestors_of_node_to_find_new_lca = nx.ancestors(
equivalent_subgraph_without_constants,
node_to_find_new_lca,
)
lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca
node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca
if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca:
variable_start_nodes += list(
pred
for pred in nx_graph.predecessors(
node_to_find_new_lca if lca_is_ancestor_of_node_to_find_new_lca else lca
)
if pred.operation != Operation.Constant
)
lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca
continue
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None
)
lca = find_single_lca(nx_graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
@@ -254,19 +208,6 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
except StopIteration:
return None
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
# about parent relationship here and not the meaning of edges, so we can convert our
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
# issues in our search, so we remove them.
equivalent_subgraph_without_constants = nx.DiGraph(nx_graph)
constant_nodes = [
node
for node in equivalent_subgraph_without_constants.nodes()
if node.operation == Operation.Constant
]
equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes)
all_nodes: Dict[Node, None] = {}
while True:
@@ -274,40 +215,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
# find a common ancestor as we need a single variable input node
# lca == lowest common ancestor
# lca search only works for node pairs in networkx, so we progressively find the ancestors
# setting the lca by default to one of the nodes we are searching the lca for
lca = variable_start_nodes.pop()
while len(variable_start_nodes) > 0 and lca is not None:
node_to_find_new_lca = variable_start_nodes.pop()
if lca == node_to_find_new_lca:
continue
ancestors_of_lca = nx.ancestors(
equivalent_subgraph_without_constants,
lca,
)
ancestors_of_node_to_find_new_lca = nx.ancestors(
equivalent_subgraph_without_constants,
node_to_find_new_lca,
)
lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca
node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca
if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca:
variable_start_nodes += list(
pred
for pred in nx_graph.predecessors(
node_to_find_new_lca if lca_is_ancestor_of_node_to_find_new_lca else lca
)
if pred.operation != Operation.Constant
)
lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca
continue
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None
)
lca = find_single_lca(nx_graph, variable_start_nodes)
# if subgraph cannot be fused because there is no way to find a common ancestor, break
if lca is None:
@@ -332,6 +240,185 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
return all_nodes, start_nodes, terminal_node
def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]:
"""
Find the single lowest common ancestor of a list of nodes.
Args:
nx_graph (nx.MultiDiGraph):
graph to search for single lca
nodes (List[Node]):
nodes to find the single lca of
Returns
Optional[Node]:
single lca if it exists, None otherwise
"""
# find all ancestors of `nodes`
# nodes themselves need to be in this set because the single lca can be within `nodes`
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
# find common ancestors among `nodes`
# if the single lca exists, it's in this set
common_ancestors = set(
node
for node in nx_graph.nodes()
if node.operation != Operation.Constant
and all(node in ancestors for ancestors in all_ancestors)
)
# iterate over every node in the graph reversed topological order
# this is to ensure result, if found, is the single "lowest" common ancestor
for candidate in reversed(list(nx.topological_sort(nx_graph))):
# check if node is a common ancestor of all `nodes`
if candidate not in common_ancestors:
# if not, it cannot be the single lca
continue
# check if node is a single common ancestor of `nodes`
if is_single_common_ancestor(nx_graph, candidate, nodes):
# if so, it's the single lca of `nodes`
# so return it
return candidate
# if none of the nodes in `common_ancestors` is the single lca
# there is no single lca of this set of nodes, so return None
return None
def is_single_common_ancestor(
nx_graph: nx.MultiDiGraph,
candidate: Node,
nodes: List[Node],
) -> bool:
"""
Determine if a node is the single common ancestor of a list of nodes.
Note that this function doesn't care about `lowest` property of `lca`.
Args:
nx_graph (nx.MultiDiGraph):
graph to perform the check
candidate (Node):
node to determine single common ancestor status
nodes (List[Node]):
nodes to determine single common ancestor status against
Returns
bool:
True if `candidate` is a single common ancestor of `nodes`, False otherwise
"""
# create a subgraph with `candidate` node
subgraph = nx.DiGraph()
subgraph.add_node(candidate)
# iterate over `nodes` to add them to the subgraph
# along with every path from `candidate` to them
for node in nodes:
subgraph.add_node(node)
for path in nx.all_simple_paths(nx_graph, source=candidate, target=node):
nx.add_path(subgraph, path)
# iterate over the nodes of the subgraph
for node in subgraph.nodes():
# the condition below doesn't apply to `candidate`
# as its predecessors are not in the subgraph
if node == candidate:
continue
# find number of predecessors in the subgraph and in the original graph
# except constant nodes in the original graph as
# - they are not in the subgraph
# - they don't affect fusability status
predecessor_count_in_subgraph = len(list(subgraph.predecessors(node)))
predecessor_count_in_nx_graph = len(
list(
pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant
)
)
# see if number of predecessors are different
if predecessor_count_in_subgraph != predecessor_count_in_nx_graph:
# if so, `candidate` cannot be a single common ancestor
# reasoning for is explained below
return False
# if every node in the subgraph has the same number of predecessors
# as in the original graph `candidate` is in fact a single common ancestor
return True
# Here is why this function works.
#
# Legend:
# - /|\- = Edge
# - (...) = Intermediate Node
# - {...} = Candidate Node
# - [...] = Node of which single common ancestor is searched
# - {[...]} = Both Candidate Node and Node of which single common ancestor is searched
#
# Consider the folowing graph:
#
# (3) (x) (2)
# \ / \ /
# [{*}] (/)
# \ /
# [+]
#
# - Operation: (x * 3) + (x / 2)
# - Candidate: {*}
# - Nodes: [*] and [+]
#
# So we want to know if multiplication node is a single common ancestor of
# multiplication and addition nodes. The result is no in this case for our purposes.
#
# Once you apply the subgraph creation above, you'll get the following graph:
#
# (*)
# |
# (+)
#
# In this subgraph, addition node only have a single predecessor,
# which means there is path leading to the addition node and that path doesn't include
# the multiplication node, so we conclude multiplication node is not a single common ancestor
#
# Now, consider the folowing graph:
#
# (3) {x} (2)
# \ / \ /
# [*] (/)
# \ /
# [+]
#
# - Operation: (x * 3) + (x / 2)
# - Candidate: {x}
# - Nodes: [*] and [+]
#
# So we want to know if the input node 'x' is the single common ancestor of
# multiplication and addition nodes. The result is yes in this case.
#
# Once you apply the subgraph creation above, you'll get the following graph:
#
# {x}
# / \
# [*] (/)
# \ /
# [+]
#
# In this subgraph, every node except the candidate node
# will keep all of their non-constant predecessors,
# which means all of their non-constant predecessors originated
# from the `candidate`, so it's a single common anscestor.
#
# When you think about it, this implementation makes a lot of sense for our purposes
# It basically determines if `nodes` "solely" depend on the `candidate`,
# which is the condition for fusing.
def find_closest_integer_output_nodes(
nx_graph: nx.MultiDiGraph,
start_nodes: List[Node],

View File

@@ -141,6 +141,53 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
# pylint: enable=invalid-name,too-many-locals,too-many-statements
def fusable_with_hard_to_find_lca(x):
"""
Fusable function that requires harder lca search.
"""
a = x * 3
b = x // 3
c = a + b
return ((np.sin(a) ** 2) + (np.cos(c) ** 2)).round().astype(np.int64)
def fusable_with_hard_to_find_lca_used_twice(x):
"""
Fusable function that uses `fusable_with_hard_to_find_lca` twice.
"""
a = x @ np.array([[3, 1], [4, 2]])
b = x @ np.array([[1, 2], [3, 4]])
a = fusable_with_hard_to_find_lca(a)
b = fusable_with_hard_to_find_lca(b)
return a + b
def fusable_additional_1(x):
"""
Another fusable function for additional safety.
"""
a = x.astype(np.float64) * 3.0
b = x + 1
c = a.astype(np.int64)
return (a + b + c).astype(np.int64)
def fusable_additional_2(x):
"""
Another fusable function for additional safety.
"""
a = x.astype(np.float64) / 3.0
b = x + 1
c = a * a
return (a + b + c).astype(np.int64)
def deterministic_unary_function(x):
"""
An example deterministic unary function.
@@ -440,6 +487,34 @@ def deterministic_unary_function(x):
},
id="fusable_with_one_of_the_start_nodes_is_lca",
),
pytest.param(
fusable_with_hard_to_find_lca,
{
"x": {"status": "encrypted", "range": [0, 10]},
},
id="fusable_with_hard_to_find_lca",
),
pytest.param(
fusable_with_hard_to_find_lca_used_twice,
{
"x": {"status": "encrypted", "range": [0, 4], "shape": (2, 2)},
},
id="fusable_with_hard_to_find_lca_used_twice",
),
pytest.param(
fusable_additional_1,
{
"x": {"status": "encrypted", "range": [0, 10]},
},
id="fusable_additional_1",
),
pytest.param(
fusable_additional_2,
{
"x": {"status": "encrypted", "range": [0, 10]},
},
id="fusable_additional_2",
),
pytest.param(
lambda x: x + x.shape[0] + x.ndim + x.size,
{