diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 52ac49a00..b8c5dd355 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -1828,16 +1828,6 @@ def test_wrong_inputs(default_compilation_configuration): [ pytest.param(lambda x: (x + (-27)) + 32, ((0, 10),), ["x"]), pytest.param(lambda x: ((-3) * x) + (100 - (x + 1)), ((0, 10),), ["x"]), - pytest.param( - lambda x: (20 + 10 * numpy.tanh(50 * (numpy.cos(x + 33.0)))).astype(numpy.uint32), - ((0, 31),), - ["x"], - ), - pytest.param( - lambda x: (20 * (numpy.cos(x + 33.0)) + 30).astype(numpy.uint32), - ((0, 31),), - ["x"], - ), pytest.param( lambda x, y: (-1) * x + (-2) * y + 40, ( @@ -1869,6 +1859,42 @@ def test_compile_and_run_correctness_with_negative_values( assert compiler_engine.run(*args) == function(*args) +@pytest.mark.parametrize( + "function,input_ranges,list_of_arg_names", + [ + pytest.param( + lambda x: (20 + 10 * numpy.tanh(50 * (numpy.cos(x + 33.0)))).astype(numpy.uint32), + ((0, 31),), + ["x"], + ), + pytest.param( + lambda x: (20 * (numpy.cos(x + 33.0)) + 30).astype(numpy.uint32), + ((0, 31),), + ["x"], + ), + ], +) +def test_compile_and_run_correctness_with_negative_values_and_pbs( + function, input_ranges, list_of_arg_names, default_compilation_configuration +): + """Test correctness of results when running a compiled function, which has some negative + intermediate values.""" + + function_parameters = { + arg_name: EncryptedScalar(Integer(64, False)) for arg_name in list_of_arg_names + } + + compiler_engine = compile_numpy_function( + function, + function_parameters, + data_gen(tuple(range(x[0], x[1] + 1) for x in input_ranges)), + default_compilation_configuration, + ) + + args = [random.randint(low, high) for (low, high) in input_ranges] + check_is_good_execution(compiler_engine, function, args, verbose=False) + + def check_equality_modulo(a, b, modulus): """Check that (a mod modulus) == (b mod modulus)""" return (a % modulus) == (b % modulus)