mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-06 21:34:05 -05:00
236 lines
7.7 KiB
Python
236 lines
7.7 KiB
Python
"""
|
|
Script that generates a cargo-nextest filter as an output.
|
|
The string result can be directly injected into a nextest command.
|
|
"""
|
|
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(allow_abbrev=False)
|
|
parser.add_argument(
|
|
"--layer",
|
|
dest="layer",
|
|
choices=["integer", "shortint"],
|
|
required=True,
|
|
help="tfhe-rs layer to use",
|
|
)
|
|
parser.add_argument(
|
|
"--backend",
|
|
dest="backend",
|
|
choices=["cpu", "gpu"],
|
|
default="cpu",
|
|
help="tfhe-rs backend to use",
|
|
)
|
|
parser.add_argument(
|
|
"--fast-tests",
|
|
dest="fast_tests",
|
|
action="store_true",
|
|
help="Run only a small subset of test suite",
|
|
)
|
|
parser.add_argument(
|
|
"--long-tests",
|
|
dest="long_tests",
|
|
action="store_true",
|
|
help="Run only the long tests suite",
|
|
)
|
|
parser.add_argument(
|
|
"--nightly-tests",
|
|
dest="nightly_tests",
|
|
action="store_true",
|
|
help="Run only a subset of test suite",
|
|
)
|
|
parser.add_argument(
|
|
"--big-instance",
|
|
dest="big_instance",
|
|
action="store_true",
|
|
help="Backend is using a large instance",
|
|
)
|
|
parser.add_argument(
|
|
"--multi-bit",
|
|
dest="multi_bit",
|
|
action="store_true",
|
|
help="Include tests running on multi-bit parameters set",
|
|
)
|
|
parser.add_argument(
|
|
"--signed-only",
|
|
dest="signed_only",
|
|
action="store_true",
|
|
help="Include only signed integer tests",
|
|
)
|
|
parser.add_argument(
|
|
"--unsigned-only",
|
|
dest="unsigned_only",
|
|
action="store_true",
|
|
help="Include only unsigned integer tests",
|
|
)
|
|
parser.add_argument(
|
|
"--no-big-params",
|
|
dest="no_big_params",
|
|
action="store_true",
|
|
help="Do not run tests with big parameters set (e.g. 4bits message with 4 bits carry)",
|
|
)
|
|
parser.add_argument(
|
|
"--no-big-params-gpu",
|
|
dest="no_big_params_gpu",
|
|
action="store_true",
|
|
help="Do not run tests with big parameters set (e.g. 3bits message with 3 bits carry) for GPU",
|
|
)
|
|
|
|
# block PBS are too slow for high params
|
|
# mul_crt_4_4 is extremely flaky (~80% failure)
|
|
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
|
|
# test_integer_smart_mul_param_message_4_carry_4_ks_pbs_gaussian_2m64 is too slow
|
|
# so is test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs_gaussian_2m64
|
|
# skip smart_div, smart_rem which are already covered by the smar_div_rem test
|
|
# skip default_div, default_rem which are covered by default_div_rem
|
|
EXCLUDED_INTEGER_TESTS = [
|
|
"/.*integer_smart_div_param/",
|
|
"/.*integer_smart_rem_param/",
|
|
"/.*integer_default_div_param/",
|
|
"/.*integer_default_rem_param/",
|
|
"/.*_block_pbs(_base)?_param_message_[34]_carry_[34]_ks_pbs_gaussian_2m64$/",
|
|
"~mul_crt_param_message_4_carry_4_ks_pbs_gaussian_2m64",
|
|
"/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]_ks_pbs_gaussian_2m64$/",
|
|
"/.*test_integer_smart_mul_param_message_4_carry_4_ks_pbs_gaussian_2m64$/",
|
|
"/.*test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs_gaussian_2m64$/",
|
|
"/.*::tests_long_run::.*/",
|
|
]
|
|
|
|
# skip default_div, default_rem which are covered by default_div_rem
|
|
EXCLUDED_INTEGER_FAST_TESTS = [
|
|
"/.*integer_default_div_param/",
|
|
"/.*integer_default_rem_param/",
|
|
"/.*_param_message_[14]_carry_[14]_ks_pbs_gaussian_2m64$/",
|
|
"/.*default_add_sequence_multi_thread_param_message_3_carry_3_ks_pbs_gaussian_2m64$/",
|
|
]
|
|
|
|
EXCLUDED_BIG_PARAMETERS = [
|
|
"/.*_param_message_4_carry_4_ks_pbs_gaussian_2m64$/",
|
|
]
|
|
|
|
EXCLUDED_BIG_PARAMETERS_GPU = [
|
|
"/.*_message_3_carry_3.*$/",
|
|
"/.*_group_3_message_2_carry_2.*$/",
|
|
]
|
|
|
|
|
|
def filter_integer_tests(input_args):
|
|
(multi_bit_filter, group_filter) = (
|
|
("_multi_bit", "_group_[0-9]") if input_args.multi_bit else ("", "")
|
|
)
|
|
backend_filter = ""
|
|
if not input_args.long_tests:
|
|
if input_args.backend == "gpu":
|
|
backend_filter = "gpu::"
|
|
if multi_bit_filter:
|
|
# For now, GPU only has specific parameters set for multi-bit
|
|
multi_bit_filter = "_gpu_multi_bit"
|
|
|
|
filter_expression = [f"test(/^integer::{backend_filter}.*/)"]
|
|
|
|
if input_args.multi_bit:
|
|
filter_expression.append("test(~_multi_bit)")
|
|
else:
|
|
filter_expression.append("not test(~_multi_bit)")
|
|
|
|
if input_args.signed_only:
|
|
filter_expression.append("test(~_signed)")
|
|
if input_args.unsigned_only:
|
|
filter_expression.append("not test(~_signed)")
|
|
|
|
if input_args.no_big_params:
|
|
for pattern in EXCLUDED_BIG_PARAMETERS:
|
|
filter_expression.append(f"not test({pattern})")
|
|
|
|
if input_args.no_big_params_gpu:
|
|
for pattern in EXCLUDED_BIG_PARAMETERS_GPU:
|
|
filter_expression.append(f"not test({pattern})")
|
|
|
|
if input_args.fast_tests and input_args.nightly_tests:
|
|
filter_expression.append(
|
|
f"test(/.*_default_.*?_param{multi_bit_filter}{group_filter}_message_[2-3]_carry_[2-3]_.*/)"
|
|
)
|
|
elif input_args.fast_tests:
|
|
# Test only fast default operations with only one set of parameters
|
|
filter_expression.append(
|
|
f"test(/.*_default_.*?_param{multi_bit_filter}{group_filter}_message_2_carry_2_.*/)"
|
|
)
|
|
elif input_args.nightly_tests:
|
|
# Test only fast default operations with only one set of parameters
|
|
# This subset would run slower than fast_tests hence the use of nightly_tests
|
|
filter_expression.append(
|
|
f"test(/.*_default_.*?_param{multi_bit_filter}{group_filter}_message_3_carry_3_.*/)"
|
|
)
|
|
excluded_tests = (
|
|
EXCLUDED_INTEGER_FAST_TESTS if input_args.fast_tests else EXCLUDED_INTEGER_TESTS
|
|
)
|
|
for pattern in excluded_tests:
|
|
filter_expression.append(f"not test({pattern})")
|
|
|
|
else:
|
|
if input_args.backend == "gpu":
|
|
filter_expression = ["test(/^integer::gpu::server_key::radix::tests_long_run.*/)"]
|
|
elif input_args.backend == "cpu":
|
|
filter_expression = ["test(/^integer::server_key::radix_parallel::tests_long_run.*/)"]
|
|
|
|
|
|
# Do not run noise check tests by default as they can be very slow
|
|
# they will be run e.g. nightly or on demand
|
|
filter = filter_expression.append(f"not test(/^integer::gpu::server_key::radix::tests_noise_distribution::.*::test_gpu_noise_check.*/)")
|
|
|
|
return " and ".join(filter_expression)
|
|
|
|
|
|
def filter_shortint_tests(input_args):
|
|
multi_bit_filter, group_filter = (
|
|
("_multi_bit", "_group_[0-9]") if input_args.multi_bit else ("", "")
|
|
)
|
|
|
|
if input_args.fast_tests:
|
|
msg_carry_pairs = [(2, 1), (2, 2), (2, 3)]
|
|
else:
|
|
msg_carry_pairs = [
|
|
(1, 1),
|
|
(1, 2),
|
|
(1, 3),
|
|
(1, 4),
|
|
(1, 5),
|
|
(1, 6),
|
|
(2, 1),
|
|
(2, 2),
|
|
(2, 3),
|
|
(3, 1),
|
|
(3, 2),
|
|
(3, 3),
|
|
]
|
|
if input_args.big_instance:
|
|
msg_carry_pairs.append((4, 4))
|
|
|
|
filter_expression = [
|
|
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks(32)?_pbs.*/)"
|
|
for msg, carry in msg_carry_pairs
|
|
]
|
|
filter_expression.append("test(/^shortint::.*meta_param_cpu_2_2/)")
|
|
|
|
filter_expression.append("test(/^shortint::.*_ci_run_filter/)")
|
|
|
|
opt_in_tests = " or ".join(filter_expression)
|
|
|
|
# Do not run noise check tests by default as they can be very slow
|
|
# they will be run e.g. nightly or on demand
|
|
filter = f"({opt_in_tests}) and not test(/^shortint::.*test_noise_check/)"
|
|
|
|
return filter
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
|
|
expression = ""
|
|
|
|
if args.layer == "integer":
|
|
expression = filter_integer_tests(args)
|
|
elif args.layer == "shortint":
|
|
expression = filter_shortint_tests(args)
|
|
|
|
print(expression)
|