test: dot compilation and execution

This commit is contained in:
youben11
2021-09-10 11:49:32 +01:00
committed by Ayoub Benaissa
parent d793bffc52
commit 845558d3a5

View File

@@ -9,7 +9,7 @@ from concrete.common.compilation import CompilationConfiguration
from concrete.common.data_types.integers import Integer
from concrete.common.debugging import draw_graph, get_printable_graph
from concrete.common.extensions.table import LookupTable
from concrete.common.values import EncryptedScalar, EncryptedTensor
from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor
from concrete.numpy.compile import (
compile_numpy_function,
compile_numpy_function_into_op_graph,
@@ -139,6 +139,54 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
assert compiler_engine.run(*args) == function(*args)
@pytest.mark.parametrize(
"size, input_range",
[
pytest.param(
1,
(0, 8),
),
pytest.param(
4,
(0, 8),
),
pytest.param(
8,
(0, 8),
),
pytest.param(
16,
(0, 4),
),
],
)
def test_compile_and_run_dot_correctness(size, input_range):
"""Test correctness of results when running a compiled function"""
def data_gen(input_range, size):
for i in range(*input_range, size):
vec = list(range(i, min(i + size, input_range[1])))
yield vec, vec[::-1]
function_parameters = {
"x": EncryptedTensor(Integer(64, False), (size,)),
"y": ClearTensor(Integer(64, False), (size,)),
}
def function(x, y):
return numpy.dot(x, y)
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(input_range, size),
)
low, high = input_range
args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)]
assert compiler_engine.run(*args) == function(*args)
def test_compile_function_with_direct_tlu():
"""Test compile_numpy_function_into_op_graph for a program with direct table lookup"""