From 4f6103d1d1439a6addf299dd909fede49d478002 Mon Sep 17 00:00:00 2001 From: Benoit Chevallier-Mames Date: Thu, 12 Aug 2021 11:59:15 +0200 Subject: [PATCH] fix: fixing issue in the graph generation closes #130 --- hdk/common/tracing/tracing_helpers.py | 10 +++--- tests/hnumpy/test_debugging.py | 45 +++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 4 deletions(-) diff --git a/hdk/common/tracing/tracing_helpers.py b/hdk/common/tracing/tracing_helpers.py index 087d042fb..f60a33e72 100644 --- a/hdk/common/tracing/tracing_helpers.py +++ b/hdk/common/tracing/tracing_helpers.py @@ -1,7 +1,7 @@ """Helper functions for tracing.""" import collections from inspect import signature -from typing import Callable, Dict, Iterable, OrderedDict, Set, Tuple, Type +from typing import Callable, Dict, Iterable, OrderedDict, Set, Type import networkx as nx from networkx.algorithms.dag import is_directed_acyclic_graph @@ -100,10 +100,12 @@ def create_graph_from_output_tracers( graph = nx.MultiDiGraph() visited_tracers: Set[BaseTracer] = set() - current_tracers = tuple(output_tracers) + # use dict as ordered set + current_tracers = {tracer: None for tracer in output_tracers} while current_tracers: - next_tracers: Tuple[BaseTracer, ...] = tuple() + # use dict as ordered set + next_tracers: Dict[BaseTracer, None] = dict() for tracer in current_tracers: current_ir_node = tracer.traced_computation graph.add_node(current_ir_node, content=current_ir_node) @@ -113,7 +115,7 @@ def create_graph_from_output_tracers( graph.add_node(input_ir_node, content=input_ir_node) graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx) if input_tracer not in visited_tracers: - next_tracers += (input_tracer,) + next_tracers.update({input_tracer: None}) visited_tracers.add(tracer) diff --git a/tests/hnumpy/test_debugging.py b/tests/hnumpy/test_debugging.py index 750f56bbd..6a51a4374 100644 --- a/tests/hnumpy/test_debugging.py +++ b/tests/hnumpy/test_debugging.py @@ -8,6 +8,30 @@ from hdk.common.debugging import draw_graph, get_printable_graph from hdk.hnumpy import tracing +def issue_130_a(x, y): + """Test case derived from issue #130""" + # pylint: disable=unused-argument + intermediate = x + 1 + return (intermediate, intermediate) + # pylint: enable=unused-argument + + +def issue_130_b(x, y): + """Test case derived from issue #130""" + # pylint: disable=unused-argument + intermediate = x - 1 + return (intermediate, intermediate) + # pylint: enable=unused-argument + + +def issue_130_c(x, y): + """Test case derived from issue #130""" + # pylint: disable=unused-argument + intermediate = 1 - x + return (intermediate, intermediate) + # pylint: enable=unused-argument + + @pytest.mark.parametrize( "lambda_f,ref_graph_str", [ @@ -62,6 +86,27 @@ from hdk.hnumpy import tracing lambda x, y: (x, x + 1), "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%0, %2)", ), + ( + lambda x, y: (x + 1, x + 1), + "\n%0 = x" + "\n%1 = ConstantInput(1)" + "\n%2 = ConstantInput(1)" + "\n%3 = Add(0, 1)" + "\n%4 = Add(0, 2)" + "\nreturn(%3, %4)", + ), + ( + issue_130_a, + "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2, %2)", + ), + ( + issue_130_b, + "\n%0 = x\n%1 = ConstantInput(1)\n%2 = Sub(0, 1)\nreturn(%2, %2)", + ), + ( + issue_130_c, + "\n%0 = ConstantInput(1)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2, %2)", + ), ], ) @pytest.mark.parametrize(