From 879699c0727da54c4ac89e5bd42ad5a139477415 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Test=C3=A9?= Date: Tue, 4 Jun 2024 11:45:09 +0200 Subject: [PATCH] chore(ci): filter integer and shortint tests using python script Backend support for GPU has been added to integer tests. --- .github/workflows/aws_tfhe_gpu_4090_tests.yml | 2 + scripts/integer-tests.sh | 79 +++------ scripts/shortint-tests.sh | 64 +------ scripts/test_filtering.py | 159 ++++++++++++++++++ 4 files changed, 197 insertions(+), 107 deletions(-) create mode 100644 scripts/test_filtering.py diff --git a/.github/workflows/aws_tfhe_gpu_4090_tests.yml b/.github/workflows/aws_tfhe_gpu_4090_tests.yml index 06e32ddd4..220ac3590 100644 --- a/.github/workflows/aws_tfhe_gpu_4090_tests.yml +++ b/.github/workflows/aws_tfhe_gpu_4090_tests.yml @@ -18,6 +18,8 @@ on: pull_request: types: [ labeled ] + # FIXME: remove pull_request event and schedule it as nightly + jobs: cuda-tests-linux: name: CUDA tests (RTX 4090) diff --git a/scripts/integer-tests.sh b/scripts/integer-tests.sh index 6402950b8..45a3aa2c6 100755 --- a/scripts/integer-tests.sh +++ b/scripts/integer-tests.sh @@ -3,7 +3,7 @@ set -e function usage() { - echo "$0: shortint test runner" + echo "$0: integer test runner" echo echo "--help Print this message" echo "--rust-toolchain The toolchain to run the tests with default: stable" @@ -11,18 +11,19 @@ function usage() { echo "--unsigned-only Run only unsigned integer tests, by default both signed and unsigned tests are run" echo "--signed-only Run only signed integer tests, by default both signed and unsigned tests are run" echo "--cargo-profile The cargo profile used to build tests" + echo "--backend Backend to use with tfhe-rs" echo "--avx512-support Set to ON to enable avx512" echo "--tfhe-package The package spec like tfhe@0.4.2, default=tfhe" echo } RUST_TOOLCHAIN="+stable" -multi_bit="" -not_multi_bit="_multi_bit" -# Run signed test by default -signed="" -not_signed="" +multi_bit_argument= +sign_argument= +fast_tests_argument= cargo_profile="release" +backend="cpu" +gpu_feature="" avx512_feature="" tfhe_package="tfhe" @@ -40,18 +41,15 @@ do ;; "--multi-bit" ) - multi_bit="_multi_bit" - not_multi_bit="" + multi_bit_argument=--multi-bit ;; "--unsigned-only" ) - signed="" - not_signed="_signed" + sign_argument=--unsigned-only ;; "--signed-only" ) - signed="_signed" - not_signed="" + sign_argument=--signed-only ;; "--cargo-profile" ) @@ -59,6 +57,10 @@ do cargo_profile="$1" ;; + "--backend" ) + shift + backend="$1" + ;; "--avx512-support" ) shift if [[ "$1" == "ON" ]]; then @@ -83,6 +85,14 @@ if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}" fi +if [[ "${FAST_TESTS}" == TRUE ]]; then + fast_tests_argument=--fast-tests +fi + +if [[ "${backend}" == "gpu" ]]; then + gpu_feature="gpu" +fi + CURR_DIR="$(dirname "$0")" ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")" @@ -107,47 +117,12 @@ else doctest_threads="${num_cpu_threads}" fi -# block pbs are too slow for high params -# mul_crt_4_4 is extremely flaky (~80% failure) -# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment -# test_integer_smart_mul_param_message_4_carry_4_ks_pbs is too slow -# so is test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs -# we skip smart_div, smart_rem which are already covered by the smar_div_rem test -# we similarly skip default_div, default_rem which are covered by default_div_rem -full_test_filter_expression="""\ -test(/^integer::.*${multi_bit}/) \ -${signed:+"and test(/^integer::.*${signed}/)"} \ -${not_multi_bit:+"and not test(~${not_multi_bit})"} \ -${not_signed:+"and not test(~${not_signed})"} \ -and not test(/.*integer_smart_div_param/) \ -and not test(/.*integer_smart_rem_param/) \ -and not test(/.*integer_default_div_param/) \ -and not test(/.*integer_default_rem_param/) \ -and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]_ks_pbs$/) \ -and not test(~mul_crt_param_message_4_carry_4_ks_pbs) \ -and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]_ks_pbs$/) \ -and not test(/.*test_integer_smart_mul_param_message_4_carry_4_ks_pbs$/) \ -and not test(/.*test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs$/)""" - -# test only fast default operations with only two set of parameters -# we skip default_div, default_rem which are covered by default_div_rem -fast_test_filter_expression="""\ -test(/^integer::.*${multi_bit}/) \ -${signed:+"and test(/^integer::.*${signed}/)"} \ -${not_multi_bit:+"and not test(~${not_multi_bit})"} \ -${not_signed:+"and not test(~${not_signed})"} \ -and test(/.*_default_.*?_param${multi_bit}_message_[2-3]_carry_[2-3]${multi_bit:+"_group_2"}_ks_pbs/) \ -and not test(/.*integer_default_div_param/) \ -and not test(/.*integer_default_rem_param/) \ -and not test(/.*_param_message_[14]_carry_[14]_ks_pbs$/) \ -and not test(/.*default_add_sequence_multi_thread_param_message_3_carry_3_ks_pbs$/)""" +filter_expression=$(/usr/bin/python3 scripts/test_filtering.py --layer integer --backend "${backend}" ${fast_tests_argument} ${multi_bit_argument} ${sign_argument}) if [[ "${FAST_TESTS}" == "TRUE" ]]; then echo "Running 'fast' test set'" - filter_expression="${fast_test_filter_expression}" else echo "Running 'slow' test set" - filter_expression="${full_test_filter_expression}" fi cargo "${RUST_TOOLCHAIN}" nextest run \ @@ -155,17 +130,17 @@ cargo "${RUST_TOOLCHAIN}" nextest run \ --cargo-profile "${cargo_profile}" \ --package "${tfhe_package}" \ --profile ci \ - --features="${ARCH_FEATURE}",integer,internal-keycache,zk-pok,"${avx512_feature}" \ + --features="${ARCH_FEATURE}",integer,internal-keycache,zk-pok,"${avx512_feature}","${gpu_feature}" \ --test-threads "${test_threads}" \ -E "$filter_expression" -if [[ "${multi_bit}" == "" ]]; then +if [[ -z ${multi_bit_argument} ]]; then cargo "${RUST_TOOLCHAIN}" test \ --profile "${cargo_profile}" \ --package "${tfhe_package}" \ - --features="${ARCH_FEATURE}",integer,internal-keycache,"${avx512_feature}" \ + --features="${ARCH_FEATURE}",integer,internal-keycache,"${avx512_feature}","${gpu_feature}" \ --doc \ - -- --test-threads="${doctest_threads}" integer:: + -- --test-threads="${doctest_threads}" integer::"${gpu_feature}" fi echo "Test ran in $SECONDS seconds" diff --git a/scripts/shortint-tests.sh b/scripts/shortint-tests.sh index dc7df504f..3faa9a575 100755 --- a/scripts/shortint-tests.sh +++ b/scripts/shortint-tests.sh @@ -15,6 +15,8 @@ function usage() { RUST_TOOLCHAIN="+stable" multi_bit="" +multi_bit_argument= +fast_tests_argument= cargo_profile="release" tfhe_package="tfhe" @@ -33,6 +35,7 @@ do "--multi-bit" ) multi_bit="_multi_bit" + multi_bit_argument=--multi-bit ;; "--cargo-profile" ) @@ -57,6 +60,10 @@ if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}" fi +if [[ "${FAST_TESTS}" == TRUE ]]; then + fast_tests_argument=--fast-tests +fi + CURR_DIR="$(dirname "$0")" ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")" @@ -86,33 +93,7 @@ else fi if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then - if [[ "${FAST_TESTS}" != TRUE ]]; then - filter_expression_small_params="""\ -(\ - test(/^shortint::.*_param${multi_bit}_message_1_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_4${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_5${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_6${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_3_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_3_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_3_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_ci_run_filter/) \ -)\ -and not test(~smart_add_and_mul)""" # This test is too slow - else - filter_expression_small_params="""\ -(\ - test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -)\ -and not test(~smart_add_and_mul)""" # This test is too slow - fi + filter_expression_small_params=$(/usr/bin/python3 scripts/test_filtering.py --layer shortint ${fast_tests_argument} ${multi_bit_argument}) # Run tests only no examples or benches with small params and more threads cargo "${RUST_TOOLCHAIN}" nextest run \ @@ -151,34 +132,7 @@ and not test(~smart_add_and_mul)""" fi fi else - if [[ "${FAST_TESTS}" != TRUE ]]; then - filter_expression="""\ -(\ - test(/^shortint::.*_param${multi_bit}_message_1_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_4${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_5${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_1_carry_6${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_3_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_3_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_3_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_4_carry_4${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_ci_run_filter/) \ -)\ -and not test(~smart_add_and_mul)""" # This test is too slow - else - filter_expression="""\ -(\ - test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \ -)\ -and not test(~smart_add_and_mul)""" # This test is too slow - fi + filter_expression=$(/usr/bin/python3 scripts/test_filtering.py --layer shortint --big-instance ${fast_tests_argument} ${multi_bit_argument}) # Run tests only no examples or benches with small params and more threads cargo "${RUST_TOOLCHAIN}" nextest run \ diff --git a/scripts/test_filtering.py b/scripts/test_filtering.py new file mode 100644 index 000000000..992a8d25f --- /dev/null +++ b/scripts/test_filtering.py @@ -0,0 +1,159 @@ +""" +Script that generates a cargo-nextest filter as an output. +The string result can be directly injected into a nextest command. +""" +import argparse + +parser = argparse.ArgumentParser(allow_abbrev=False) +parser.add_argument( + "--layer", + dest="layer", + choices=["integer", "shortint"], + required=True, + help="tfhe-rs layer to use", +) +parser.add_argument( + "--backend", + dest="backend", + choices=["cpu", "gpu"], + default="cpu", + help="tfhe-rs backend to use", +) +parser.add_argument( + "--fast-tests", + dest="fast_tests", + action="store_true", + help="Run only a small subset of test suite", +) +parser.add_argument( + "--big-instance", + dest="big_instance", + action="store_true", + help="Backend is using a large instance", +) +parser.add_argument( + "--multi-bit", + dest="multi_bit", + action="store_true", + help="Include tests running on multi-bit parameters set", +) +parser.add_argument( + "--signed-only", + dest="signed_only", + action="store_true", + help="Include only signed integer tests", +) +parser.add_argument( + "--unsigned-only", + dest="unsigned_only", + action="store_true", + help="Include only unsigned integer tests", +) + +# block PBS are too slow for high params +# mul_crt_4_4 is extremely flaky (~80% failure) +# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment +# test_integer_smart_mul_param_message_4_carry_4_ks_pbs is too slow +# so is test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs +# skip smart_div, smart_rem which are already covered by the smar_div_rem test +# skip default_div, default_rem which are covered by default_div_rem +EXCLUDED_INTEGER_TESTS = [ + "/.*integer_smart_div_param/", + "/.*integer_smart_rem_param/", + "/.*integer_default_div_param/", + "/.*integer_default_rem_param/", + "/.*_block_pbs(_base)?_param_message_[34]_carry_[34]_ks_pbs$/", + "~mul_crt_param_message_4_carry_4_ks_pbs", + "/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]_ks_pbs$/", + "/.*test_integer_smart_mul_param_message_4_carry_4_ks_pbs$/", + "/.*test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs$/", +] + +# skip default_div, default_rem which are covered by default_div_rem +EXCLUDED_INTEGER_FAST_TESTS = [ + "/.*integer_default_div_param/", + "/.*integer_default_rem_param/", + "/.*_param_message_[14]_carry_[14]_ks_pbs$/", + "/.*default_add_sequence_multi_thread_param_message_3_carry_3_ks_pbs$/", +] + + +def filter_integer_tests(input_args): + multi_bit_filter = "_multi_bit" if input_args.multi_bit else "" + backend_filter = "" + if input_args.backend == "gpu": + backend_filter = "gpu::" + + filter_expression = [f"test(/^integer::{backend_filter}.*/)"] + + if input_args.multi_bit: + filter_expression.append("test(~_multi_bit)") + else: + filter_expression.append("not test(~_multi_bit)") + + if input_args.signed_only: + filter_expression.append("test(~_signed)") + if input_args.unsigned_only: + filter_expression.append("not test(~_signed)") + + if input_args.fast_tests: + # Test only fast default operations with only two set of parameters + param_group = "_group_2" if input_args.multi_bit else "" + filter_expression.append( + f"test(/.*_default_.*?_param{multi_bit_filter}_message_[2-3]_carry_[2-3]{param_group}_ks_pbs/)" + ) + + excluded_tests = ( + EXCLUDED_INTEGER_FAST_TESTS if input_args.fast_tests else EXCLUDED_INTEGER_TESTS + ) + for pattern in excluded_tests: + filter_expression.append(f"not test({pattern})") + + return " and ".join(filter_expression) + + +def filter_shortint_tests(input_args): + multi_bit_filter, group_filter = ( + ("_multi_bit", "_group_[0-9]") if input_args.multi_bit else ("", "") + ) + + if input_args.fast_tests: + msg_carry_pairs = [(2, 1), (2, 2), (2, 3)] + else: + msg_carry_pairs = [ + (1, 1), + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (1, 6), + (2, 1), + (2, 2), + (2, 3), + (3, 1), + (3, 2), + (3, 3), + ] + if input_args.big_instance: + msg_carry_pairs.append((4, 4)) + + filter_expression = [ + f"test(/^shortint::.*_param{multi_bit_filter}_message_{msg}_carry_{carry}{group_filter}(_compact_pk)?_ks_pbs/)" + for msg, carry in msg_carry_pairs + ] + filter_expression.append("test(/^shortint::.*_ci_run_filter/)") + + return " or ".join(filter_expression) + + +if __name__ == "__main__": + args = parser.parse_args() + + expression = "" + + if args.layer == "integer": + expression = filter_integer_tests(args) + elif args.layer == "shortint": + expression = filter_shortint_tests(args) + + print(expression)