diff --git a/concrete/common/tracing/tracing_helpers.py b/concrete/common/tracing/tracing_helpers.py index a2c894312..2e3eea6a8 100644 --- a/concrete/common/tracing/tracing_helpers.py +++ b/concrete/common/tracing/tracing_helpers.py @@ -6,7 +6,7 @@ from typing import Callable, Dict, Iterable, OrderedDict, Set, Type import networkx as nx from networkx.algorithms.dag import is_directed_acyclic_graph -from ..debugging.custom_assert import custom_assert +from ..debugging.custom_assert import assert_true, custom_assert from ..representation.intermediate import Input from ..values import BaseValue from .base_tracer import BaseTracer @@ -108,6 +108,8 @@ def create_graph_from_output_tracers( # use dict as ordered set next_tracers: Dict[BaseTracer, None] = {} for tracer in current_tracers: + if tracer in visited_tracers: + continue current_ir_node = tracer.traced_computation graph.add_node(current_ir_node) @@ -124,4 +126,12 @@ def create_graph_from_output_tracers( custom_assert(is_directed_acyclic_graph(graph)) + # Check each edge is unique + unique_edges = set( + (pred, succ, tuple((k, v) for k, v in edge_data.items())) + for pred, succ, edge_data in graph.edges(data=True) + ) + number_of_edges = len(graph.edges) + assert_true(len(unique_edges) == number_of_edges) + return graph diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index eb155c776..97563618e 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -38,6 +38,20 @@ def small_fused_table(x): return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32) +def complicated_topology(x): + """Mix x in an intricated way.""" + intermediate = x + x_p_1 = intermediate + 1 + x_p_2 = intermediate + 2 + x_p_3 = x_p_1 + x_p_2 + return ( + x_p_3.astype(numpy.int32), + x_p_2.astype(numpy.int32), + (x_p_2 + 3).astype(numpy.int32), + x_p_3.astype(numpy.int32) + 67, + ) + + @pytest.mark.parametrize( "function,input_ranges,list_of_arg_names", [ @@ -55,6 +69,7 @@ def small_fused_table(x): ["x", "y"], marks=pytest.mark.xfail(strict=True, raises=ValueError), ), + pytest.param(complicated_topology, ((0, 10),), ["x"]), ], ) def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):