mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
chore(ci): filter integer and shortint tests using python script
Backend support for GPU has been added to integer tests.
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
set -e
|
||||
|
||||
function usage() {
|
||||
echo "$0: shortint test runner"
|
||||
echo "$0: integer test runner"
|
||||
echo
|
||||
echo "--help Print this message"
|
||||
echo "--rust-toolchain The toolchain to run the tests with default: stable"
|
||||
@@ -11,18 +11,19 @@ function usage() {
|
||||
echo "--unsigned-only Run only unsigned integer tests, by default both signed and unsigned tests are run"
|
||||
echo "--signed-only Run only signed integer tests, by default both signed and unsigned tests are run"
|
||||
echo "--cargo-profile The cargo profile used to build tests"
|
||||
echo "--backend Backend to use with tfhe-rs"
|
||||
echo "--avx512-support Set to ON to enable avx512"
|
||||
echo "--tfhe-package The package spec like tfhe@0.4.2, default=tfhe"
|
||||
echo
|
||||
}
|
||||
|
||||
RUST_TOOLCHAIN="+stable"
|
||||
multi_bit=""
|
||||
not_multi_bit="_multi_bit"
|
||||
# Run signed test by default
|
||||
signed=""
|
||||
not_signed=""
|
||||
multi_bit_argument=
|
||||
sign_argument=
|
||||
fast_tests_argument=
|
||||
cargo_profile="release"
|
||||
backend="cpu"
|
||||
gpu_feature=""
|
||||
avx512_feature=""
|
||||
tfhe_package="tfhe"
|
||||
|
||||
@@ -40,18 +41,15 @@ do
|
||||
;;
|
||||
|
||||
"--multi-bit" )
|
||||
multi_bit="_multi_bit"
|
||||
not_multi_bit=""
|
||||
multi_bit_argument=--multi-bit
|
||||
;;
|
||||
|
||||
"--unsigned-only" )
|
||||
signed=""
|
||||
not_signed="_signed"
|
||||
sign_argument=--unsigned-only
|
||||
;;
|
||||
|
||||
"--signed-only" )
|
||||
signed="_signed"
|
||||
not_signed=""
|
||||
sign_argument=--signed-only
|
||||
;;
|
||||
|
||||
"--cargo-profile" )
|
||||
@@ -59,6 +57,10 @@ do
|
||||
cargo_profile="$1"
|
||||
;;
|
||||
|
||||
"--backend" )
|
||||
shift
|
||||
backend="$1"
|
||||
;;
|
||||
"--avx512-support" )
|
||||
shift
|
||||
if [[ "$1" == "ON" ]]; then
|
||||
@@ -83,6 +85,14 @@ if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then
|
||||
RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}"
|
||||
fi
|
||||
|
||||
if [[ "${FAST_TESTS}" == TRUE ]]; then
|
||||
fast_tests_argument=--fast-tests
|
||||
fi
|
||||
|
||||
if [[ "${backend}" == "gpu" ]]; then
|
||||
gpu_feature="gpu"
|
||||
fi
|
||||
|
||||
CURR_DIR="$(dirname "$0")"
|
||||
ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")"
|
||||
|
||||
@@ -107,47 +117,12 @@ else
|
||||
doctest_threads="${num_cpu_threads}"
|
||||
fi
|
||||
|
||||
# block pbs are too slow for high params
|
||||
# mul_crt_4_4 is extremely flaky (~80% failure)
|
||||
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
|
||||
# test_integer_smart_mul_param_message_4_carry_4_ks_pbs is too slow
|
||||
# so is test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs
|
||||
# we skip smart_div, smart_rem which are already covered by the smar_div_rem test
|
||||
# we similarly skip default_div, default_rem which are covered by default_div_rem
|
||||
full_test_filter_expression="""\
|
||||
test(/^integer::.*${multi_bit}/) \
|
||||
${signed:+"and test(/^integer::.*${signed}/)"} \
|
||||
${not_multi_bit:+"and not test(~${not_multi_bit})"} \
|
||||
${not_signed:+"and not test(~${not_signed})"} \
|
||||
and not test(/.*integer_smart_div_param/) \
|
||||
and not test(/.*integer_smart_rem_param/) \
|
||||
and not test(/.*integer_default_div_param/) \
|
||||
and not test(/.*integer_default_rem_param/) \
|
||||
and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]_ks_pbs$/) \
|
||||
and not test(~mul_crt_param_message_4_carry_4_ks_pbs) \
|
||||
and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]_ks_pbs$/) \
|
||||
and not test(/.*test_integer_smart_mul_param_message_4_carry_4_ks_pbs$/) \
|
||||
and not test(/.*test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs$/)"""
|
||||
|
||||
# test only fast default operations with only two set of parameters
|
||||
# we skip default_div, default_rem which are covered by default_div_rem
|
||||
fast_test_filter_expression="""\
|
||||
test(/^integer::.*${multi_bit}/) \
|
||||
${signed:+"and test(/^integer::.*${signed}/)"} \
|
||||
${not_multi_bit:+"and not test(~${not_multi_bit})"} \
|
||||
${not_signed:+"and not test(~${not_signed})"} \
|
||||
and test(/.*_default_.*?_param${multi_bit}_message_[2-3]_carry_[2-3]${multi_bit:+"_group_2"}_ks_pbs/) \
|
||||
and not test(/.*integer_default_div_param/) \
|
||||
and not test(/.*integer_default_rem_param/) \
|
||||
and not test(/.*_param_message_[14]_carry_[14]_ks_pbs$/) \
|
||||
and not test(/.*default_add_sequence_multi_thread_param_message_3_carry_3_ks_pbs$/)"""
|
||||
filter_expression=$(/usr/bin/python3 scripts/test_filtering.py --layer integer --backend "${backend}" ${fast_tests_argument} ${multi_bit_argument} ${sign_argument})
|
||||
|
||||
if [[ "${FAST_TESTS}" == "TRUE" ]]; then
|
||||
echo "Running 'fast' test set'"
|
||||
filter_expression="${fast_test_filter_expression}"
|
||||
else
|
||||
echo "Running 'slow' test set"
|
||||
filter_expression="${full_test_filter_expression}"
|
||||
fi
|
||||
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
@@ -155,17 +130,17 @@ cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
--cargo-profile "${cargo_profile}" \
|
||||
--package "${tfhe_package}" \
|
||||
--profile ci \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache,zk-pok,"${avx512_feature}" \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache,zk-pok,"${avx512_feature}","${gpu_feature}" \
|
||||
--test-threads "${test_threads}" \
|
||||
-E "$filter_expression"
|
||||
|
||||
if [[ "${multi_bit}" == "" ]]; then
|
||||
if [[ -z ${multi_bit_argument} ]]; then
|
||||
cargo "${RUST_TOOLCHAIN}" test \
|
||||
--profile "${cargo_profile}" \
|
||||
--package "${tfhe_package}" \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache,"${avx512_feature}" \
|
||||
--features="${ARCH_FEATURE}",integer,internal-keycache,"${avx512_feature}","${gpu_feature}" \
|
||||
--doc \
|
||||
-- --test-threads="${doctest_threads}" integer::
|
||||
-- --test-threads="${doctest_threads}" integer::"${gpu_feature}"
|
||||
fi
|
||||
|
||||
echo "Test ran in $SECONDS seconds"
|
||||
|
||||
@@ -15,6 +15,8 @@ function usage() {
|
||||
|
||||
RUST_TOOLCHAIN="+stable"
|
||||
multi_bit=""
|
||||
multi_bit_argument=
|
||||
fast_tests_argument=
|
||||
cargo_profile="release"
|
||||
tfhe_package="tfhe"
|
||||
|
||||
@@ -33,6 +35,7 @@ do
|
||||
|
||||
"--multi-bit" )
|
||||
multi_bit="_multi_bit"
|
||||
multi_bit_argument=--multi-bit
|
||||
;;
|
||||
|
||||
"--cargo-profile" )
|
||||
@@ -57,6 +60,10 @@ if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then
|
||||
RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}"
|
||||
fi
|
||||
|
||||
if [[ "${FAST_TESTS}" == TRUE ]]; then
|
||||
fast_tests_argument=--fast-tests
|
||||
fi
|
||||
|
||||
CURR_DIR="$(dirname "$0")"
|
||||
ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")"
|
||||
|
||||
@@ -86,33 +93,7 @@ else
|
||||
fi
|
||||
|
||||
if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then
|
||||
if [[ "${FAST_TESTS}" != TRUE ]]; then
|
||||
filter_expression_small_params="""\
|
||||
(\
|
||||
test(/^shortint::.*_param${multi_bit}_message_1_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_4${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_5${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_6${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_ci_run_filter/) \
|
||||
)\
|
||||
and not test(~smart_add_and_mul)""" # This test is too slow
|
||||
else
|
||||
filter_expression_small_params="""\
|
||||
(\
|
||||
test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
)\
|
||||
and not test(~smart_add_and_mul)""" # This test is too slow
|
||||
fi
|
||||
filter_expression_small_params=$(/usr/bin/python3 scripts/test_filtering.py --layer shortint ${fast_tests_argument} ${multi_bit_argument})
|
||||
|
||||
# Run tests only no examples or benches with small params and more threads
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
@@ -151,34 +132,7 @@ and not test(~smart_add_and_mul)"""
|
||||
fi
|
||||
fi
|
||||
else
|
||||
if [[ "${FAST_TESTS}" != TRUE ]]; then
|
||||
filter_expression="""\
|
||||
(\
|
||||
test(/^shortint::.*_param${multi_bit}_message_1_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_4${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_5${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_1_carry_6${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_3_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_4_carry_4${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_ci_run_filter/) \
|
||||
)\
|
||||
and not test(~smart_add_and_mul)""" # This test is too slow
|
||||
else
|
||||
filter_expression="""\
|
||||
(\
|
||||
test(/^shortint::.*_param${multi_bit}_message_2_carry_1${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_2${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
or test(/^shortint::.*_param${multi_bit}_message_2_carry_3${multi_bit:+"_group_[0-9]"}(_compact_pk)?_ks_pbs/) \
|
||||
)\
|
||||
and not test(~smart_add_and_mul)""" # This test is too slow
|
||||
fi
|
||||
filter_expression=$(/usr/bin/python3 scripts/test_filtering.py --layer shortint --big-instance ${fast_tests_argument} ${multi_bit_argument})
|
||||
|
||||
# Run tests only no examples or benches with small params and more threads
|
||||
cargo "${RUST_TOOLCHAIN}" nextest run \
|
||||
|
||||
159
scripts/test_filtering.py
Normal file
159
scripts/test_filtering.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Script that generates a cargo-nextest filter as an output.
|
||||
The string result can be directly injected into a nextest command.
|
||||
"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(allow_abbrev=False)
|
||||
parser.add_argument(
|
||||
"--layer",
|
||||
dest="layer",
|
||||
choices=["integer", "shortint"],
|
||||
required=True,
|
||||
help="tfhe-rs layer to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--backend",
|
||||
dest="backend",
|
||||
choices=["cpu", "gpu"],
|
||||
default="cpu",
|
||||
help="tfhe-rs backend to use",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-tests",
|
||||
dest="fast_tests",
|
||||
action="store_true",
|
||||
help="Run only a small subset of test suite",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--big-instance",
|
||||
dest="big_instance",
|
||||
action="store_true",
|
||||
help="Backend is using a large instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--multi-bit",
|
||||
dest="multi_bit",
|
||||
action="store_true",
|
||||
help="Include tests running on multi-bit parameters set",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--signed-only",
|
||||
dest="signed_only",
|
||||
action="store_true",
|
||||
help="Include only signed integer tests",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unsigned-only",
|
||||
dest="unsigned_only",
|
||||
action="store_true",
|
||||
help="Include only unsigned integer tests",
|
||||
)
|
||||
|
||||
# block PBS are too slow for high params
|
||||
# mul_crt_4_4 is extremely flaky (~80% failure)
|
||||
# test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment
|
||||
# test_integer_smart_mul_param_message_4_carry_4_ks_pbs is too slow
|
||||
# so is test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs
|
||||
# skip smart_div, smart_rem which are already covered by the smar_div_rem test
|
||||
# skip default_div, default_rem which are covered by default_div_rem
|
||||
EXCLUDED_INTEGER_TESTS = [
|
||||
"/.*integer_smart_div_param/",
|
||||
"/.*integer_smart_rem_param/",
|
||||
"/.*integer_default_div_param/",
|
||||
"/.*integer_default_rem_param/",
|
||||
"/.*_block_pbs(_base)?_param_message_[34]_carry_[34]_ks_pbs$/",
|
||||
"~mul_crt_param_message_4_carry_4_ks_pbs",
|
||||
"/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]_ks_pbs$/",
|
||||
"/.*test_integer_smart_mul_param_message_4_carry_4_ks_pbs$/",
|
||||
"/.*test_integer_default_add_sequence_multi_thread_param_message_4_carry_4_ks_pbs$/",
|
||||
]
|
||||
|
||||
# skip default_div, default_rem which are covered by default_div_rem
|
||||
EXCLUDED_INTEGER_FAST_TESTS = [
|
||||
"/.*integer_default_div_param/",
|
||||
"/.*integer_default_rem_param/",
|
||||
"/.*_param_message_[14]_carry_[14]_ks_pbs$/",
|
||||
"/.*default_add_sequence_multi_thread_param_message_3_carry_3_ks_pbs$/",
|
||||
]
|
||||
|
||||
|
||||
def filter_integer_tests(input_args):
|
||||
multi_bit_filter = "_multi_bit" if input_args.multi_bit else ""
|
||||
backend_filter = ""
|
||||
if input_args.backend == "gpu":
|
||||
backend_filter = "gpu::"
|
||||
|
||||
filter_expression = [f"test(/^integer::{backend_filter}.*/)"]
|
||||
|
||||
if input_args.multi_bit:
|
||||
filter_expression.append("test(~_multi_bit)")
|
||||
else:
|
||||
filter_expression.append("not test(~_multi_bit)")
|
||||
|
||||
if input_args.signed_only:
|
||||
filter_expression.append("test(~_signed)")
|
||||
if input_args.unsigned_only:
|
||||
filter_expression.append("not test(~_signed)")
|
||||
|
||||
if input_args.fast_tests:
|
||||
# Test only fast default operations with only two set of parameters
|
||||
param_group = "_group_2" if input_args.multi_bit else ""
|
||||
filter_expression.append(
|
||||
f"test(/.*_default_.*?_param{multi_bit_filter}_message_[2-3]_carry_[2-3]{param_group}_ks_pbs/)"
|
||||
)
|
||||
|
||||
excluded_tests = (
|
||||
EXCLUDED_INTEGER_FAST_TESTS if input_args.fast_tests else EXCLUDED_INTEGER_TESTS
|
||||
)
|
||||
for pattern in excluded_tests:
|
||||
filter_expression.append(f"not test({pattern})")
|
||||
|
||||
return " and ".join(filter_expression)
|
||||
|
||||
|
||||
def filter_shortint_tests(input_args):
|
||||
multi_bit_filter, group_filter = (
|
||||
("_multi_bit", "_group_[0-9]") if input_args.multi_bit else ("", "")
|
||||
)
|
||||
|
||||
if input_args.fast_tests:
|
||||
msg_carry_pairs = [(2, 1), (2, 2), (2, 3)]
|
||||
else:
|
||||
msg_carry_pairs = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 3),
|
||||
(1, 4),
|
||||
(1, 5),
|
||||
(1, 6),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 3),
|
||||
(3, 1),
|
||||
(3, 2),
|
||||
(3, 3),
|
||||
]
|
||||
if input_args.big_instance:
|
||||
msg_carry_pairs.append((4, 4))
|
||||
|
||||
filter_expression = [
|
||||
f"test(/^shortint::.*_param{multi_bit_filter}_message_{msg}_carry_{carry}{group_filter}(_compact_pk)?_ks_pbs/)"
|
||||
for msg, carry in msg_carry_pairs
|
||||
]
|
||||
filter_expression.append("test(/^shortint::.*_ci_run_filter/)")
|
||||
|
||||
return " or ".join(filter_expression)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
expression = ""
|
||||
|
||||
if args.layer == "integer":
|
||||
expression = filter_integer_tests(args)
|
||||
elif args.layer == "shortint":
|
||||
expression = filter_shortint_tests(args)
|
||||
|
||||
print(expression)
|
||||
Reference in New Issue
Block a user