From 548b755409d8baf3163397d0637884cfce892a09 Mon Sep 17 00:00:00 2001 From: Umut Date: Mon, 8 Nov 2021 18:04:09 +0300 Subject: [PATCH] fix(fhe_circuit): update type annotations for run method --- concrete/common/fhe_circuit.py | 7 ++++--- tests/numpy/test_compile.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/concrete/common/fhe_circuit.py b/concrete/common/fhe_circuit.py index 40f2a5e60..fbfd20353 100644 --- a/concrete/common/fhe_circuit.py +++ b/concrete/common/fhe_circuit.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import List, Optional, Union +import numpy from zamalang import CompilerEngine from .debugging import draw_graph, format_operation_graph @@ -43,14 +44,14 @@ class FHECircuit: return draw_graph(self.opgraph, show, vertical, save_to) - def run(self, *args: List[Union[int, List[int]]]) -> int: + def run(self, *args: List[Union[int, numpy.ndarray]]) -> Union[int, numpy.ndarray]: """Encrypt, evaluate, and decrypt the inputs on the circuit. Args: - *args (List[Union[int, List[int]]]): inputs to the circuit + *args (List[Union[int, numpy.ndarray]]): inputs to the circuit Returns: - int: homomorphic evaluation result + Union[int, numpy.ndarray]: homomorphic evaluation result """ diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 1c58d1e54..4505c753a 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -591,7 +591,7 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_ default_compilation_configuration, ) - args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)] + args = [numpy.random.randint(low, high + 1, size=(size,), dtype=numpy.uint8) for __ in range(2)] assert compiler_engine.run(*args) == function(*args) @@ -652,7 +652,7 @@ def test_compile_and_run_constant_dot_correctness( default_compilation_configuration, ) - args = (numpy.random.randint(low, high + 1, size=shape).tolist(),) + args = (numpy.random.randint(low, high + 1, size=shape, dtype=numpy.uint8),) assert left_circuit.run(*args) == left(*args) assert right_circuit.run(*args) == right(*args)