test: add tests for bivariate functions with tensor and multi TLU

closes #558
closes #567
closes #665
This commit is contained in:
Arthur Meyre
2021-11-17 18:52:24 +01:00
parent ff03bc2220
commit f6aadd37e9

View File

@@ -363,6 +363,11 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
sotc = partial(scalar_or_tensor_constructor, return_scalar=tensor_shape == ())
run_multi_tlu_test = False
if tensor_shape != ():
run_multi_tlu_test = True
tensor_for_multi_tlu = numpy.arange(numpy.prod(tensor_shape)).reshape(tensor_shape)
tensor_for_multi_tlu_small_values = tensor_for_multi_tlu // 2
if ufunc in [numpy.power, numpy.float_power]:
# Need small constants to keep results really small
@@ -382,6 +387,23 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
tensor_shape,
default_compilation_configuration,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y),
tensor_for_multi_tlu_small_values,
((0, 4), (0, 5)),
tensor_shape,
default_compilation_configuration,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
tensor_for_multi_tlu_small_values,
((0, 4), (0, 5)),
tensor_shape,
default_compilation_configuration,
)
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
@@ -391,6 +413,15 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
tensor_shape,
default_compilation_configuration,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
tensor_for_multi_tlu,
((1, 5), (1, 5)),
tensor_shape,
default_compilation_configuration,
)
elif ufunc in [numpy.lcm, numpy.left_shift]:
# Need small constants to keep results sufficiently small
subtest_compile_and_run_binary_ufunc_correctness(
@@ -409,15 +440,42 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
tensor_shape,
default_compilation_configuration,
)
elif ufunc in [numpy.ldexp]:
# Need small constants to keep results sufficiently small
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/665
# ldexp requires multi TLU for the tensor case for now
if tensor_shape == ():
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y),
tensor_for_multi_tlu
if ufunc != numpy.left_shift
else tensor_for_multi_tlu_small_values,
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
sotc(2),
tensor_for_multi_tlu
if ufunc != numpy.left_shift
else tensor_for_multi_tlu_small_values,
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
)
elif ufunc in [numpy.ldexp]:
# Need small constants to keep results sufficiently small
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
sotc(2),
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
tensor_for_multi_tlu // 2,
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
@@ -440,6 +498,23 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
tensor_shape,
default_compilation_configuration,
)
if run_multi_tlu_test:
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y),
tensor_for_multi_tlu,
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
)
subtest_compile_and_run_binary_ufunc_correctness(
ufunc,
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
tensor_for_multi_tlu,
((0, 5), (0, 5)),
tensor_shape,
default_compilation_configuration,
)
@pytest.mark.parametrize(