From 11c38efa1eb8a4ec4796d1ac04b721332e034e15 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 24 Dec 2021 17:18:50 +0100 Subject: [PATCH] test: add dot tests with non program inputs args - also fix some inputsets which had wrongly shaped inputs --- tests/numpy/test_compile.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 5bf5fb5b5..ca0387c65 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1157,7 +1157,12 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_ ), ] for _ in range(8): - inputset.append((numpy.random.randint(low, high + 1), numpy.random.randint(low, high + 1))) + inputset.append( + ( + numpy.random.randint(low, high + 1, size=shape), + numpy.random.randint(low, high + 1, size=shape), + ) + ) function_parameters = { "x": EncryptedTensor(Integer(64, False), shape), @@ -1167,15 +1172,21 @@ def test_compile_and_run_dot_correctness(size, input_range, default_compilation_ def function(x, y): return numpy.dot(x, y) - compiler_engine = compile_numpy_function( - function, - function_parameters, - inputset, - default_compilation_configuration, - ) + def function_indirect_args(x, y): + return numpy.dot(x.flatten(), y.flatten()) - args = [numpy.random.randint(low, high + 1, size=(size,), dtype=numpy.uint8) for __ in range(2)] - assert compiler_engine.run(*args) == function(*args) + for func_to_compile in [function, function_indirect_args]: + compiler_engine = compile_numpy_function( + func_to_compile, + function_parameters, + inputset, + default_compilation_configuration, + ) + + args = [ + numpy.random.randint(low, high + 1, size=shape, dtype=numpy.uint8) for __ in range(2) + ] + assert compiler_engine.run(*args) == func_to_compile(*args) @pytest.mark.parametrize(