diff --git a/tests/numpy/test_compile.py b/tests/numpy/test_compile.py index 8ea1b407b..67c82528a 100644 --- a/tests/numpy/test_compile.py +++ b/tests/numpy/test_compile.py @@ -363,6 +363,11 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso """Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC.""" sotc = partial(scalar_or_tensor_constructor, return_scalar=tensor_shape == ()) + run_multi_tlu_test = False + if tensor_shape != (): + run_multi_tlu_test = True + tensor_for_multi_tlu = numpy.arange(numpy.prod(tensor_shape)).reshape(tensor_shape) + tensor_for_multi_tlu_small_values = tensor_for_multi_tlu // 2 if ufunc in [numpy.power, numpy.float_power]: # Need small constants to keep results really small @@ -382,6 +387,23 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso tensor_shape, default_compilation_configuration, ) + if run_multi_tlu_test: + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y), + tensor_for_multi_tlu_small_values, + ((0, 4), (0, 5)), + tensor_shape, + default_compilation_configuration, + ) + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y), + tensor_for_multi_tlu_small_values, + ((0, 4), (0, 5)), + tensor_shape, + default_compilation_configuration, + ) elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]: subtest_compile_and_run_binary_ufunc_correctness( ufunc, @@ -391,6 +413,15 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso tensor_shape, default_compilation_configuration, ) + if run_multi_tlu_test: + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y), + tensor_for_multi_tlu, + ((1, 5), (1, 5)), + tensor_shape, + default_compilation_configuration, + ) elif ufunc in [numpy.lcm, numpy.left_shift]: # Need small constants to keep results sufficiently small subtest_compile_and_run_binary_ufunc_correctness( @@ -409,15 +440,42 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso tensor_shape, default_compilation_configuration, ) - elif ufunc in [numpy.ldexp]: - # Need small constants to keep results sufficiently small - # TODO: https://github.com/zama-ai/concretefhe-internal/issues/665 - # ldexp requires multi TLU for the tensor case for now - if tensor_shape == (): + if run_multi_tlu_test: + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y), + tensor_for_multi_tlu + if ufunc != numpy.left_shift + else tensor_for_multi_tlu_small_values, + ((0, 5), (0, 5)), + tensor_shape, + default_compilation_configuration, + ) subtest_compile_and_run_binary_ufunc_correctness( ufunc, lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y), - sotc(2), + tensor_for_multi_tlu + if ufunc != numpy.left_shift + else tensor_for_multi_tlu_small_values, + ((0, 5), (0, 5)), + tensor_shape, + default_compilation_configuration, + ) + elif ufunc in [numpy.ldexp]: + # Need small constants to keep results sufficiently small + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y), + sotc(2), + ((0, 5), (0, 5)), + tensor_shape, + default_compilation_configuration, + ) + if run_multi_tlu_test: + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y), + tensor_for_multi_tlu // 2, ((0, 5), (0, 5)), tensor_shape, default_compilation_configuration, @@ -440,6 +498,23 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso tensor_shape, default_compilation_configuration, ) + if run_multi_tlu_test: + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y), + tensor_for_multi_tlu, + ((0, 5), (0, 5)), + tensor_shape, + default_compilation_configuration, + ) + subtest_compile_and_run_binary_ufunc_correctness( + ufunc, + lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y), + tensor_for_multi_tlu, + ((0, 5), (0, 5)), + tensor_shape, + default_compilation_configuration, + ) @pytest.mark.parametrize(