mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: allow unused arguments to be present to simplify development
This commit is contained in:
@@ -54,6 +54,8 @@ class Tracer:
|
||||
computation graph corresponding to `function`
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-statements
|
||||
|
||||
signature = inspect.signature(function)
|
||||
|
||||
missing_args = list(signature.parameters)
|
||||
@@ -102,7 +104,10 @@ class Tracer:
|
||||
|
||||
output_tracers = tuple(sanitized_tracers)
|
||||
|
||||
def create_graph_from_output_tracers(output_tracers: Tuple[Tracer, ...]) -> nx.MultiDiGraph:
|
||||
def create_graph_from_output_tracers(
|
||||
arguments: Dict[str, Tracer],
|
||||
output_tracers: Tuple[Tracer, ...],
|
||||
) -> nx.MultiDiGraph:
|
||||
graph = nx.MultiDiGraph()
|
||||
|
||||
visited_tracers: Set[Tracer] = set()
|
||||
@@ -140,9 +145,12 @@ class Tracer:
|
||||
}
|
||||
assert_that(len(unique_edges) == len(graph.edges))
|
||||
|
||||
for tracer in arguments.values():
|
||||
graph.add_node(tracer.computation)
|
||||
|
||||
return graph
|
||||
|
||||
graph = create_graph_from_output_tracers(output_tracers)
|
||||
graph = create_graph_from_output_tracers(arguments, output_tracers)
|
||||
input_nodes = {
|
||||
input_indices[node]: node
|
||||
for node in graph.nodes()
|
||||
@@ -154,6 +162,8 @@ class Tracer:
|
||||
|
||||
return Graph(graph, input_nodes, output_nodes, is_direct)
|
||||
|
||||
# pylint: enable=too-many-statements
|
||||
|
||||
def __init__(self, computation: Node, input_tracers: List["Tracer"]):
|
||||
self.computation = computation
|
||||
self.input_tracers = input_tracers
|
||||
|
||||
@@ -334,3 +334,27 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance),
|
||||
]
|
||||
assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1]
|
||||
|
||||
|
||||
def test_circuit_run_with_unused_arg(helpers):
|
||||
"""
|
||||
Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments.
|
||||
"""
|
||||
|
||||
configuration = helpers.configuration()
|
||||
|
||||
@compiler({"x": "encrypted", "y": "encrypted"})
|
||||
def f(x, y): # pylint: disable=unused-argument
|
||||
return x + 10
|
||||
|
||||
inputset = [
|
||||
(np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100)
|
||||
]
|
||||
circuit = f.compile(inputset, configuration)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected 2 inputs but got 1"):
|
||||
circuit.encrypt_run_decrypt(10)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(10, 0) == 20
|
||||
assert circuit.encrypt_run_decrypt(10, 10) == 20
|
||||
assert circuit.encrypt_run_decrypt(10, 20) == 20
|
||||
|
||||
Reference in New Issue
Block a user