mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix(tracing): do not process already visited tracers
- some topologies triggered a bug where a tracer was visited several times - this created duplicate edges which caused problems later in the code - simply skip already visited tracers - add a check to see that all edges are indeed unique
This commit is contained in:
@@ -6,7 +6,7 @@ from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
|
||||
import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
|
||||
from ..debugging.custom_assert import custom_assert
|
||||
from ..debugging.custom_assert import assert_true, custom_assert
|
||||
from ..representation.intermediate import Input
|
||||
from ..values import BaseValue
|
||||
from .base_tracer import BaseTracer
|
||||
@@ -108,6 +108,8 @@ def create_graph_from_output_tracers(
|
||||
# use dict as ordered set
|
||||
next_tracers: Dict[BaseTracer, None] = {}
|
||||
for tracer in current_tracers:
|
||||
if tracer in visited_tracers:
|
||||
continue
|
||||
current_ir_node = tracer.traced_computation
|
||||
graph.add_node(current_ir_node)
|
||||
|
||||
@@ -124,4 +126,12 @@ def create_graph_from_output_tracers(
|
||||
|
||||
custom_assert(is_directed_acyclic_graph(graph))
|
||||
|
||||
# Check each edge is unique
|
||||
unique_edges = set(
|
||||
(pred, succ, tuple((k, v) for k, v in edge_data.items()))
|
||||
for pred, succ, edge_data in graph.edges(data=True)
|
||||
)
|
||||
number_of_edges = len(graph.edges)
|
||||
assert_true(len(unique_edges) == number_of_edges)
|
||||
|
||||
return graph
|
||||
|
||||
@@ -38,6 +38,20 @@ def small_fused_table(x):
|
||||
return (10 * (numpy.cos(x + 1) + 1)).astype(numpy.uint32)
|
||||
|
||||
|
||||
def complicated_topology(x):
|
||||
"""Mix x in an intricated way."""
|
||||
intermediate = x
|
||||
x_p_1 = intermediate + 1
|
||||
x_p_2 = intermediate + 2
|
||||
x_p_3 = x_p_1 + x_p_2
|
||||
return (
|
||||
x_p_3.astype(numpy.int32),
|
||||
x_p_2.astype(numpy.int32),
|
||||
(x_p_2 + 3).astype(numpy.int32),
|
||||
x_p_3.astype(numpy.int32) + 67,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,input_ranges,list_of_arg_names",
|
||||
[
|
||||
@@ -55,6 +69,7 @@ def small_fused_table(x):
|
||||
["x", "y"],
|
||||
marks=pytest.mark.xfail(strict=True, raises=ValueError),
|
||||
),
|
||||
pytest.param(complicated_topology, ((0, 10),), ["x"]),
|
||||
],
|
||||
)
|
||||
def test_compile_function_multiple_outputs(function, input_ranges, list_of_arg_names):
|
||||
|
||||
Reference in New Issue
Block a user