diff --git a/tests/hnumpy/test_compile.py b/tests/hnumpy/test_compile.py index 67f93dc3a..939d3f42c 100644 --- a/tests/hnumpy/test_compile.py +++ b/tests/hnumpy/test_compile.py @@ -98,6 +98,38 @@ def test_compile_and_run_function_multiple_outputs(function, input_ranges, list_ compiler_engine.run(*args) +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names", + [ + pytest.param(lambda x: x + 64, ((0, 10),), ["x"]), + pytest.param(lambda x: x * 3, ((0, 40),), ["x"]), + pytest.param(lambda x: 120 - x, ((40, 80),), ["x"]), + pytest.param(lambda x, y: x + y + 64, ((0, 20), (0, 20)), ["x", "y"]), + pytest.param(lambda x, y: 100 - y + x, ((0, 20), (0, 20)), ["x", "y"]), + pytest.param(lambda x, y: 50 - y * 2 + x, ((0, 20), (0, 20)), ["x", "y"]), + ], +) +def test_compile_and_run_correctness(function, input_ranges, list_of_arg_names): + """Test correctness of results when running a compiled function""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedValue(Integer(64, False)) for arg_name in list_of_arg_names + } + + compiler_engine = compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + ) + + args = [random.randint(low, high) for (low, high) in input_ranges] + assert compiler_engine.run(*args) == function(*args) + + def test_compile_function_with_direct_tlu(): """Test compile_numpy_function for a program with direct table lookup"""