feat: allow unused arguments to be present to simplify development

This commit is contained in:
Umut
2023-01-05 12:57:39 +01:00
parent 01ef6ce206
commit 0c470852c3
2 changed files with 36 additions and 2 deletions

View File

@@ -54,6 +54,8 @@ class Tracer:
computation graph corresponding to `function`
"""
# pylint: disable=too-many-statements
signature = inspect.signature(function)
missing_args = list(signature.parameters)
@@ -102,7 +104,10 @@ class Tracer:
output_tracers = tuple(sanitized_tracers)
def create_graph_from_output_tracers(output_tracers: Tuple[Tracer, ...]) -> nx.MultiDiGraph:
def create_graph_from_output_tracers(
arguments: Dict[str, Tracer],
output_tracers: Tuple[Tracer, ...],
) -> nx.MultiDiGraph:
graph = nx.MultiDiGraph()
visited_tracers: Set[Tracer] = set()
@@ -140,9 +145,12 @@ class Tracer:
}
assert_that(len(unique_edges) == len(graph.edges))
for tracer in arguments.values():
graph.add_node(tracer.computation)
return graph
graph = create_graph_from_output_tracers(output_tracers)
graph = create_graph_from_output_tracers(arguments, output_tracers)
input_nodes = {
input_indices[node]: node
for node in graph.nodes()
@@ -154,6 +162,8 @@ class Tracer:
return Graph(graph, input_nodes, output_nodes, is_direct)
# pylint: enable=too-many-statements
def __init__(self, computation: Node, input_tracers: List["Tracer"]):
self.computation = computation
self.input_tracers = input_tracers

View File

@@ -334,3 +334,27 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance),
]
assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1]
def test_circuit_run_with_unused_arg(helpers):
"""
Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments.
"""
configuration = helpers.configuration()
@compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y): # pylint: disable=unused-argument
return x + 10
inputset = [
(np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100)
]
circuit = f.compile(inputset, configuration)
with pytest.raises(ValueError, match="Expected 2 inputs but got 1"):
circuit.encrypt_run_decrypt(10)
assert circuit.encrypt_run_decrypt(10, 0) == 20
assert circuit.encrypt_run_decrypt(10, 10) == 20
assert circuit.encrypt_run_decrypt(10, 20) == 20