fix(tracing): do not process already visited tracers

- some topologies triggered a bug where a tracer was visited several times
- this created duplicate edges which caused problems later in the code
- simply skip already visited tracers
- add a check to see that all edges are indeed unique
This commit is contained in:
Arthur Meyre
2021-10-12 14:40:02 +02:00
parent 31e14e7b66
commit fc9fc992c8
2 changed files with 26 additions and 1 deletions

View File

@@ -6,7 +6,7 @@ from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
import networkx as nx
from networkx.algorithms.dag import is_directed_acyclic_graph
from ..debugging.custom_assert import custom_assert
from ..debugging.custom_assert import assert_true, custom_assert
from ..representation.intermediate import Input
from ..values import BaseValue
from .base_tracer import BaseTracer
@@ -108,6 +108,8 @@ def create_graph_from_output_tracers(
# use dict as ordered set
next_tracers: Dict[BaseTracer, None] = {}
for tracer in current_tracers:
if tracer in visited_tracers:
continue
current_ir_node = tracer.traced_computation
graph.add_node(current_ir_node)
@@ -124,4 +126,12 @@ def create_graph_from_output_tracers(
custom_assert(is_directed_acyclic_graph(graph))
# Check each edge is unique
unique_edges = set(
(pred, succ, tuple((k, v) for k, v in edge_data.items()))
for pred, succ, edge_data in graph.edges(data=True)
)
number_of_edges = len(graph.edges)
assert_true(len(unique_edges) == number_of_edges)
return graph

View File

@@ -38,6 +38,20 @@ def small_fused_table(x):
return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32)
def complicated_topology(x):
"""Mix x in an intricated way."""
intermediate = x
x_p_1 = intermediate + 1
x_p_2 = intermediate + 2
x_p_3 = x_p_1 + x_p_2
return (
x_p_3.astype(numpy.int32),
x_p_2.astype(numpy.int32),
(x_p_2 + 3).astype(numpy.int32),
x_p_3.astype(numpy.int32) + 67,
)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
@@ -55,6 +69,7 @@ def small_fused_table(x):
["x", "y"],
marks=pytest.mark.xfail(strict=True, raises=ValueError),
),
pytest.param(complicated_topology, ((0, 10),), ["x"]),
],
)
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):