feat(tracing): add output_idx information in edges

- renamed output_index to output_idx in BaseTracer
- update tracing and fusing code to manage output_idx correctly
- update OPGraph evaluate and update_values_with_bounds to manage
output_idx
- update tests checking graph validity to have output_idx set properly
- the support of actual multi-output nodes is in #81
This commit is contained in:
Arthur Meyre
2021-10-12 11:35:52 +02:00
parent 0cd33b6f67
commit a4da3b8210
12 changed files with 84 additions and 39 deletions

View File

@@ -89,6 +89,7 @@ def draw_graph(
}
nx.set_node_attributes(graph, attributes)
# TODO: #639 adapt drawing routine to manage output_idx
for edge in graph.edges(keys=True):
idx = graph.edges[edge]["input_idx"]
graph.edges[edge]["label"] = f" {idx} " # spaces are there intentionally for a better look

View File

@@ -61,6 +61,7 @@ def get_printable_graph(opgraph: OPGraph, show_data_types: bool = False) -> str:
for node in nx.topological_sort(graph):
# TODO: #640
# This code doesn't work with more than a single output. For more outputs,
# we would need to change the way the destination are created: currently,
# they only are done by incrementing i

View File

@@ -45,7 +45,7 @@ class LookupTable:
return key.__class__(
inputs=[key],
traced_computation=traced_computation,
output_index=0,
output_idx=0,
)
# if not, it means table is indexed with a constant

View File

