From a4da3b82101f53a8423161bf309e58908dbb8b5a Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Tue, 12 Oct 2021 11:35:52 +0200 Subject: [PATCH] feat(tracing): add output_idx information in edges - renamed output_index to output_idx in BaseTracer - update tracing and fusing code to manage output_idx correctly - update OPGraph evaluate and update_values_with_bounds to manage output_idx - update tests checking graph validity to have output_idx set properly - the support of actual multi-output nodes is in #81 --- concrete/common/debugging/drawing.py | 1 + concrete/common/debugging/printing.py | 1 + concrete/common/extensions/table.py | 2 +- concrete/common/operator_graph.py | 41 +++++++++++++++++---- concrete/common/optimization/topological.py | 16 ++++++-- concrete/common/tracing/base_tracer.py | 10 +++-- concrete/common/tracing/tracing_helpers.py | 8 +++- concrete/numpy/tracing.py | 10 ++--- tests/common/extensions/test_table.py | 8 ++-- tests/conftest.py | 2 +- tests/helpers/test_conftest.py | 16 ++++---- tests/numpy/test_tracing.py | 8 ++-- 12 files changed, 84 insertions(+), 39 deletions(-) diff --git a/concrete/common/debugging/drawing.py b/concrete/common/debugging/drawing.py index 372cbb79b..f9f1e02f3 100644 --- a/concrete/common/debugging/drawing.py +++ b/concrete/common/debugging/drawing.py @@ -89,6 +89,7 @@ def draw_graph( } nx.set_node_attributes(graph, attributes) + # TODO: #639 adapt drawing routine to manage output_idx for edge in graph.edges(keys=True): idx = graph.edges[edge]["input_idx"] graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look diff --git a/concrete/common/debugging/printing.py b/concrete/common/debugging/printing.py index 5ecfbc95e..688a4faef 100644 --- a/concrete/common/debugging/printing.py +++ b/concrete/common/debugging/printing.py @@ -61,6 +61,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str: for node in nx.topological_sort(graph): + # TODO: #640 # This code doesn't work with more than a single output. For more outputs, # we would need to change the way the destination are created: currently, # they only are done by incrementing i diff --git a/concrete/common/extensions/table.py b/concrete/common/extensions/table.py index 8e882bd52..0455d3cc1 100644 --- a/concrete/common/extensions/table.py +++ b/concrete/common/extensions/table.py @@ -45,7 +45,7 @@ class LookupTable: return key.__class__( inputs=[key], traced_computation=traced_computation, - output_index=0, + output_idx=0, ) # if not, it means table is indexed with a constant diff --git a/concrete/common/operator_graph.py b/concrete/common/operator_graph.py index 4c3498cb9..92f338984 100644 --- a/concrete/common/operator_graph.py +++ b/concrete/common/operator_graph.py @@ -126,13 +126,44 @@ class OPGraph: """ node_results: Dict[IntermediateNode, Any] = {} + def get_result_of_node_at_index(node: IntermediateNode, output_idx: int) -> Any: + """Get the output result at index output_idx for a node. + + Args: + node (IntermediateNode): the node from which we want the output. + output_idx (int): which output we want. + + Returns: + Any: the output value of the evaluation of node. + """ + result = node_results[node] + # TODO: #81 remove no cover once we have nodes with multiple outputs + if isinstance(result, tuple): # pragma: no cover + # If the node has multiple outputs (i.e. the result is a tuple), return the + # requested output + return result[output_idx] + # If the result is not a tuple, then the result is the node's only output. Check that + # the requested index is 0 (as it's the only valid value) and return the result itself. + assert_true( + output_idx == 0, + f"Unable to get output at index {output_idx} for node {node}.\n" + f"Node result: {result}", + ) + return result + for node in nx.topological_sort(self.graph): if not isinstance(node, Input): curr_inputs = {} for pred_node in self.graph.pred[node]: edges = self.graph.get_edge_data(pred_node, node) curr_inputs.update( - {edge["input_idx"]: node_results[pred_node] for edge in edges.values()} + { + edge["input_idx"]: get_result_of_node_at_index( + pred_node, + output_idx=edge["output_idx"], + ) + for edge in edges.values() + } ) node_results[node] = node.evaluate(curr_inputs) else: @@ -225,16 +256,12 @@ class OPGraph: node.outputs[0] = deepcopy(node.inputs[0]) - # TODO: #57 manage multiple outputs from a node, probably requires an output_idx when - # adding an edge - assert_true(len(node.outputs) == 1) - successors = self.graph.succ[node] for succ in successors: edge_data = self.graph.get_edge_data(node, succ) for edge in edge_data.values(): - input_idx = edge["input_idx"] - succ.inputs[input_idx] = deepcopy(node.outputs[0]) + input_idx, output_idx = edge["input_idx"], edge["output_idx"] + succ.inputs[input_idx] = deepcopy(node.outputs[output_idx]) def prune_nodes(self): """Remove unreachable nodes from outputs.""" diff --git a/concrete/common/optimization/topological.py b/concrete/common/optimization/topological.py index 16d2d1df0..ca6cdb012 100644 --- a/concrete/common/optimization/topological.py +++ b/concrete/common/optimization/topological.py @@ -73,10 +73,14 @@ def fuse_float_operations( succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ)) for edge_key, edge_data in succ_edge_data.items(): nx_graph.remove_edge(terminal_node, succ, key=edge_key) - nx_graph.add_edge(fused_node, succ, key=edge_key, **edge_data) + # fused_node is always a UnivariateFunction so output_idx == 0 always + new_edge_data = deepcopy(edge_data) + new_edge_data["output_idx"] = 0 + nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data) # Connect the node feeding the subgraph contained in fused_node - nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0) + # node_before_subgraph has a single integer output currently so output_idx == 0 + nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0, output_idx=0) op_graph.prune_nodes() if compilation_artifacts is not None: @@ -122,6 +126,7 @@ def convert_float_subgraph_to_fused_node( assert_true(len(variable_input_nodes) == 1) current_subgraph_variable_input = variable_input_nodes[0] + assert_true(len(current_subgraph_variable_input.outputs) == 1) new_input_value = deepcopy(current_subgraph_variable_input.outputs[0]) nx_graph = op_graph.graph @@ -147,11 +152,14 @@ def convert_float_subgraph_to_fused_node( float_subgraph.remove_edge( current_subgraph_variable_input, node_after_input, key=edge_key ) + # new_subgraph_variable_input is always an Input so output_idx == 0 always + new_edge_data = deepcopy(edge_data) + new_edge_data["output_idx"] = 0 float_subgraph.add_edge( new_subgraph_variable_input, node_after_input, key=edge_key, - **edge_data, + **new_edge_data, ) float_op_subgraph = OPGraph.from_graph( @@ -160,6 +168,8 @@ def convert_float_subgraph_to_fused_node( [terminal_node], ) + assert_true(len(terminal_node.outputs) == 1) + # Create fused_node fused_node = UnivariateFunction( deepcopy(new_subgraph_variable_input.inputs[0]), diff --git a/concrete/common/tracing/base_tracer.py b/concrete/common/tracing/base_tracer.py index 086b24139..4a6f450e4 100644 --- a/concrete/common/tracing/base_tracer.py +++ b/concrete/common/tracing/base_tracer.py @@ -19,6 +19,7 @@ class BaseTracer(ABC): inputs: List["BaseTracer"] traced_computation: IntermediateNode + output_idx: int output: BaseValue _mix_values_func: Callable[..., BaseValue] @@ -26,11 +27,12 @@ class BaseTracer(ABC): self, inputs: Iterable["BaseTracer"], traced_computation: IntermediateNode, - output_index: int, + output_idx: int, ) -> None: self.inputs = list(inputs) self.traced_computation = traced_computation - self.output = traced_computation.outputs[output_index] + self.output_idx = output_idx + self.output = traced_computation.outputs[output_idx] @abstractmethod def _supports_other_operand(self, other: Any) -> bool: @@ -96,8 +98,8 @@ class BaseTracer(ABC): ) output_tracers = tuple( - self.__class__(sanitized_inputs, traced_computation, output_index) - for output_index in range(len(traced_computation.outputs)) + self.__class__(sanitized_inputs, traced_computation, output_idx) + for output_idx in range(len(traced_computation.outputs)) ) return output_tracers diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py index 24b2f673f..8d114ed35 100644 --- a/concrete/common/tracing/tracing_helpers.py +++ b/concrete/common/tracing/tracing_helpers.py @@ -115,8 +115,14 @@ def create_graph_from_output_tracers( for input_idx, input_tracer in enumerate(tracer.inputs): input_ir_node = input_tracer.traced_computation + output_idx = input_tracer.output_idx graph.add_node(input_ir_node) - graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx) + graph.add_edge( + input_ir_node, + current_ir_node, + input_idx=input_idx, + output_idx=output_idx, + ) if input_tracer not in visited_tracers: next_tracers.update({input_tracer: None}) diff --git a/concrete/numpy/tracing.py b/concrete/numpy/tracing.py index 12ffb5f27..00e1edfa6 100644 --- a/concrete/numpy/tracing.py +++ b/concrete/numpy/tracing.py @@ -93,9 +93,7 @@ class NPTracer(BaseTracer): output_dtype=output_dtype, op_name=f"astype({normalized_numpy_dtype})", ) - output_tracer = self.__class__( - [self], traced_computation=traced_computation, output_index=0 - ) + output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0) return output_tracer @staticmethod @@ -164,7 +162,7 @@ class NPTracer(BaseTracer): output_tracer = cls( input_tracers, traced_computation=traced_computation, - output_index=0, + output_idx=0, ) return output_tracer @@ -229,7 +227,7 @@ class NPTracer(BaseTracer): output_tracer = cls( (input_tracers[in_which_input_is_variable],), traced_computation=traced_computation, - output_index=0, + output_idx=0, ) return output_tracer @@ -253,7 +251,7 @@ class NPTracer(BaseTracer): output_tracer = self.__class__( args, traced_computation=traced_computation, - output_index=0, + output_idx=0, ) return output_tracer diff --git a/tests/common/extensions/test_table.py b/tests/common/extensions/test_table.py index d6a95999a..ac90049a1 100644 --- a/tests/common/extensions/test_table.py +++ b/tests/common/extensions/test_table.py @@ -66,7 +66,7 @@ def test_lookup_table_encrypted_lookup(test_helpers): # pylint: enable=protected-access ref_graph.add_node(output_arbitrary_function) - ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0) + ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0, output_idx=0) # TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) @@ -112,10 +112,10 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers): output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0])) ref_graph.add_node(output_add) - ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0) + ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0, output_idx=0) - ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0) - ref_graph.add_edge(constant_3, output_add, input_idx=1) + ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0, output_idx=0) + ref_graph.add_edge(constant_3, output_add, input_idx=1, output_idx=0) # TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph) diff --git a/tests/conftest.py b/tests/conftest.py index 73acdbb9c..d3e38870b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -194,7 +194,7 @@ class TestHelpers: def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph): """Check that two digraphs are equivalent without modifications""" # edge_match is a copy of node_match - edge_matcher = iso.categorical_multiedge_match("input_idx", None) + edge_matcher = iso.categorical_multiedge_match(["input_idx", "output_idx"], [None, None]) node_matcher = iso.generic_node_match( "_test_content", None, TestHelpers.nodes_are_equivalent ) diff --git a/tests/helpers/test_conftest.py b/tests/helpers/test_conftest.py index 9f5a50a29..65a5d3a09 100644 --- a/tests/helpers/test_conftest.py +++ b/tests/helpers/test_conftest.py @@ -28,27 +28,27 @@ def test_digraphs_are_equivalent(test_helpers): t_1 = TestNode("Mul") t_2 = TestNode("TLU") - g_1.add_edge(t_0, t_2, input_idx=0) - g_1.add_edge(t_1, t_2, input_idx=1) + g_1.add_edge(t_0, t_2, input_idx=0, output_idx=0) + g_1.add_edge(t_1, t_2, input_idx=1, output_idx=0) t0p = TestNode("Add") t1p = TestNode("Mul") t2p = TestNode("TLU") - g_2.add_edge(t1p, t2p, input_idx=1) - g_2.add_edge(t0p, t2p, input_idx=0) + g_2.add_edge(t1p, t2p, input_idx=1, output_idx=0) + g_2.add_edge(t0p, t2p, input_idx=0, output_idx=0) bad_g2 = nx.MultiDiGraph() bad_t0 = TestNode("Not Add") - bad_g2.add_edge(bad_t0, t_2, input_idx=0) - bad_g2.add_edge(t_1, t_2, input_idx=1) + bad_g2.add_edge(bad_t0, t_2, input_idx=0, output_idx=0) + bad_g2.add_edge(t_1, t_2, input_idx=1, output_idx=0) bad_g3 = nx.MultiDiGraph() - bad_g3.add_edge(t_0, t_2, input_idx=1) - bad_g3.add_edge(t_1, t_2, input_idx=0) + bad_g3.add_edge(t_0, t_2, input_idx=1, output_idx=0) + bad_g3.add_edge(t_1, t_2, input_idx=0, output_idx=0) assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent" assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent" diff --git a/tests/numpy/test_tracing.py b/tests/numpy/test_tracing.py index c7fa35a3e..7c953a44c 100644 --- a/tests/numpy/test_tracing.py +++ b/tests/numpy/test_tracing.py @@ -159,11 +159,11 @@ def test_numpy_tracing_binary_op(operation, x, y, test_helpers): ref_graph.add_node(add_node_z) ref_graph.add_node(returned_final_node) - ref_graph.add_edge(input_x, add_node_z, input_idx=0) - ref_graph.add_edge(input_x, add_node_z, input_idx=1) + ref_graph.add_edge(input_x, add_node_z, input_idx=0, output_idx=0) + ref_graph.add_edge(input_x, add_node_z, input_idx=1, output_idx=0) - ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0) - ref_graph.add_edge(input_y, returned_final_node, input_idx=1) + ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0, output_idx=0) + ref_graph.add_edge(input_y, returned_final_node, input_idx=1, output_idx=0) assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)