diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index b2f1aa847..ce02e0841 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -123,6 +123,28 @@ def mix_x_and_y_and_call_binary_f_two_avoid_0_input(func, c, x, y): return z +def check_is_good_execution(compiler_engine, function, args): + """Run several times the check compiler_engine.run(*args) == function(*args). If always wrong, + return an error. One can set the expected probability of success of one execution and the + number of tests, to finetune the probability of bad luck, ie that we run several times the + check and always have a wrong result.""" + expected_probability_of_success = 0.95 + nb_tries = 5 + expected_bad_luck = (1 - expected_probability_of_success) ** nb_tries + + for i in range(1, nb_tries + 1): + if compiler_engine.run(*args) == function(*args): + # Good computation after i tries + print(f"Good computation after {i} tries") + return + + # Bad computation after nb_tries + raise AssertionError( + f"bad computation after {nb_tries} tries, which was supposed to happen with a " + f"probability of {expected_bad_luck}" + ) + + def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input_ranges): """Test correctness of results when running a compiled function""" @@ -145,11 +167,7 @@ def subtest_compile_and_run_unary_ufunc_correctness(ufunc, upper_function, input args = [random.randint(low, high) for (low, high) in input_ranges] - # TODO: fix the check - # assert compiler_engine.run(*args) == function(*args) - - if compiler_engine.run(*args) != function(*args): - print("Warning, bad computation") + check_is_good_execution(compiler_engine, function, args) def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, input_ranges): @@ -174,11 +192,7 @@ def subtest_compile_and_run_binary_ufunc_correctness(ufunc, upper_function, c, i args = [random.randint(low, high) for (low, high) in input_ranges] - # TODO: fix the check - # assert compiler_engine.run(*args) == function(*args) - - if compiler_engine.run(*args) != function(*args): - print("Warning, bad computation") + check_is_good_execution(compiler_engine, function, args) @pytest.mark.parametrize(