@@ -126,13 +126,44 @@ class OPGraph:
"""
node_results: Dict[IntermediateNode, Any] = {}
def get_result_of_node_at_index(node: IntermediateNode, output_idx: int) -> Any:
"""Get the output result at index output_idx for a node.
Args:
node (IntermediateNode): the node from which we want the output.
output_idx (int): which output we want.
Returns:
Any: the output value of the evaluation of node.
"""
result = node_results[node]
# TODO: #81 remove no cover once we have nodes with multiple outputs
if isinstance(result, tuple): # pragma: no cover
# If the node has multiple outputs (i.e. the result is a tuple), return the
# requested output
return result[output_idx]
# If the result is not a tuple, then the result is the node's only output. Check that
# the requested index is 0 (as it's the only valid value) and return the result itself.
assert_true(
output_idx == 0,
f"Unable to get output at index {output_idx} for node {node}.\n"
f"Node result: {result}",
)
return result
for node in nx.topological_sort(self.graph):
if not isinstance(node, Input):
curr_inputs = {}
for pred_node in self.graph.pred[node]:
edges = self.graph.get_edge_data(pred_node, node)
curr_inputs.update(
{edge["input_idx"]: node_results[pred_node] for edge in edges.values()}
{
edge["input_idx"]: get_result_of_node_at_index(
pred_node,
output_idx=edge["output_idx"],
)
for edge in edges.values()
}
)
node_results[node] = node.evaluate(curr_inputs)
else:
@@ -225,16 +256,12 @@ class OPGraph:
node.outputs[0] = deepcopy(node.inputs[0])
# TODO: #57 manage multiple outputs from a node, probably requires an output_idx when
# adding an edge
assert_true(len(node.outputs) == 1)
successors = self.graph.succ[node]
for succ in successors:
edge_data = self.graph.get_edge_data(node, succ)
for edge in edge_data.values():
input_idx = edge["input_idx"]
succ.inputs[input_idx] = deepcopy(node.outputs[0])
input_idx, output_idx = edge["input_idx"], edge["output_idx"]
succ.inputs[input_idx] = deepcopy(node.outputs[output_idx])
def prune_nodes(self):
"""Remove unreachable nodes from outputs."""

View File

@@ -73,10 +73,14 @@ def fuse_float_operations(
succ_edge_data = deepcopy(nx_graph.get_edge_data(terminal_node, succ))
for edge_key, edge_data in succ_edge_data.items():
nx_graph.remove_edge(terminal_node, succ, key=edge_key)
nx_graph.add_edge(fused_node, succ, key=edge_key, **edge_data)
# fused_node is always a UnivariateFunction so output_idx == 0 always
new_edge_data = deepcopy(edge_data)
new_edge_data["output_idx"] = 0
nx_graph.add_edge(fused_node, succ, key=edge_key, **new_edge_data)
# Connect the node feeding the subgraph contained in fused_node
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0)
# node_before_subgraph has a single integer output currently so output_idx == 0
nx_graph.add_edge(node_before_subgraph, fused_node, input_idx=0, output_idx=0)
op_graph.prune_nodes()
if compilation_artifacts is not None:
@@ -122,6 +126,7 @@ def convert_float_subgraph_to_fused_node(
assert_true(len(variable_input_nodes) == 1)
current_subgraph_variable_input = variable_input_nodes[0]
assert_true(len(current_subgraph_variable_input.outputs) == 1)
new_input_value = deepcopy(current_subgraph_variable_input.outputs[0])
nx_graph = op_graph.graph
@@ -147,11 +152,14 @@ def convert_float_subgraph_to_fused_node(
float_subgraph.remove_edge(
current_subgraph_variable_input, node_after_input, key=edge_key
)
# new_subgraph_variable_input is always an Input so output_idx == 0 always
new_edge_data = deepcopy(edge_data)
new_edge_data["output_idx"] = 0
float_subgraph.add_edge(
new_subgraph_variable_input,
node_after_input,
key=edge_key,
**edge_data,
**new_edge_data,
)
float_op_subgraph = OPGraph.from_graph(
@@ -160,6 +168,8 @@ def convert_float_subgraph_to_fused_node(
[terminal_node],
)
assert_true(len(terminal_node.outputs) == 1)
# Create fused_node
fused_node = UnivariateFunction(
deepcopy(new_subgraph_variable_input.inputs[0]),

View File

@@ -19,6 +19,7 @@ class BaseTracer(ABC):
inputs: List["BaseTracer"]
traced_computation: IntermediateNode
output_idx: int
output: BaseValue
_mix_values_func: Callable[..., BaseValue]
@@ -26,11 +27,12 @@ class BaseTracer(ABC):
self,
inputs: Iterable["BaseTracer"],
traced_computation: IntermediateNode,
output_index: int,
output_idx: int,
) -> None:
self.inputs = list(inputs)
self.traced_computation = traced_computation
self.output = traced_computation.outputs[output_index]
self.output_idx = output_idx
self.output = traced_computation.outputs[output_idx]
@abstractmethod
def _supports_other_operand(self, other: Any) -> bool:
@@ -96,8 +98,8 @@ class BaseTracer(ABC):
)
output_tracers = tuple(
self.__class__(sanitized_inputs, traced_computation, output_index)
for output_index in range(len(traced_computation.outputs))
self.__class__(sanitized_inputs, traced_computation, output_idx)
for output_idx in range(len(traced_computation.outputs))
)
return output_tracers

View File

@@ -115,8 +115,14 @@ def create_graph_from_output_tracers(
for input_idx, input_tracer in enumerate(tracer.inputs):
input_ir_node = input_tracer.traced_computation
output_idx = input_tracer.output_idx
graph.add_node(input_ir_node)
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
graph.add_edge(
input_ir_node,
current_ir_node,
input_idx=input_idx,
output_idx=output_idx,
)
if input_tracer not in visited_tracers:
next_tracers.update({input_tracer: None})

View File

@@ -93,9 +93,7 @@ class NPTracer(BaseTracer):
output_dtype=output_dtype,
op_name=f"astype({normalized_numpy_dtype})",
)
output_tracer = self.__class__(
[self], traced_computation=traced_computation, output_index=0
)
output_tracer = self.__class__([self], traced_computation=traced_computation, output_idx=0)
return output_tracer
@staticmethod
@@ -164,7 +162,7 @@ class NPTracer(BaseTracer):
output_tracer = cls(
input_tracers,
traced_computation=traced_computation,
output_index=0,
output_idx=0,
)
return output_tracer
@@ -229,7 +227,7 @@ class NPTracer(BaseTracer):
output_tracer = cls(
(input_tracers[in_which_input_is_variable],),
traced_computation=traced_computation,
output_index=0,
output_idx=0,
)
return output_tracer
@@ -253,7 +251,7 @@ class NPTracer(BaseTracer):
output_tracer = self.__class__(
args,
traced_computation=traced_computation,
output_index=0,
output_idx=0,
)
return output_tracer

View File

@@ -66,7 +66,7 @@ def test_lookup_table_encrypted_lookup(test_helpers):
# pylint: enable=protected-access
ref_graph.add_node(output_arbitrary_function)
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0)
ref_graph.add_edge(input_x, output_arbitrary_function, input_idx=0, output_idx=0)
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)
@@ -112,10 +112,10 @@ def test_lookup_table_encrypted_and_plain_lookup(test_helpers):
output_add = ir.Add((intermediate_arbitrary_function.outputs[0], constant_3.outputs[0]))
ref_graph.add_node(output_add)
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0)
ref_graph.add_edge(input_x, intermediate_arbitrary_function, input_idx=0, output_idx=0)
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0)
ref_graph.add_edge(constant_3, output_add, input_idx=1)
ref_graph.add_edge(intermediate_arbitrary_function, output_add, input_idx=0, output_idx=0)
ref_graph.add_edge(constant_3, output_add, input_idx=1, output_idx=0)
# TODO: discuss if this check is enough as == is not overloaded properly for UnivariateFunction
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)

View File

@@ -194,7 +194,7 @@ class TestHelpers:
def digraphs_are_equivalent(reference: nx.MultiDiGraph, to_compare: nx.MultiDiGraph):
"""Check that two digraphs are equivalent without modifications"""
# edge_match is a copy of node_match
edge_matcher = iso.categorical_multiedge_match("input_idx", None)
edge_matcher = iso.categorical_multiedge_match(["input_idx", "output_idx"], [None, None])
node_matcher = iso.generic_node_match(
"_test_content", None, TestHelpers.nodes_are_equivalent
)

View File

@@ -28,27 +28,27 @@ def test_digraphs_are_equivalent(test_helpers):
t_1 = TestNode("Mul")
t_2 = TestNode("TLU")
g_1.add_edge(t_0, t_2, input_idx=0)
g_1.add_edge(t_1, t_2, input_idx=1)
g_1.add_edge(t_0, t_2, input_idx=0, output_idx=0)
g_1.add_edge(t_1, t_2, input_idx=1, output_idx=0)
t0p = TestNode("Add")
t1p = TestNode("Mul")
t2p = TestNode("TLU")
g_2.add_edge(t1p, t2p, input_idx=1)
g_2.add_edge(t0p, t2p, input_idx=0)
g_2.add_edge(t1p, t2p, input_idx=1, output_idx=0)
g_2.add_edge(t0p, t2p, input_idx=0, output_idx=0)
bad_g2 = nx.MultiDiGraph()
bad_t0 = TestNode("Not Add")
bad_g2.add_edge(bad_t0, t_2, input_idx=0)
bad_g2.add_edge(t_1, t_2, input_idx=1)
bad_g2.add_edge(bad_t0, t_2, input_idx=0, output_idx=0)
bad_g2.add_edge(t_1, t_2, input_idx=1, output_idx=0)
bad_g3 = nx.MultiDiGraph()
bad_g3.add_edge(t_0, t_2, input_idx=1)
bad_g3.add_edge(t_1, t_2, input_idx=0)
bad_g3.add_edge(t_0, t_2, input_idx=1, output_idx=0)
bad_g3.add_edge(t_1, t_2, input_idx=0, output_idx=0)
assert test_helpers.digraphs_are_equivalent(g_1, g_2), "Graphs should be equivalent"
assert not test_helpers.digraphs_are_equivalent(g_1, bad_g2), "Graphs should not be equivalent"

View File

@@ -159,11 +159,11 @@ def test_numpy_tracing_binary_op(operation, x, y, test_helpers):
ref_graph.add_node(add_node_z)
ref_graph.add_node(returned_final_node)
ref_graph.add_edge(input_x, add_node_z, input_idx=0)
ref_graph.add_edge(input_x, add_node_z, input_idx=1)
ref_graph.add_edge(input_x, add_node_z, input_idx=0, output_idx=0)
ref_graph.add_edge(input_x, add_node_z, input_idx=1, output_idx=0)
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0)
ref_graph.add_edge(input_y, returned_final_node, input_idx=1)
ref_graph.add_edge(add_node_z, returned_final_node, input_idx=0, output_idx=0)
ref_graph.add_edge(input_y, returned_final_node, input_idx=1, output_idx=0)
assert test_helpers.digraphs_are_equivalent(ref_graph, op_graph.graph)