fix(fhe_circuit): update type annotations for run method

This commit is contained in:
Umut
2021-11-08 18:04:09 +03:00
committed by Benoit Chevallier
parent f417246ea3
commit 548b755409
2 changed files with 6 additions and 5 deletions

View File

@@ -3,6 +3,7 @@
from pathlib import Path
from typing import List, Optional, Union
import numpy
from zamalang import CompilerEngine
from .debugging import draw_graph, format_operation_graph
@@ -43,14 +44,14 @@ class FHECircuit:
return draw_graph(self.opgraph, show, vertical, save_to)
def run(self, *args: List[Union[int, List[int]]]) -> int:
def run(self, *args: List[Union[int, numpy.ndarray]]) -> Union[int, numpy.ndarray]:
"""Encrypt, evaluate, and decrypt the inputs on the circuit.
Args:
*args (List[Union[int, List[int]]]): inputs to the circuit
*args (List[Union[int, numpy.ndarray]]): inputs to the circuit
Returns:
int: homomorphic evaluation result
Union[int, numpy.ndarray]: homomorphic evaluation result
"""

View File

@@ -591,7 +591,7 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
default_compilation_configuration,
)
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
args = [numpy.random.randint(low, high + 1, size=(size,), dtype=numpy.uint8) for __ in range(2)]
assert compiler_engine.run(*args) == function(*args)
@@ -652,7 +652,7 @@ def test_compile_and_run_constant_dot_correctness(
default_compilation_configuration,
)
args = (numpy.random.randint(low, high + 1, size=shape).tolist(),)
args = (numpy.random.randint(low, high + 1, size=shape, dtype=numpy.uint8),)
assert left_circuit.run(*args) == left(*args)
assert right_circuit.run(*args) == right(*args)