diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index d9ac7b3fd..73f6e2388 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -9,7 +9,7 @@ from concrete.common.compilation import CompilationConfiguration from concrete.common.data_types.integers import Integer from concrete.common.debugging import draw_graph, get_printable_graph from concrete.common.extensions.table import LookupTable -from concrete.common.values import EncryptedScalar, EncryptedTensor +from concrete.common.values import ClearTensor, EncryptedScalar, EncryptedTensor from concrete.numpy.compile import ( compile_numpy_function, compile_numpy_function_into_op_graph, @@ -139,6 +139,54 @@ def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names): assert compiler_engine.run(*args) == function(*args) +@pytest.mark.parametrize( + "size, input_range", + [ + pytest.param( + 1, + (0, 8), + ), + pytest.param( + 4, + (0, 8), + ), + pytest.param( + 8, + (0, 8), + ), + pytest.param( + 16, + (0, 4), + ), + ], +) +def test_compile_and_run_dot_correctness(size, input_range): + """Test correctness of results when running a compiled function""" + + def data_gen(input_range, size): + for i in range(*input_range, size): + vec = list(range(i, min(i + size, input_range[1]))) + yield vec, vec[::-1] + + function_parameters = { + "x": EncryptedTensor(Integer(64, False), (size,)), + "y": ClearTensor(Integer(64, False), (size,)), + } + + def function(x, y): + return numpy.dot(x, y) + + compiler_engine = compile_numpy_function( + function, + function_parameters, + data_gen(input_range, size), + ) + + low, high = input_range + args = [[random.randint(low, high) for _ in range(size)] for __ in range(2)] + assert compiler_engine.run(*args) == function(*args) + + def test_compile_function_with_direct_tlu(): """Test compile_numpy_function_into_op_graph for a program with direct table lookup"""