diff --git a/concrete/numpy/tracing/tracer.py b/concrete/numpy/tracing/tracer.py index 16fc918c1..8a5fe56c7 100644 --- a/concrete/numpy/tracing/tracer.py +++ b/concrete/numpy/tracing/tracer.py @@ -54,6 +54,8 @@ class Tracer: computation graph corresponding to `function` """ + # pylint: disable=too-many-statements + signature = inspect.signature(function) missing_args = list(signature.parameters) @@ -102,7 +104,10 @@ class Tracer: output_tracers = tuple(sanitized_tracers) - def create_graph_from_output_tracers(output_tracers: Tuple[Tracer, ...]) -> nx.MultiDiGraph: + def create_graph_from_output_tracers( + arguments: Dict[str, Tracer], + output_tracers: Tuple[Tracer, ...], + ) -> nx.MultiDiGraph: graph = nx.MultiDiGraph() visited_tracers: Set[Tracer] = set() @@ -140,9 +145,12 @@ class Tracer: } assert_that(len(unique_edges) == len(graph.edges)) + for tracer in arguments.values(): + graph.add_node(tracer.computation) + return graph - graph = create_graph_from_output_tracers(output_tracers) + graph = create_graph_from_output_tracers(arguments, output_tracers) input_nodes = { input_indices[node]: node for node in graph.nodes() @@ -154,6 +162,8 @@ class Tracer: return Graph(graph, input_nodes, output_nodes, is_direct) + # pylint: enable=too-many-statements + def __init__(self, computation: Node, input_tracers: List["Tracer"]): self.computation = computation self.input_tracers = input_tracers diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index f93917dcd..87ece5235 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -334,3 +334,27 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers): expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance), ] assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1] + + +def test_circuit_run_with_unused_arg(helpers): + """ + Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments. + """ + + configuration = helpers.configuration() + + @compiler({"x": "encrypted", "y": "encrypted"}) + def f(x, y): # pylint: disable=unused-argument + return x + 10 + + inputset = [ + (np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100) + ] + circuit = f.compile(inputset, configuration) + + with pytest.raises(ValueError, match="Expected 2 inputs but got 1"): + circuit.encrypt_run_decrypt(10) + + assert circuit.encrypt_run_decrypt(10, 0) == 20 + assert circuit.encrypt_run_decrypt(10, 10) == 20 + assert circuit.encrypt_run_decrypt(10, 20) == 20