mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
test: add tests for bivariate functions with tensor and multi TLU
closes #558 closes #567 closes #665
This commit is contained in:
@@ -363,6 +363,11 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
"""Test biary functions which are in tracing.NPTracer.LIST_OF_SUPPORTED_UFUNC."""
|
||||
|
||||
sotc = partial(scalar_or_tensor_constructor, return_scalar=tensor_shape == ())
|
||||
run_multi_tlu_test = False
|
||||
if tensor_shape != ():
|
||||
run_multi_tlu_test = True
|
||||
tensor_for_multi_tlu = numpy.arange(numpy.prod(tensor_shape)).reshape(tensor_shape)
|
||||
tensor_for_multi_tlu_small_values = tensor_for_multi_tlu // 2
|
||||
|
||||
if ufunc in [numpy.power, numpy.float_power]:
|
||||
# Need small constants to keep results really small
|
||||
@@ -382,6 +387,23 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y),
|
||||
tensor_for_multi_tlu_small_values,
|
||||
((0, 4), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
|
||||
tensor_for_multi_tlu_small_values,
|
||||
((0, 4), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.floor_divide, numpy.fmod, numpy.remainder, numpy.true_divide]:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
@@ -391,6 +413,15 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
|
||||
tensor_for_multi_tlu,
|
||||
((1, 5), (1, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.lcm, numpy.left_shift]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
@@ -409,15 +440,42 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.ldexp]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
# TODO: https://github.com/zama-ai/concretefhe-internal/issues/665
|
||||
# ldexp requires multi TLU for the tensor case for now
|
||||
if tensor_shape == ():
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y),
|
||||
tensor_for_multi_tlu
|
||||
if ufunc != numpy.left_shift
|
||||
else tensor_for_multi_tlu_small_values,
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
|
||||
sotc(2),
|
||||
tensor_for_multi_tlu
|
||||
if ufunc != numpy.left_shift
|
||||
else tensor_for_multi_tlu_small_values,
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
elif ufunc in [numpy.ldexp]:
|
||||
# Need small constants to keep results sufficiently small
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
|
||||
sotc(2),
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
|
||||
tensor_for_multi_tlu // 2,
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
@@ -440,6 +498,23 @@ def test_binary_ufunc_operations(ufunc, default_compilation_configuration, tenso
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
if run_multi_tlu_test:
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_one(func, sotc, c, x, y),
|
||||
tensor_for_multi_tlu,
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
subtest_compile_and_run_binary_ufunc_correctness(
|
||||
ufunc,
|
||||
lambda func, c, x, y: mix_x_and_y_and_call_binary_f_two(func, sotc, c, x, y),
|
||||
tensor_for_multi_tlu,
|
||||
((0, 5), (0, 5)),
|
||||
tensor_shape,
|
||||
default_compilation_configuration,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user