diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index ac59eca02..161e0d852 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1513,3 +1513,54 @@ def test_wrong_inputs(default_compilation_configuration): str(excinfo.value) == f"wrong type for inputs {dict_for_inputs}, " f"needs to be one of {list_of_possible_basevalue}" ) + + +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names", + [ + pytest.param(lambda x: (x + (-27)) + 32, ((0, 10),), ["x"]), + pytest.param(lambda x: ((-3) * x) + (100 - (x + 1)), ((0, 10),), ["x"]), + # FIXME: doesn't work for now, #885 + # pytest.param( + # lambda x: (50 * (numpy.cos(x + 33.0))).astype(numpy.uint32), + # ((0, 31),), + # ["x"], + # ), + pytest.param( + lambda x: (20 * (numpy.cos(x + 33.0)) + 30).astype(numpy.uint32), + ((0, 31),), + ["x"], + ), + pytest.param( + lambda x, y: (-1) * x + (-2) * y + 40, + ( + (0, 10), + (0, 10), + ), + ["x", "y"], + ), + ], +) +def test_compile_and_run_correctness_with_negative_values( + function, input_ranges, list_of_arg_names, default_compilation_configuration +): + """Test correctness of results when running a compiled function, which has some negative + intermediate values.""" + + def data_gen(args): + for prod in itertools.product(*args): + yield prod + + function_parameters = { + arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names + } + + compiler_engine = compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + default_compilation_configuration, + ) + + args = [random.randint(low, high) for (low, high) in input_ranges] + assert compiler_engine.run(*args) == function(*args)