mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(fhe_circuit): update type annotations for run method
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy
|
||||
from zamalang import CompilerEngine
|
||||
|
||||
from .debugging import draw_graph, format_operation_graph
|
||||
@@ -43,14 +44,14 @@ class FHECircuit:
|
||||
|
||||
return draw_graph(self.opgraph, show, vertical, save_to)
|
||||
|
||||
def run(self, *args: List[Union[int, List[int]]]) -> int:
|
||||
def run(self, *args: List[Union[int, numpy.ndarray]]) -> Union[int, numpy.ndarray]:
|
||||
"""Encrypt, evaluate, and decrypt the inputs on the circuit.
|
||||
|
||||
Args:
|
||||
*args (List[Union[int, List[int]]]): inputs to the circuit
|
||||
*args (List[Union[int, numpy.ndarray]]): inputs to the circuit
|
||||
|
||||
Returns:
|
||||
int: homomorphic evaluation result
|
||||
Union[int, numpy.ndarray]: homomorphic evaluation result
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -591,7 +591,7 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
|
||||
args = [numpy.random.randint(low, high + 1, size=(size,), dtype=numpy.uint8) for __ in range(2)]
|
||||
assert compiler_engine.run(*args) == function(*args)
|
||||
|
||||
|
||||
@@ -652,7 +652,7 @@ def test_compile_and_run_constant_dot_correctness(
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
args = (numpy.random.randint(low, high + 1, size=shape).tolist(),)
|
||||
args = (numpy.random.randint(low, high + 1, size=shape, dtype=numpy.uint8),)
|
||||
assert left_circuit.run(*args) == left(*args)
|
||||
assert right_circuit.run(*args) == right(*args)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user