diff --git a/concrete/numpy/compilation/utils.py b/concrete/numpy/compilation/utils.py index c4801f99f..1c6455722 100644 --- a/concrete/numpy/compilation/utils.py +++ b/concrete/numpy/compilation/utils.py @@ -123,19 +123,6 @@ def find_float_subgraph_with_unique_terminal_node( except StopIteration: return None - # networkx does not implement lowest common ancestor search for multidigraph, but we only care - # about parent relationship here and not the meaning of edges, so we can convert our - # multidigraph to a digraph and use the lca search algorithm (if needed), we create the - # equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause - # issues in our search, so we remove them. - equivalent_subgraph_without_constants = nx.DiGraph(nx_graph) - constant_nodes = [ - node - for node in equivalent_subgraph_without_constants.nodes() - if node.operation == Operation.Constant - ] - equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes) - all_nodes: Dict[Node, None] = {} start_single_int_output_nodes_search_from = terminal_node @@ -154,40 +141,7 @@ def find_float_subgraph_with_unique_terminal_node( # find a common ancestor as we need a single variable input node # lca == lowest common ancestor - # lca search only works for node pairs in networkx, so we progressively find the ancestors - # setting the lca by default to one of the nodes we are searching the lca for - lca = variable_start_nodes.pop() - while len(variable_start_nodes) > 0 and lca is not None: - node_to_find_new_lca = variable_start_nodes.pop() - if lca == node_to_find_new_lca: - continue - - ancestors_of_lca = nx.ancestors( - equivalent_subgraph_without_constants, - lca, - ) - ancestors_of_node_to_find_new_lca = nx.ancestors( - equivalent_subgraph_without_constants, - node_to_find_new_lca, - ) - - lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca - node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca - - if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca: - variable_start_nodes += list( - pred - for pred in nx_graph.predecessors( - node_to_find_new_lca if lca_is_ancestor_of_node_to_find_new_lca else lca - ) - if pred.operation != Operation.Constant - ) - lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca - continue - - lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor( - equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None - ) + lca = find_single_lca(nx_graph, variable_start_nodes) # if subgraph cannot be fused because there is no way to find a common ancestor, break if lca is None: @@ -254,19 +208,6 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc except StopIteration: return None - # networkx does not implement lowest common ancestor search for multidigraph, but we only care - # about parent relationship here and not the meaning of edges, so we can convert our - # multidigraph to a digraph and use the lca search algorithm (if needed), we create the - # equivalent digraph here as it will avoid recreating it in a loop. Constant nodes could cause - # issues in our search, so we remove them. - equivalent_subgraph_without_constants = nx.DiGraph(nx_graph) - constant_nodes = [ - node - for node in equivalent_subgraph_without_constants.nodes() - if node.operation == Operation.Constant - ] - equivalent_subgraph_without_constants.remove_nodes_from(constant_nodes) - all_nodes: Dict[Node, None] = {} while True: @@ -274,40 +215,7 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc # find a common ancestor as we need a single variable input node # lca == lowest common ancestor - # lca search only works for node pairs in networkx, so we progressively find the ancestors - # setting the lca by default to one of the nodes we are searching the lca for - lca = variable_start_nodes.pop() - while len(variable_start_nodes) > 0 and lca is not None: - node_to_find_new_lca = variable_start_nodes.pop() - if lca == node_to_find_new_lca: - continue - - ancestors_of_lca = nx.ancestors( - equivalent_subgraph_without_constants, - lca, - ) - ancestors_of_node_to_find_new_lca = nx.ancestors( - equivalent_subgraph_without_constants, - node_to_find_new_lca, - ) - - lca_is_ancestor_of_node_to_find_new_lca = lca in ancestors_of_node_to_find_new_lca - node_to_find_new_lca_is_ancestor_of_lca = node_to_find_new_lca in ancestors_of_lca - - if lca_is_ancestor_of_node_to_find_new_lca or node_to_find_new_lca_is_ancestor_of_lca: - variable_start_nodes += list( - pred - for pred in nx_graph.predecessors( - node_to_find_new_lca if lca_is_ancestor_of_node_to_find_new_lca else lca - ) - if pred.operation != Operation.Constant - ) - lca = lca if lca_is_ancestor_of_node_to_find_new_lca else node_to_find_new_lca - continue - - lca = nx.algorithms.lowest_common_ancestors.lowest_common_ancestor( - equivalent_subgraph_without_constants, lca, node_to_find_new_lca, default=None - ) + lca = find_single_lca(nx_graph, variable_start_nodes) # if subgraph cannot be fused because there is no way to find a common ancestor, break if lca is None: @@ -332,6 +240,185 @@ def find_tlu_subgraph_with_multiple_variable_inputs_that_has_a_single_common_anc return all_nodes, start_nodes, terminal_node +def find_single_lca(nx_graph: nx.MultiDiGraph, nodes: List[Node]) -> Optional[Node]: + """ + Find the single lowest common ancestor of a list of nodes. + + Args: + nx_graph (nx.MultiDiGraph): + graph to search for single lca + + nodes (List[Node]): + nodes to find the single lca of + + Returns + Optional[Node]: + single lca if it exists, None otherwise + """ + + # find all ancestors of `nodes` + # nodes themselves need to be in this set because the single lca can be within `nodes` + all_ancestors = [set(list(nx.ancestors(nx_graph, node)) + [node]) for node in nodes] + + # find common ancestors among `nodes` + # if the single lca exists, it's in this set + common_ancestors = set( + node + for node in nx_graph.nodes() + if node.operation != Operation.Constant + and all(node in ancestors for ancestors in all_ancestors) + ) + + # iterate over every node in the graph reversed topological order + # this is to ensure result, if found, is the single "lowest" common ancestor + for candidate in reversed(list(nx.topological_sort(nx_graph))): + # check if node is a common ancestor of all `nodes` + if candidate not in common_ancestors: + # if not, it cannot be the single lca + continue + + # check if node is a single common ancestor of `nodes` + if is_single_common_ancestor(nx_graph, candidate, nodes): + # if so, it's the single lca of `nodes` + # so return it + return candidate + + # if none of the nodes in `common_ancestors` is the single lca + # there is no single lca of this set of nodes, so return None + return None + + +def is_single_common_ancestor( + nx_graph: nx.MultiDiGraph, + candidate: Node, + nodes: List[Node], +) -> bool: + """ + Determine if a node is the single common ancestor of a list of nodes. + + Note that this function doesn't care about `lowest` property of `lca`. + + Args: + nx_graph (nx.MultiDiGraph): + graph to perform the check + + candidate (Node): + node to determine single common ancestor status + + nodes (List[Node]): + nodes to determine single common ancestor status against + + Returns + bool: + True if `candidate` is a single common ancestor of `nodes`, False otherwise + """ + + # create a subgraph with `candidate` node + subgraph = nx.DiGraph() + subgraph.add_node(candidate) + + # iterate over `nodes` to add them to the subgraph + # along with every path from `candidate` to them + for node in nodes: + subgraph.add_node(node) + for path in nx.all_simple_paths(nx_graph, source=candidate, target=node): + nx.add_path(subgraph, path) + + # iterate over the nodes of the subgraph + for node in subgraph.nodes(): + # the condition below doesn't apply to `candidate` + # as its predecessors are not in the subgraph + if node == candidate: + continue + + # find number of predecessors in the subgraph and in the original graph + # except constant nodes in the original graph as + # - they are not in the subgraph + # - they don't affect fusability status + predecessor_count_in_subgraph = len(list(subgraph.predecessors(node))) + predecessor_count_in_nx_graph = len( + list( + pred for pred in nx_graph.predecessors(node) if pred.operation != Operation.Constant + ) + ) + + # see if number of predecessors are different + if predecessor_count_in_subgraph != predecessor_count_in_nx_graph: + # if so, `candidate` cannot be a single common ancestor + # reasoning for is explained below + return False + + # if every node in the subgraph has the same number of predecessors + # as in the original graph `candidate` is in fact a single common ancestor + return True + + # Here is why this function works. + # + # Legend: + # - /|\- = Edge + # - (...) = Intermediate Node + # - {...} = Candidate Node + # - [...] = Node of which single common ancestor is searched + # - {[...]} = Both Candidate Node and Node of which single common ancestor is searched + # + # Consider the folowing graph: + # + # (3) (x) (2) + # \ / \ / + # [{*}] (/) + # \ / + # [+] + # + # - Operation: (x * 3) + (x / 2) + # - Candidate: {*} + # - Nodes: [*] and [+] + # + # So we want to know if multiplication node is a single common ancestor of + # multiplication and addition nodes. The result is no in this case for our purposes. + # + # Once you apply the subgraph creation above, you'll get the following graph: + # + # (*) + # | + # (+) + # + # In this subgraph, addition node only have a single predecessor, + # which means there is path leading to the addition node and that path doesn't include + # the multiplication node, so we conclude multiplication node is not a single common ancestor + # + # Now, consider the folowing graph: + # + # (3) {x} (2) + # \ / \ / + # [*] (/) + # \ / + # [+] + # + # - Operation: (x * 3) + (x / 2) + # - Candidate: {x} + # - Nodes: [*] and [+] + # + # So we want to know if the input node 'x' is the single common ancestor of + # multiplication and addition nodes. The result is yes in this case. + # + # Once you apply the subgraph creation above, you'll get the following graph: + # + # {x} + # / \ + # [*] (/) + # \ / + # [+] + # + # In this subgraph, every node except the candidate node + # will keep all of their non-constant predecessors, + # which means all of their non-constant predecessors originated + # from the `candidate`, so it's a single common anscestor. + # + # When you think about it, this implementation makes a lot of sense for our purposes + # It basically determines if `nodes` "solely" depend on the `candidate`, + # which is the condition for fusing. + + def find_closest_integer_output_nodes( nx_graph: nx.MultiDiGraph, start_nodes: List[Node], diff --git a/tests/execution/test_others.py b/tests/execution/test_others.py index 34f21baf7..e339cb968 100644 --- a/tests/execution/test_others.py +++ b/tests/execution/test_others.py @@ -141,6 +141,53 @@ def fusable_with_one_of_the_start_nodes_is_lca_generator(): # pylint: enable=invalid-name,too-many-locals,too-many-statements +def fusable_with_hard_to_find_lca(x): + """ + Fusable function that requires harder lca search. + """ + + a = x * 3 + b = x // 3 + c = a + b + return ((np.sin(a) ** 2) + (np.cos(c) ** 2)).round().astype(np.int64) + + +def fusable_with_hard_to_find_lca_used_twice(x): + """ + Fusable function that uses `fusable_with_hard_to_find_lca` twice. + """ + + a = x @ np.array([[3, 1], [4, 2]]) + b = x @ np.array([[1, 2], [3, 4]]) + + a = fusable_with_hard_to_find_lca(a) + b = fusable_with_hard_to_find_lca(b) + + return a + b + + +def fusable_additional_1(x): + """ + Another fusable function for additional safety. + """ + + a = x.astype(np.float64) * 3.0 + b = x + 1 + c = a.astype(np.int64) + return (a + b + c).astype(np.int64) + + +def fusable_additional_2(x): + """ + Another fusable function for additional safety. + """ + + a = x.astype(np.float64) / 3.0 + b = x + 1 + c = a * a + return (a + b + c).astype(np.int64) + + def deterministic_unary_function(x): """ An example deterministic unary function. @@ -440,6 +487,34 @@ def deterministic_unary_function(x): }, id="fusable_with_one_of_the_start_nodes_is_lca", ), + pytest.param( + fusable_with_hard_to_find_lca, + { + "x": {"status": "encrypted", "range": [0, 10]}, + }, + id="fusable_with_hard_to_find_lca", + ), + pytest.param( + fusable_with_hard_to_find_lca_used_twice, + { + "x": {"status": "encrypted", "range": [0, 4], "shape": (2, 2)}, + }, + id="fusable_with_hard_to_find_lca_used_twice", + ), + pytest.param( + fusable_additional_1, + { + "x": {"status": "encrypted", "range": [0, 10]}, + }, + id="fusable_additional_1", + ), + pytest.param( + fusable_additional_2, + { + "x": {"status": "encrypted", "range": [0, 10]}, + }, + id="fusable_additional_2", + ), pytest.param( lambda x: x + x.shape[0] + x.ndim + x.size, {