mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: properly determine lca during fusing
This commit is contained in:
@@ -123,19 +123,6 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
|
||||
# about parent relationship here and not the meaning of edges, so we can convert our
|
||||
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
|
||||
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
|
||||
# issues in our search, so we remove them.
|
||||
equivalent_subgraph_without_constants = nx.DiGraph(nx_graph)
|
||||
constant_nodes = [
|
||||
node
|
||||
for node in equivalent_subgraph_without_constants.nodes()
|
||||
if node.operation == Operation.Constant
|
||||
]
|
||||
equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes)
|
||||
|
||||
all_nodes: Dict[Node, None] = {}
|
||||
|
||||
start_single_int_output_nodes_search_from = terminal_node
|
||||
@@ -154,40 +141,7 @@ def find_float_subgraph_with_unique_terminal_node(
|
||||
|
||||
# find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
# lca search only works for node pairs in networkx, so we progressively find the ancestors
|
||||
# setting the lca by default to one of the nodes we are searching the lca for
|
||||
lca = variable_start_nodes.pop()
|
||||
while len(variable_start_nodes) > 0 and lca is not None:
|
||||
node_to_find_new_lca = variable_start_nodes.pop()
|
||||
if lca == node_to_find_new_lca:
|
||||
continue
|
||||
|
||||
ancestors_of_lca = nx.ancestors(
|
||||
equivalent_subgraph_without_constants,
|
||||
lca,
|
||||
)
|
||||
ancestors_of_node_to_find_new_lca = nx.ancestors(
|
||||
equivalent_subgraph_without_constants,
|
||||
node_to_find_new_lca,
|
||||
)
|
||||
|
||||
lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca
|
||||
node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca
|
||||
|
||||
if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca:
|
||||
variable_start_nodes += list(
|
||||
pred
|
||||
for pred in nx_graph.predecessors(
|
||||
node_to_find_new_lca if lca_is_ancestor_of_node_to_find_new_lca else lca
|
||||
)
|
||||
if pred.operation != Operation.Constant
|
||||
)
|
||||
lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca
|
||||
continue
|
||||
|
||||
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
|
||||
equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None
|
||||
)
|
||||
lca = find_single_lca(nx_graph, variable_start_nodes)
|
||||
|
||||
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
||||
if lca is None:
|
||||
@@ -254,19 +208,6 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
# networkx does not implement lowest common ancestor search for multidigraph, but we only care
|
||||
# about parent relationship here and not the meaning of edges, so we can convert our
|
||||
# multidigraph to a digraph and use the lca search algorithm (if needed), we create the
|
||||
# equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause
|
||||
# issues in our search, so we remove them.
|
||||
equivalent_subgraph_without_constants = nx.DiGraph(nx_graph)
|
||||
constant_nodes = [
|
||||
node
|
||||
for node in equivalent_subgraph_without_constants.nodes()
|
||||
if node.operation == Operation.Constant
|
||||
]
|
||||
equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes)
|
||||
|
||||
all_nodes: Dict[Node, None] = {}
|
||||
|
||||
while True:
|
||||
@@ -274,40 +215,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
|
||||
|
||||
# find a common ancestor as we need a single variable input node
|
||||
# lca == lowest common ancestor
|
||||
# lca search only works for node pairs in networkx, so we progressively find the ancestors
|
||||
# setting the lca by default to one of the nodes we are searching the lca for
|
||||
lca = variable_start_nodes.pop()
|
||||
while len(variable_start_nodes) > 0 and lca is not None:
|
||||
node_to_find_new_lca = variable_start_nodes.pop()
|
||||
if lca == node_to_find_new_lca:
|
||||
continue
|
||||
|
||||
ancestors_of_lca = nx.ancestors(
|
||||
equivalent_subgraph_without_constants,
|
||||
lca,
|
||||
)
|
||||
ancestors_of_node_to_find_new_lca = nx.ancestors(
|
||||
equivalent_subgraph_without_constants,
|
||||
node_to_find_new_lca,
|
||||
)
|
||||
|
||||
lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca
|
||||
node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca
|
||||
|
||||
if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca:
|
||||
variable_start_nodes += list(
|
||||
pred
|
||||
for pred in nx_graph.predecessors(
|
||||
node_to_find_new_lca if lca_is_ancestor_of_node_to_find_new_lca else lca
|
||||
)
|
||||
if pred.operation != Operation.Constant
|
||||
)
|
||||
lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca
|
||||
continue
|
||||
|
||||
lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor(
|
||||
equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None
|
||||
)
|
||||
lca = find_single_lca(nx_graph, variable_start_nodes)
|
||||
|
||||
# if subgraph cannot be fused because there is no way to find a common ancestor, break
|
||||
if lca is None:
|
||||
@@ -332,6 +240,185 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc
|
||||
return all_nodes, start_nodes, terminal_node
|
||||
|
||||
|
||||
def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]:
|
||||
"""
|
||||
Find the single lowest common ancestor of a list of nodes.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph to search for single lca
|
||||
|
||||
nodes (List[Node]):
|
||||
nodes to find the single lca of
|
||||
|
||||
Returns
|
||||
Optional[Node]:
|
||||
single lca if it exists, None otherwise
|
||||
"""
|
||||
|
||||
# find all ancestors of `nodes`
|
||||
# nodes themselves need to be in this set because the single lca can be within `nodes`
|
||||
all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes]
|
||||
|
||||
# find common ancestors among `nodes`
|
||||
# if the single lca exists, it's in this set
|
||||
common_ancestors = set(
|
||||
node
|
||||
for node in nx_graph.nodes()
|
||||
if node.operation != Operation.Constant
|
||||
and all(node in ancestors for ancestors in all_ancestors)
|
||||
)
|
||||
|
||||
# iterate over every node in the graph reversed topological order
|
||||
# this is to ensure result, if found, is the single "lowest" common ancestor
|
||||
for candidate in reversed(list(nx.topological_sort(nx_graph))):
|
||||
# check if node is a common ancestor of all `nodes`
|
||||
if candidate not in common_ancestors:
|
||||
# if not, it cannot be the single lca
|
||||
continue
|
||||
|
||||
# check if node is a single common ancestor of `nodes`
|
||||
if is_single_common_ancestor(nx_graph, candidate, nodes):
|
||||
# if so, it's the single lca of `nodes`
|
||||
# so return it
|
||||
return candidate
|
||||
|
||||
# if none of the nodes in `common_ancestors` is the single lca
|
||||
# there is no single lca of this set of nodes, so return None
|
||||
return None
|
||||
|
||||
|
||||
def is_single_common_ancestor(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
candidate: Node,
|
||||
nodes: List[Node],
|
||||
) -> bool:
|
||||
"""
|
||||
Determine if a node is the single common ancestor of a list of nodes.
|
||||
|
||||
Note that this function doesn't care about `lowest` property of `lca`.
|
||||
|
||||
Args:
|
||||
nx_graph (nx.MultiDiGraph):
|
||||
graph to perform the check
|
||||
|
||||
candidate (Node):
|
||||
node to determine single common ancestor status
|
||||
|
||||
nodes (List[Node]):
|
||||
nodes to determine single common ancestor status against
|
||||
|
||||
Returns
|
||||
bool:
|
||||
True if `candidate` is a single common ancestor of `nodes`, False otherwise
|
||||
"""
|
||||
|
||||
# create a subgraph with `candidate` node
|
||||
subgraph = nx.DiGraph()
|
||||
subgraph.add_node(candidate)
|
||||
|
||||
# iterate over `nodes` to add them to the subgraph
|
||||
# along with every path from `candidate` to them
|
||||
for node in nodes:
|
||||
subgraph.add_node(node)
|
||||
for path in nx.all_simple_paths(nx_graph, source=candidate, target=node):
|
||||
nx.add_path(subgraph, path)
|
||||
|
||||
# iterate over the nodes of the subgraph
|
||||
for node in subgraph.nodes():
|
||||
# the condition below doesn't apply to `candidate`
|
||||
# as its predecessors are not in the subgraph
|
||||
if node == candidate:
|
||||
continue
|
||||
|
||||
# find number of predecessors in the subgraph and in the original graph
|
||||
# except constant nodes in the original graph as
|
||||
# - they are not in the subgraph
|
||||
# - they don't affect fusability status
|
||||
predecessor_count_in_subgraph = len(list(subgraph.predecessors(node)))
|
||||
predecessor_count_in_nx_graph = len(
|
||||
list(
|
||||
pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant
|
||||
)
|
||||
)
|
||||
|
||||
# see if number of predecessors are different
|
||||
if predecessor_count_in_subgraph != predecessor_count_in_nx_graph:
|
||||
# if so, `candidate` cannot be a single common ancestor
|
||||
# reasoning for is explained below
|
||||
return False
|
||||
|
||||
# if every node in the subgraph has the same number of predecessors
|
||||
# as in the original graph `candidate` is in fact a single common ancestor
|
||||
return True
|
||||
|
||||
# Here is why this function works.
|
||||
#
|
||||
# Legend:
|
||||
# - /|\- = Edge
|
||||
# - (...) = Intermediate Node
|
||||
# - {...} = Candidate Node
|
||||
# - [...] = Node of which single common ancestor is searched
|
||||
# - {[...]} = Both Candidate Node and Node of which single common ancestor is searched
|
||||
#
|
||||
# Consider the folowing graph:
|
||||
#
|
||||
# (3) (x) (2)
|
||||
# \ / \ /
|
||||
# [{*}] (/)
|
||||
# \ /
|
||||
# [+]
|
||||
#
|
||||
# - Operation: (x * 3) + (x / 2)
|
||||
# - Candidate: {*}
|
||||
# - Nodes: [*] and [+]
|
||||
#
|
||||
# So we want to know if multiplication node is a single common ancestor of
|
||||
# multiplication and addition nodes. The result is no in this case for our purposes.
|
||||
#
|
||||
# Once you apply the subgraph creation above, you'll get the following graph:
|
||||
#
|
||||
# (*)
|
||||
# |
|
||||
# (+)
|
||||
#
|
||||
# In this subgraph, addition node only have a single predecessor,
|
||||
# which means there is path leading to the addition node and that path doesn't include
|
||||
# the multiplication node, so we conclude multiplication node is not a single common ancestor
|
||||
#
|
||||
# Now, consider the folowing graph:
|
||||
#
|
||||
# (3) {x} (2)
|
||||
# \ / \ /
|
||||
# [*] (/)
|
||||
# \ /
|
||||
# [+]
|
||||
#
|
||||
# - Operation: (x * 3) + (x / 2)
|
||||
# - Candidate: {x}
|
||||
# - Nodes: [*] and [+]
|
||||
#
|
||||
# So we want to know if the input node 'x' is the single common ancestor of
|
||||
# multiplication and addition nodes. The result is yes in this case.
|
||||
#
|
||||
# Once you apply the subgraph creation above, you'll get the following graph:
|
||||
#
|
||||
# {x}
|
||||
# / \
|
||||
# [*] (/)
|
||||
# \ /
|
||||
# [+]
|
||||
#
|
||||
# In this subgraph, every node except the candidate node
|
||||
# will keep all of their non-constant predecessors,
|
||||
# which means all of their non-constant predecessors originated
|
||||
# from the `candidate`, so it's a single common anscestor.
|
||||
#
|
||||
# When you think about it, this implementation makes a lot of sense for our purposes
|
||||
# It basically determines if `nodes` "solely" depend on the `candidate`,
|
||||
# which is the condition for fusing.
|
||||
|
||||
|
||||
def find_closest_integer_output_nodes(
|
||||
nx_graph: nx.MultiDiGraph,
|
||||
start_nodes: List[Node],
|
||||
|
||||
@@ -141,6 +141,53 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator():
|
||||
# pylint: enable=invalid-name,too-many-locals,too-many-statements
|
||||
|
||||
|
||||
def fusable_with_hard_to_find_lca(x):
|
||||
"""
|
||||
Fusable function that requires harder lca search.
|
||||
"""
|
||||
|
||||
a = x * 3
|
||||
b = x // 3
|
||||
c = a + b
|
||||
return ((np.sin(a) ** 2) + (np.cos(c) ** 2)).round().astype(np.int64)
|
||||
|
||||
|
||||
def fusable_with_hard_to_find_lca_used_twice(x):
|
||||
"""
|
||||
Fusable function that uses `fusable_with_hard_to_find_lca` twice.
|
||||
"""
|
||||
|
||||
a = x @ np.array([[3, 1], [4, 2]])
|
||||
b = x @ np.array([[1, 2], [3, 4]])
|
||||
|
||||
a = fusable_with_hard_to_find_lca(a)
|
||||
b = fusable_with_hard_to_find_lca(b)
|
||||
|
||||
return a + b
|
||||
|
||||
|
||||
def fusable_additional_1(x):
|
||||
"""
|
||||
Another fusable function for additional safety.
|
||||
"""
|
||||
|
||||
a = x.astype(np.float64) * 3.0
|
||||
b = x + 1
|
||||
c = a.astype(np.int64)
|
||||
return (a + b + c).astype(np.int64)
|
||||
|
||||
|
||||
def fusable_additional_2(x):
|
||||
"""
|
||||
Another fusable function for additional safety.
|
||||
"""
|
||||
|
||||
a = x.astype(np.float64) / 3.0
|
||||
b = x + 1
|
||||
c = a * a
|
||||
return (a + b + c).astype(np.int64)
|
||||
|
||||
|
||||
def deterministic_unary_function(x):
|
||||
"""
|
||||
An example deterministic unary function.
|
||||
@@ -440,6 +487,34 @@ def deterministic_unary_function(x):
|
||||
},
|
||||
id="fusable_with_one_of_the_start_nodes_is_lca",
|
||||
),
|
||||
pytest.param(
|
||||
fusable_with_hard_to_find_lca,
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 10]},
|
||||
},
|
||||
id="fusable_with_hard_to_find_lca",
|
||||
),
|
||||
pytest.param(
|
||||
fusable_with_hard_to_find_lca_used_twice,
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 4], "shape": (2, 2)},
|
||||
},
|
||||
id="fusable_with_hard_to_find_lca_used_twice",
|
||||
),
|
||||
pytest.param(
|
||||
fusable_additional_1,
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 10]},
|
||||
},
|
||||
id="fusable_additional_1",
|
||||
),
|
||||
pytest.param(
|
||||
fusable_additional_2,
|
||||
{
|
||||
"x": {"status": "encrypted", "range": [0, 10]},
|
||||
},
|
||||
id="fusable_additional_2",
|
||||
),
|
||||
pytest.param(
|
||||
lambda x: x + x.shape[0] + x.ndim + x.size,
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user