Files
tfhe-rs/scripts/test_filtering.py
2025-12-12 09:41:11 +01:00

236 lines
7.7 KiB
Python

"""
Script that generates a cargo-nextest filter as an output.
The string result can be directly injected into a nextest command.
"""
import argparse
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
"--layer",
dest="layer",
choices=["integer", "shortint"],
required=True,
help="tfhe-rs layer to use",
)
parser.add_argument(
"--backend",
dest="backend",
choices=["cpu", "gpu"],
default="cpu",
help="tfhe-rs backend to use",
)
parser.add_argument(
"--fast-tests",
dest="fast_tests",
action="store_true",
help="Run only a small subset of test suite",
)
parser.add_argument(
"--long-tests",
dest="long_tests",
action="store_true",
help="Run only the long tests suite",
)
parser.add_argument(
"--nightly-tests",
dest="nightly_tests",
action="store_true",
help="Run only a subset of test suite",
)
parser.add_argument(
"--big-instance",
dest="big_instance",
action="store_true",
help="Backend is using a large instance",
)
parser.add_argument(
"--multi-bit",
dest="multi_bit",
action="store_true",
help="Include tests running on multi-bit parameters set",
)
parser.add_argument(
"--signed-only",
dest="signed_only",
action="store_true",
help="Include only signed integer tests",
)
parser.add_argument(
"--unsigned-only",
dest="unsigned_only",
action="store_true",
help="Include only unsigned integer tests",
)
parser.add_argument(
"--no-big-params",
dest="no_big_params",
action="store_true",
help="Do not run tests with big parameters set (e.g. 4bits message with 4 bits carry)",
)
parser.add_argument(
"--no-big-params-gpu",
dest="no_big_params_gpu",
action="store_true",
help="Do not run tests with big parameters set (e.g. 3bits message with 3 bits carry) for GPU",
)
# block PBS are too slow for high params
# mul_crt_4_4 is extremely flaky (~80% failure)
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
# test_integer_smart_mul_param_message_4_carry_4_ks_pbs_gaussian_2m64 is too slow
# so is test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs_gaussian_2m64
# skip smart_div, smart_rem which are already covered by the smar_div_rem test
# skip default_div, default_rem which are covered by default_div_rem
EXCLUDED_INTEGER_TESTS = [
"/.*integer_smart_div_param/",
"/.*integer_smart_rem_param/",
"/.*integer_default_div_param/",
"/.*integer_default_rem_param/",
"/.*_block_pbs(_base)?_param_message_[34]_carry_[34]_ks_pbs_gaussian_2m64$/",
"~mul_crt_param_message_4_carry_4_ks_pbs_gaussian_2m64",
"/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]_ks_pbs_gaussian_2m64$/",
"/.*test_integer_smart_mul_param_message_4_carry_4_ks_pbs_gaussian_2m64$/",
"/.*test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs_gaussian_2m64$/",
"/.*::tests_long_run::.*/",
]
# skip default_div, default_rem which are covered by default_div_rem
EXCLUDED_INTEGER_FAST_TESTS = [
"/.*integer_default_div_param/",
"/.*integer_default_rem_param/",
"/.*_param_message_[14]_carry_[14]_ks_pbs_gaussian_2m64$/",
"/.*default_add_sequence_multi_thread_param_message_3_carry_3_ks_pbs_gaussian_2m64$/",
]
EXCLUDED_BIG_PARAMETERS = [
"/.*_param_message_4_carry_4_ks_pbs_gaussian_2m64$/",
]
EXCLUDED_BIG_PARAMETERS_GPU = [
"/.*_message_3_carry_3.*$/",
"/.*_group_3_message_2_carry_2.*$/",
]
def filter_integer_tests(input_args):
(multi_bit_filter, group_filter) = (
("_multi_bit", "_group_[0-9]") if input_args.multi_bit else ("", "")
)
backend_filter = ""
if not input_args.long_tests:
if input_args.backend == "gpu":
backend_filter = "gpu::"
if multi_bit_filter:
# For now, GPU only has specific parameters set for multi-bit
multi_bit_filter = "_gpu_multi_bit"
filter_expression = [f"test(/^integer::{backend_filter}.*/)"]
if input_args.multi_bit:
filter_expression.append("test(~_multi_bit)")
else:
filter_expression.append("not test(~_multi_bit)")
if input_args.signed_only:
filter_expression.append("test(~_signed)")
if input_args.unsigned_only:
filter_expression.append("not test(~_signed)")
if input_args.no_big_params:
for pattern in EXCLUDED_BIG_PARAMETERS:
filter_expression.append(f"not test({pattern})")
if input_args.no_big_params_gpu:
for pattern in EXCLUDED_BIG_PARAMETERS_GPU:
filter_expression.append(f"not test({pattern})")
if input_args.fast_tests and input_args.nightly_tests:
filter_expression.append(
f"test(/.*_default_.*?_param{multi_bit_filter}{group_filter}_message_[2-3]_carry_[2-3]_.*/)"
)
elif input_args.fast_tests:
# Test only fast default operations with only one set of parameters
filter_expression.append(
f"test(/.*_default_.*?_param{multi_bit_filter}{group_filter}_message_2_carry_2_.*/)"
)
elif input_args.nightly_tests:
# Test only fast default operations with only one set of parameters
# This subset would run slower than fast_tests hence the use of nightly_tests
filter_expression.append(
f"test(/.*_default_.*?_param{multi_bit_filter}{group_filter}_message_3_carry_3_.*/)"
)
excluded_tests = (
EXCLUDED_INTEGER_FAST_TESTS if input_args.fast_tests else EXCLUDED_INTEGER_TESTS
)
for pattern in excluded_tests:
filter_expression.append(f"not test({pattern})")
else:
if input_args.backend == "gpu":
filter_expression = ["test(/^integer::gpu::server_key::radix::tests_long_run.*/)"]
elif input_args.backend == "cpu":
filter_expression = ["test(/^integer::server_key::radix_parallel::tests_long_run.*/)"]
# Do not run noise check tests by default as they can be very slow
# they will be run e.g. nightly or on demand
filter = filter_expression.append(f"not test(/^integer::gpu::server_key::radix::tests_noise_distribution::.*::test_gpu_noise_check.*/)")
return " and ".join(filter_expression)
def filter_shortint_tests(input_args):
multi_bit_filter, group_filter = (
("_multi_bit", "_group_[0-9]") if input_args.multi_bit else ("", "")
)
if input_args.fast_tests:
msg_carry_pairs = [(2, 1), (2, 2), (2, 3)]
else:
msg_carry_pairs = [
(1, 1),
(1, 2),
(1, 3),
(1, 4),
(1, 5),
(1, 6),
(2, 1),
(2, 2),
(2, 3),
(3, 1),
(3, 2),
(3, 3),
]
if input_args.big_instance:
msg_carry_pairs.append((4, 4))
filter_expression = [
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks(32)?_pbs.*/)"
for msg, carry in msg_carry_pairs
]
filter_expression.append("test(/^shortint::.*meta_param_cpu_2_2/)")
filter_expression.append("test(/^shortint::.*_ci_run_filter/)")
opt_in_tests = " or ".join(filter_expression)
# Do not run noise check tests by default as they can be very slow
# they will be run e.g. nightly or on demand
filter = f"({opt_in_tests}) and not test(/^shortint::.*test_noise_check/)"
return filter
if __name__ == "__main__":
args = parser.parse_args()
expression = ""
if args.layer == "integer":
expression = filter_integer_tests(args)
elif args.layer == "shortint":
expression = filter_shortint_tests(args)
print(expression)