mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
committed by
Benoit Chevallier
parent
8fd0ae5c85
commit
4f6103d1d1
@@ -1,7 +1,7 @@
|
||||
"""Helper functions for tracing."""
|
||||
import collections
|
||||
from inspect import signature
|
||||
from typing import Callable, Dict, Iterable, OrderedDict, Set, Tuple, Type
|
||||
from typing import Callable, Dict, Iterable, OrderedDict, Set, Type
|
||||
|
||||
import networkx as nx
|
||||
from networkx.algorithms.dag import is_directed_acyclic_graph
|
||||
@@ -100,10 +100,12 @@ def create_graph_from_output_tracers(
|
||||
graph = nx.MultiDiGraph()
|
||||
|
||||
visited_tracers: Set[BaseTracer] = set()
|
||||
current_tracers = tuple(output_tracers)
|
||||
# use dict as ordered set
|
||||
current_tracers = {tracer: None for tracer in output_tracers}
|
||||
|
||||
while current_tracers:
|
||||
next_tracers: Tuple[BaseTracer, ...] = tuple()
|
||||
# use dict as ordered set
|
||||
next_tracers: Dict[BaseTracer, None] = dict()
|
||||
for tracer in current_tracers:
|
||||
current_ir_node = tracer.traced_computation
|
||||
graph.add_node(current_ir_node, content=current_ir_node)
|
||||
@@ -113,7 +115,7 @@ def create_graph_from_output_tracers(
|
||||
graph.add_node(input_ir_node, content=input_ir_node)
|
||||
graph.add_edge(input_ir_node, current_ir_node, input_idx=input_idx)
|
||||
if input_tracer not in visited_tracers:
|
||||
next_tracers += (input_tracer,)
|
||||
next_tracers.update({input_tracer: None})
|
||||
|
||||
visited_tracers.add(tracer)
|
||||
|
||||
|
||||
@@ -8,6 +8,30 @@ from hdk.common.debugging import draw_graph, get_printable_graph
|
||||
from hdk.hnumpy import tracing
|
||||
|
||||
|
||||
def issue_130_a(x, y):
|
||||
"""Test case derived from issue #130"""
|
||||
# pylint: disable=unused-argument
|
||||
intermediate = x + 1
|
||||
return (intermediate, intermediate)
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
def issue_130_b(x, y):
|
||||
"""Test case derived from issue #130"""
|
||||
# pylint: disable=unused-argument
|
||||
intermediate = x - 1
|
||||
return (intermediate, intermediate)
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
def issue_130_c(x, y):
|
||||
"""Test case derived from issue #130"""
|
||||
# pylint: disable=unused-argument
|
||||
intermediate = 1 - x
|
||||
return (intermediate, intermediate)
|
||||
# pylint: enable=unused-argument
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lambda_f,ref_graph_str",
|
||||
[
|
||||
@@ -62,6 +86,27 @@ from hdk.hnumpy import tracing
|
||||
lambda x, y: (x, x + 1),
|
||||
"\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%0, %2)",
|
||||
),
|
||||
(
|
||||
lambda x, y: (x + 1, x + 1),
|
||||
"\n%0 = x"
|
||||
"\n%1 = ConstantInput(1)"
|
||||
"\n%2 = ConstantInput(1)"
|
||||
"\n%3 = Add(0, 1)"
|
||||
"\n%4 = Add(0, 2)"
|
||||
"\nreturn(%3, %4)",
|
||||
),
|
||||
(
|
||||
issue_130_a,
|
||||
"\n%0 = x\n%1 = ConstantInput(1)\n%2 = Add(0, 1)\nreturn(%2, %2)",
|
||||
),
|
||||
(
|
||||
issue_130_b,
|
||||
"\n%0 = x\n%1 = ConstantInput(1)\n%2 = Sub(0, 1)\nreturn(%2, %2)",
|
||||
),
|
||||
(
|
||||
issue_130_c,
|
||||
"\n%0 = ConstantInput(1)\n%1 = x\n%2 = Sub(0, 1)\nreturn(%2, %2)",
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user