fix: fixing issue in the graph generation

closes #130
This commit is contained in:
Benoit Chevallier-Mames
2021-08-12 11:59:15 +02:00
committed by Benoit Chevallier
parent 8fd0ae5c85
commit 4f6103d1d1
2 changed files with 51 additions and 4 deletions

View File

@@ -1,7 +1,7 @@
"""Helper functions for tracing."""
import collections
from inspect import signature
from typing import Callable, Dict, Iterable, OrderedDict, Set, Tuple, Type
from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
import networkx as nx
from networkx.algorithms.dag import is_directed_acyclic_graph
@@ -100,10 +100,12 @@ def create_graph_from_output_tracers(
graph = nx.MultiDiGraph()
visited_tracers: Set[BaseTracer] = set()
current_tracers = tuple(output_tracers)
# use dict as ordered set
current_tracers = {tracer: None for tracer in output_tracers}
while current_tracers:
next_tracers: Tuple[BaseTracer, ...] = tuple()
# use dict as ordered set
next_tracers: Dict[BaseTracer, None] = dict()
for tracer in current_tracers:
current_ir_node = tracer.traced_computation
graph.add_node(current_ir_node, content=current_ir_node)
@@ -113,7 +115,7 @@ def create_graph_from_output_tracers(
graph.add_node(input_ir_node, content=input_ir_node)
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
if input_tracer not in visited_tracers:
next_tracers += (input_tracer,)
next_tracers.update({input_tracer: None})
visited_tracers.add(tracer)

View File

@@ -8,6 +8,30 @@ from hdk.common.debugging import draw_graph, get_printable_graph
from hdk.hnumpy import tracing
def issue_130_a(x, y):
"""Test case derived from issue #130"""
# pylint: disable=unused-argument
intermediate = x + 1
return (intermediate, intermediate)
# pylint: enable=unused-argument
def issue_130_b(x, y):
"""Test case derived from issue #130"""
# pylint: disable=unused-argument
intermediate = x - 1
return (intermediate, intermediate)
# pylint: enable=unused-argument
def issue_130_c(x, y):
"""Test case derived from issue #130"""
# pylint: disable=unused-argument
intermediate = 1 - x
return (intermediate, intermediate)
# pylint: enable=unused-argument
@pytest.mark.parametrize(
"lambda_f,ref_graph_str",
[
@@ -62,6 +86,27 @@ from hdk.hnumpy import tracing
lambda x, y: (x, x + 1),
"\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%0, %2)",
),
(
lambda x, y: (x + 1, x + 1),
"\n%0 = x"
"\n%1 = ConstantInput(1)"
"\n%2 = ConstantInput(1)"
"\n%3 = Add(0, 1)"
"\n%4 = Add(0, 2)"
"\nreturn(%3, %4)",
),
(
issue_130_a,
"\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2, %2)",
),
(
issue_130_b,
"\n%0 = x\n%1 = ConstantInput(1)\n%2 = Sub(0, 1)\nreturn(%2, %2)",
),
(
issue_130_c,
"\n%0 = ConstantInput(1)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2, %2)",
),
],
)
@pytest.mark.parametrize(