tests: result correctness of the compiled function

This commit is contained in:
youben11
2021-08-24 12:44:45 +01:00
committed by Ayoub Benaissa
parent 66d0c8dd62
commit 2585eb7ed8

View File

@@ -98,6 +98,38 @@ def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_
compiler_engine.run(*args)
@pytest.mark.parametrize(
"function,input_ranges,list_of_arg_names",
[
pytest.param(lambda x: x + 64, ((0, 10),), ["x"]),
pytest.param(lambda x: x * 3, ((0, 40),), ["x"]),
pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]),
pytest.param(lambda x, y: x + y + 64, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]),
pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]),
],
)
def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names):
"""Test correctness of results when running a compiled function"""
def data_gen(args):
for prod in itertools.product(*args):
yield prod
function_parameters = {
arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names
}
compiler_engine = compile_numpy_function(
function,
function_parameters,
data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)),
)
args = [random.randint(low, high) for (low, high) in input_ranges]
assert compiler_engine.run(*args) == function(*args)
def test_compile_function_with_direct_tlu():
"""Test compile_numpy_function for a program with direct table lookup"""