From d1cb55ba24b749498c8f5873b4491c139a53dba0 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Thu, 1 Jun 2023 17:11:48 +0200 Subject: [PATCH] chore(tfhe): add multi bit shortint and integer tests - default tests do not run multi bit PBS as it's not yet deterministic - only radix parallel currently use multi bit pbs in integer - remove determinism checks for some unchecked ops - 4_4 multi bit parameters are disabled for now as they seem to introduce too much noise --- .../workflows/aws_tfhe_multi_bit_tests.yml | 42 +- .github/workflows/trigger_aws_tests_on_pr.yml | 1 + Makefile | 20 +- ci/slab.toml | 5 + scripts/integer-tests.sh | 101 ++-- scripts/shortint-tests.sh | 151 +++--- tfhe/examples/generates_test_keys.rs | 118 +++-- .../server_key/radix_parallel/tests.rs | 442 ++++++++++++------ tfhe/src/shortint/server_key/tests.rs | 412 ++++++++++++---- 9 files changed, 921 insertions(+), 371 deletions(-) diff --git a/.github/workflows/aws_tfhe_multi_bit_tests.yml b/.github/workflows/aws_tfhe_multi_bit_tests.yml index 87b934911..65b3d2856 100644 --- a/.github/workflows/aws_tfhe_multi_bit_tests.yml +++ b/.github/workflows/aws_tfhe_multi_bit_tests.yml @@ -1,4 +1,4 @@ -name: AWS Tests on CPU +name: AWS Multi Bit Tests on CPU env: CARGO_TERM_COLOR: always @@ -48,3 +48,43 @@ jobs: echo "Request ID: ${{ inputs.request_id }}" echo "Fork repo: ${{ inputs.fork_repo }}" echo "Fork git sha: ${{ inputs.fork_git_sha }}" + + - name: Checkout tfhe-rs + uses: actions/checkout@8e5e7e5ab8b370d6c329ec480221332ada57f0ab + with: + repository: ${{ inputs.fork_repo }} + ref: ${{ inputs.fork_git_sha }} + + - name: Set up home + run: | + echo "HOME=/home/ubuntu" >> "${GITHUB_ENV}" + + - name: Install latest stable + uses: actions-rs/toolchain@16499b5e05bf2e26879000db0c1d13f7e13fa3af + with: + toolchain: stable + default: true + + - name: Gen Keys if required + run: | + MULTI_BIT_ONLY=TRUE make gen_key_cache + + - name: Run shortint multi-bit tests + run: | + make test_shortint_multi_bit_ci + + - name: Run integer multi-bit tests + run: | + make test_integer_multi_bit_ci + + - name: Slack Notification + if: ${{ always() }} + continue-on-error: true + uses: rtCamp/action-slack-notify@12e36fc18b0689399306c2e0b3e0f2978b7f1ee7 + env: + SLACK_COLOR: ${{ job.status }} + SLACK_CHANNEL: ${{ secrets.SLACK_CHANNEL }} + SLACK_ICON: https://pbs.twimg.com/profile_images/1274014582265298945/OjBKP9kn_400x400.png + SLACK_MESSAGE: "Shortint tests finished with status: ${{ job.status }}. (${{ env.ACTION_RUN_URL }})" + SLACK_USERNAME: ${{ secrets.BOT_USERNAME }} + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/.github/workflows/trigger_aws_tests_on_pr.yml b/.github/workflows/trigger_aws_tests_on_pr.yml index 2dee81587..65f1915f5 100644 --- a/.github/workflows/trigger_aws_tests_on_pr.yml +++ b/.github/workflows/trigger_aws_tests_on_pr.yml @@ -16,3 +16,4 @@ jobs: message: | @slab-ci cpu_test @slab-ci cpu_integer_test + @slab-ci cpu_multi_bit_test diff --git a/Makefile b/Makefile index 0fd2fab2b..9db26e76c 100644 --- a/Makefile +++ b/Makefile @@ -142,9 +142,13 @@ clippy_fast: clippy clippy_all_targets clippy_c_api clippy_js_wasm_api clippy_ta .PHONY: gen_key_cache # Run the script to generate keys and cache them for shortint tests gen_key_cache: install_rs_build_toolchain + if [[ "$${MULTI_BIT_ONLY}" == TRUE ]]; then \ + multi_bit_flag="--multi-bit-only"; \ + fi && \ RUSTFLAGS="$(RUSTFLAGS)" cargo $(CARGO_RS_BUILD_TOOLCHAIN) run --profile $(CARGO_PROFILE) \ --example generates_test_keys \ - --features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe + --features=$(TARGET_ARCH_FEATURE),shortint,internal-keycache -p tfhe -- \ + $${multi_bit_flag:+"$${multi_bit_flag}"} .PHONY: build_core # Build core_crypto without experimental features build_core: install_rs_build_toolchain install_rs_check_toolchain @@ -234,7 +238,12 @@ test_c_api: build_c_api .PHONY: test_shortint_ci # Run the tests for shortint ci test_shortint_ci: install_rs_build_toolchain install_cargo_nextest BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \ - ./scripts/shortint-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN) + ./scripts/shortint-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) + +.PHONY: test_shortint_multi_bit_ci # Run the tests for shortint ci running only multibit tests +test_shortint_multi_bit_ci: install_rs_build_toolchain install_cargo_nextest + BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \ + ./scripts/shortint-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) --multi-bit .PHONY: test_shortint # Run all the tests for shortint test_shortint: install_rs_build_toolchain @@ -244,7 +253,12 @@ test_shortint: install_rs_build_toolchain .PHONY: test_integer_ci # Run the tests for integer ci test_integer_ci: install_rs_build_toolchain install_cargo_nextest BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \ - ./scripts/integer-tests.sh $(CARGO_RS_BUILD_TOOLCHAIN) + ./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) + +.PHONY: test_integer_multi_bit_ci # Run the tests for integer ci running only multibit tests +test_integer_multi_bit_ci: install_rs_build_toolchain install_cargo_nextest + BIG_TESTS_INSTANCE="$(BIG_TESTS_INSTANCE)" \ + ./scripts/integer-tests.sh --rust-toolchain $(CARGO_RS_BUILD_TOOLCHAIN) --multi-bit .PHONY: test_integer # Run all the tests for integer test_integer: install_rs_build_toolchain diff --git a/ci/slab.toml b/ci/slab.toml index 68406e33f..886aa5687 100644 --- a/ci/slab.toml +++ b/ci/slab.toml @@ -18,6 +18,11 @@ workflow = "aws_tfhe_integer_tests.yml" profile = "cpu-big" check_run_name = "CPU Integer AWS Tests" +[command.cpu_multi_bit_test] +workflow = "aws_tfhe_multi_bit_tests.yml" +profile = "cpu-big" +check_run_name = "CPU AWS Multi Bit Tests" + [command.integer_bench] workflow = "integer_benchmark.yml" profile = "bench" diff --git a/scripts/integer-tests.sh b/scripts/integer-tests.sh index 05a0fbc0e..a145c7a9a 100755 --- a/scripts/integer-tests.sh +++ b/scripts/integer-tests.sh @@ -2,6 +2,49 @@ set -e +function usage() { + echo "$0: shortint test runner" + echo + echo "--help Print this message" + echo "--rust-toolchain The toolchain to run the tests with default: stable" + echo "--multi-bit Run multi-bit tests only: default off" + echo +} + +RUST_TOOLCHAIN="+stable" +multi_bit="" +not_multi_bit="_multi_bit" + +while [ -n "$1" ] +do + case "$1" in + "--help" | "-h" ) + usage + exit 0 + ;; + + "--rust-toolchain" ) + shift + RUST_TOOLCHAIN="$1" + ;; + + "--multi-bit" ) + multi_bit="_multi_bit" + not_multi_bit="" + ;; + + *) + echo "Unknown param : $1" + exit 1 + ;; + esac + shift +done + +if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then + RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}" +fi + CURR_DIR="$(dirname "$0")" ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")" @@ -29,14 +72,15 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then # mul_crt_4_4 is extremely flaky (~80% failure) # test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment # test_integer_smart_mul_param_message_4_carry_4 is too slow - filter_expression=''\ -'test(/^integer::.*$/)'\ -'and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/)'\ -'and not test(~mul_crt_param_message_4_carry_4)'\ -'and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/)'\ -'and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)' + filter_expression="""\ +test(/^integer::.*${multi_bit}/) \ +${not_multi_bit:+"and not test(~${not_multi_bit})"} \ +and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/) \ +and not test(~mul_crt_param_message_4_carry_4) \ +and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/) \ +and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)""" - cargo ${1:+"${1}"} nextest run \ + cargo "${RUST_TOOLCHAIN}" nextest run \ --tests \ --release \ --package tfhe \ @@ -45,27 +89,30 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then --test-threads "${n_threads}" \ -E "$filter_expression" - cargo ${1:+"${1}"} test \ - --release \ - --package tfhe \ - --features="${ARCH_FEATURE}",integer,internal-keycache \ - --doc \ - integer:: + if [[ "${multi_bit}" == "" ]]; then + cargo "${RUST_TOOLCHAIN}" test \ + --release \ + --package tfhe \ + --features="${ARCH_FEATURE}",integer,internal-keycache \ + --doc \ + integer:: + fi else # block pbs are too slow for high params # mul_crt_4_4 is extremely flaky (~80% failure) # test_wopbs_bivariate_crt_wopbs_param_message generate tables that are too big at the moment # test_integer_smart_mul_param_message_4_carry_4 is too slow - filter_expression=''\ -'test(/^integer::.*$/)'\ -'and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/)'\ -'and not test(~mul_crt_param_message_4_carry_4)'\ -'and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/)'\ -'and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)' + filter_expression="""\ +test(/^integer::.*${multi_bit}/) \ +${not_multi_bit:+"and not test(~${not_multi_bit})"} \ +and not test(/.*_block_pbs(_base)?_param_message_[34]_carry_[34]$/) \ +and not test(~mul_crt_param_message_4_carry_4) \ +and not test(/.*test_wopbs_bivariate_crt_wopbs_param_message_[34]_carry_[34]$/) \ +and not test(/.*test_integer_smart_mul_param_message_4_carry_4$/)""" num_cpu_threads="$(${nproc_bin})" num_threads=$((num_cpu_threads * 2 / 3)) - cargo ${1:+"${1}"} nextest run \ + cargo "${RUST_TOOLCHAIN}" nextest run \ --tests \ --release \ --package tfhe \ @@ -74,12 +121,14 @@ else --test-threads $num_threads \ -E "$filter_expression" - cargo ${1:+"${1}"} test \ - --release \ - --package tfhe \ - --features="${ARCH_FEATURE}",integer,internal-keycache \ - --doc \ - integer:: -- --test-threads="$(${nproc_bin})" + if [[ "${multi_bit}" == "" ]]; then + cargo "${RUST_TOOLCHAIN}" test \ + --release \ + --package tfhe \ + --features="${ARCH_FEATURE}",integer,internal-keycache \ + --doc \ + integer:: -- --test-threads="$(${nproc_bin})" + fi fi echo "Test ran in $SECONDS seconds" diff --git a/scripts/shortint-tests.sh b/scripts/shortint-tests.sh index a3cc25aee..0652f286d 100755 --- a/scripts/shortint-tests.sh +++ b/scripts/shortint-tests.sh @@ -2,6 +2,47 @@ set -e +function usage() { + echo "$0: shortint test runner" + echo + echo "--help Print this message" + echo "--rust-toolchain The toolchain to run the tests with default: stable" + echo "--multi-bit Run multi-bit tests only: default off" + echo +} + +RUST_TOOLCHAIN="+stable" +multi_bit="" + +while [ -n "$1" ] +do + case "$1" in + "--help" | "-h" ) + usage + exit 0 + ;; + + "--rust-toolchain" ) + shift + RUST_TOOLCHAIN="$1" + ;; + + "--multi-bit" ) + multi_bit="_multi_bit" + ;; + + *) + echo "Unknown param : $1" + exit 1 + ;; + esac + shift +done + +if [[ "${RUST_TOOLCHAIN::1}" != "+" ]]; then + RUST_TOOLCHAIN="+${RUST_TOOLCHAIN}" +fi + CURR_DIR="$(dirname "$0")" ARCH_FEATURE="$("${CURR_DIR}/get_arch_feature.sh")" @@ -31,25 +72,25 @@ else fi if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then - filter_expression_small_params=''\ -'('\ -' test(/^shortint::.*_param_message_1_carry_1$/)'\ -'or test(/^shortint::.*_param_message_1_carry_2$/)'\ -'or test(/^shortint::.*_param_message_1_carry_3$/)'\ -'or test(/^shortint::.*_param_message_1_carry_4$/)'\ -'or test(/^shortint::.*_param_message_1_carry_5$/)'\ -'or test(/^shortint::.*_param_message_1_carry_6$/)'\ -'or test(/^shortint::.*_param_message_2_carry_1$/)'\ -'or test(/^shortint::.*_param_message_2_carry_2$/)'\ -'or test(/^shortint::.*_param_message_2_carry_3$/)'\ -'or test(/^shortint::.*_param_message_3_carry_1$/)'\ -'or test(/^shortint::.*_param_message_3_carry_2$/)'\ -'or test(/^shortint::.*_param_message_3_carry_3$/)'\ -')'\ -'and not test(~smart_add_and_mul)' # This test is too slow + filter_expression_small_params="""\ +(\ + test(/^shortint::.*_param${multi_bit}_message_1_carry_1/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_2/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_3/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_4/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_5/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_6/) \ +or test(/^shortint::.*_param${multi_bit}_message_2_carry_1/) \ +or test(/^shortint::.*_param${multi_bit}_message_2_carry_2/) \ +or test(/^shortint::.*_param${multi_bit}_message_2_carry_3/) \ +or test(/^shortint::.*_param${multi_bit}_message_3_carry_1/) \ +or test(/^shortint::.*_param${multi_bit}_message_3_carry_2/) \ +or test(/^shortint::.*_param${multi_bit}_message_3_carry_3/) \ +) \ +and not test(~smart_add_and_mul)""" # This test is too slow # Run tests only no examples or benches with small params and more threads - cargo ${1:+"${1}"} nextest run \ + cargo "${RUST_TOOLCHAIN}" nextest run \ --tests \ --release \ --package tfhe \ @@ -58,14 +99,14 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then --test-threads "${n_threads_small}" \ -E "${filter_expression_small_params}" - filter_expression_big_params=''\ -'('\ -' test(/^shortint::.*_param_message_4_carry_4$/)'\ -')'\ -'and not test(~smart_add_and_mul)' + filter_expression_big_params="""\ +(\ + test(/^shortint::.*_param${multi_bit}_message_4_carry_4/) \ +) \ +and not test(~smart_add_and_mul)""" # Run tests only no examples or benches with big params and less threads - cargo ${1:+"${1}"} nextest run \ + cargo "${RUST_TOOLCHAIN}" nextest run \ --tests \ --release \ --package tfhe \ @@ -74,33 +115,35 @@ if [[ "${BIG_TESTS_INSTANCE}" != TRUE ]]; then --test-threads "${n_threads_big}" \ -E "${filter_expression_big_params}" - cargo ${1:+"${1}"} test \ - --release \ - --package tfhe \ - --features="${ARCH_FEATURE}",shortint,internal-keycache \ - --doc \ - shortint:: + if [[ "${multi_bit}" == "" ]]; then + cargo "${RUST_TOOLCHAIN}" test \ + --release \ + --package tfhe \ + --features="${ARCH_FEATURE}",shortint,internal-keycache \ + --doc \ + shortint:: + fi else - filter_expression=''\ -'('\ -' test(/^shortint::.*_param_message_1_carry_1$/)'\ -'or test(/^shortint::.*_param_message_1_carry_2$/)'\ -'or test(/^shortint::.*_param_message_1_carry_3$/)'\ -'or test(/^shortint::.*_param_message_1_carry_4$/)'\ -'or test(/^shortint::.*_param_message_1_carry_5$/)'\ -'or test(/^shortint::.*_param_message_1_carry_6$/)'\ -'or test(/^shortint::.*_param_message_2_carry_1$/)'\ -'or test(/^shortint::.*_param_message_2_carry_2$/)'\ -'or test(/^shortint::.*_param_message_2_carry_3$/)'\ -'or test(/^shortint::.*_param_message_3_carry_1$/)'\ -'or test(/^shortint::.*_param_message_3_carry_2$/)'\ -'or test(/^shortint::.*_param_message_3_carry_3$/)'\ -'or test(/^shortint::.*_param_message_4_carry_4$/)'\ -')'\ -'and not test(~smart_add_and_mul)' # This test is too slow + filter_expression="""\ +(\ + test(/^shortint::.*_param${multi_bit}_message_1_carry_1/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_2/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_3/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_4/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_5/) \ +or test(/^shortint::.*_param${multi_bit}_message_1_carry_6/) \ +or test(/^shortint::.*_param${multi_bit}_message_2_carry_1/) \ +or test(/^shortint::.*_param${multi_bit}_message_2_carry_2/) \ +or test(/^shortint::.*_param${multi_bit}_message_2_carry_3/) \ +or test(/^shortint::.*_param${multi_bit}_message_3_carry_1/) \ +or test(/^shortint::.*_param${multi_bit}_message_3_carry_2/) \ +or test(/^shortint::.*_param${multi_bit}_message_3_carry_3/) \ +or test(/^shortint::.*_param${multi_bit}_message_4_carry_4/) \ +)\ +and not test(~smart_add_and_mul)""" # This test is too slow # Run tests only no examples or benches with small params and more threads - cargo ${1:+"${1}"} nextest run \ + cargo "${RUST_TOOLCHAIN}" nextest run \ --tests \ --release \ --package tfhe \ @@ -109,12 +152,14 @@ else --test-threads "$(${nproc_bin})" \ -E "${filter_expression}" - cargo ${1:+"${1}"} test \ - --release \ - --package tfhe \ - --features="${ARCH_FEATURE}",shortint,internal-keycache \ - --doc \ - shortint:: -- --test-threads="$(${nproc_bin})" + if [[ "${multi_bit}" == "" ]]; then + cargo "${RUST_TOOLCHAIN}" test \ + --release \ + --package tfhe \ + --features="${ARCH_FEATURE}",shortint,internal-keycache \ + --doc \ + shortint:: -- --test-threads="$(${nproc_bin})" + fi fi echo "Test ran in $SECONDS seconds" diff --git a/tfhe/examples/generates_test_keys.rs b/tfhe/examples/generates_test_keys.rs index 6a4dcb69c..6b4d23332 100644 --- a/tfhe/examples/generates_test_keys.rs +++ b/tfhe/examples/generates_test_keys.rs @@ -1,3 +1,4 @@ +use clap::{Arg, ArgAction, Command}; use tfhe::shortint::keycache::{NamedParam, KEY_CACHE, KEY_CACHE_WOPBS}; use tfhe::shortint::parameters::parameters_wopbs_message_carry::{ WOPBS_PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_3_CARRY_3, @@ -10,75 +11,90 @@ use tfhe::shortint::parameters::{ }; fn client_server_keys() { - println!("Generating shortint (ClientKey, ServerKey)"); - for (i, params) in ALL_PARAMETER_VEC.iter().copied().enumerate() { - println!( - "Generating [{} / {}] : {}", - i + 1, - ALL_PARAMETER_VEC.len(), - params.name() - ); + let matches = Command::new("test key gen") + .arg( + Arg::new("multi_bit_only") + .long("multi-bit-only") + .help("Set to generate only multi bit keys, otherwise only PBS and WoPBS keys are generated") + .action(ArgAction::SetTrue), + ) + .get_matches(); - let start = std::time::Instant::now(); + // If set using the command line flag "--ladner-fischer" this algorithm will be used in + // additions + let multi_bit_only: bool = matches.get_flag("multi_bit_only"); - let _ = KEY_CACHE.get_from_param(params); + if multi_bit_only { + println!("Generating shortint multibit (ClientKey, ServerKey)"); + for (i, params) in ALL_MULTI_BIT_PARAMETER_VEC.iter().copied().enumerate() { + println!( + "Generating [{} / {}] : {}", + i + 1, + ALL_MULTI_BIT_PARAMETER_VEC.len(), + params.name() + ); - let stop = start.elapsed().as_secs(); + let start = std::time::Instant::now(); - println!("Generation took {stop} seconds"); + let _ = KEY_CACHE.get_from_param(params); - // Clear keys as we go to avoid filling the RAM - KEY_CACHE.clear_in_memory_cache() - } + let stop = start.elapsed().as_secs(); - println!("Generating shortint multibit (ClientKey, ServerKey)"); - for (i, params) in ALL_MULTI_BIT_PARAMETER_VEC.iter().copied().enumerate() { - println!( - "Generating [{} / {}] : {}", - i + 1, - ALL_MULTI_BIT_PARAMETER_VEC.len(), - params.name() - ); + println!("Generation took {stop} seconds"); - let start = std::time::Instant::now(); + // Clear keys as we go to avoid filling the RAM + KEY_CACHE.clear_in_memory_cache() + } + } else { + println!("Generating shortint (ClientKey, ServerKey)"); + for (i, params) in ALL_PARAMETER_VEC.iter().copied().enumerate() { + println!( + "Generating [{} / {}] : {}", + i + 1, + ALL_PARAMETER_VEC.len(), + params.name() + ); - let _ = KEY_CACHE.get_from_param(params); + let start = std::time::Instant::now(); - let stop = start.elapsed().as_secs(); + let _ = KEY_CACHE.get_from_param(params); - println!("Generation took {stop} seconds"); + let stop = start.elapsed().as_secs(); - // Clear keys as we go to avoid filling the RAM - KEY_CACHE.clear_in_memory_cache() - } + println!("Generation took {stop} seconds"); - const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 4] = [ - (PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_1_CARRY_1), - (PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2), - (PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3), - (PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4), - ]; + // Clear keys as we go to avoid filling the RAM + KEY_CACHE.clear_in_memory_cache() + } - println!("Generating woPBS keys"); - for (i, (params_shortint, params_wopbs)) in WOPBS_PARAMS.iter().copied().enumerate() { - println!( - "Generating [{} / {}] : {}, {}", - i + 1, - WOPBS_PARAMS.len(), - params_shortint.name(), - params_wopbs.name(), - ); + const WOPBS_PARAMS: [(ClassicPBSParameters, WopbsParameters); 4] = [ + (PARAM_MESSAGE_1_CARRY_1, WOPBS_PARAM_MESSAGE_1_CARRY_1), + (PARAM_MESSAGE_2_CARRY_2, WOPBS_PARAM_MESSAGE_2_CARRY_2), + (PARAM_MESSAGE_3_CARRY_3, WOPBS_PARAM_MESSAGE_3_CARRY_3), + (PARAM_MESSAGE_4_CARRY_4, WOPBS_PARAM_MESSAGE_4_CARRY_4), + ]; - let start = std::time::Instant::now(); + println!("Generating woPBS keys"); + for (i, (params_shortint, params_wopbs)) in WOPBS_PARAMS.iter().copied().enumerate() { + println!( + "Generating [{} / {}] : {}, {}", + i + 1, + WOPBS_PARAMS.len(), + params_shortint.name(), + params_wopbs.name(), + ); - let _ = KEY_CACHE_WOPBS.get_from_param((params_shortint, params_wopbs)); + let start = std::time::Instant::now(); - let stop = start.elapsed().as_secs(); + let _ = KEY_CACHE_WOPBS.get_from_param((params_shortint, params_wopbs)); - println!("Generation took {stop} seconds"); + let stop = start.elapsed().as_secs(); - // Clear keys as we go to avoid filling the RAM - KEY_CACHE_WOPBS.clear_in_memory_cache() + println!("Generation took {stop} seconds"); + + // Clear keys as we go to avoid filling the RAM + KEY_CACHE_WOPBS.clear_in_memory_cache() + } } } diff --git a/tfhe/src/integer/server_key/radix_parallel/tests.rs b/tfhe/src/integer/server_key/radix_parallel/tests.rs index 4621209df..d5bcc4d3d 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests.rs @@ -1,7 +1,6 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::{RadixClientKey, ServerKey}; use crate::shortint::parameters::*; -use crate::shortint::ClassicPBSParameters; use paste::paste; use rand::Rng; @@ -27,6 +26,27 @@ macro_rules! create_parametrized_test{ ($name:ident)=> { create_parametrized_test!($name { + PARAM_MESSAGE_1_CARRY_1, + PARAM_MESSAGE_2_CARRY_2, + PARAM_MESSAGE_3_CARRY_3, + PARAM_MESSAGE_4_CARRY_4, + PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2, + // // These parameters seem to introduce too much noise during computation + // PARAM_MULTI_BIT_MESSAGE_4_CARRY_4_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3 + // // These parameters seem to introduce too much noise during computation + // PARAM_MULTI_BIT_MESSAGE_4_CARRY_4_GROUP_3 + }); + }; +} + +macro_rules! create_parametrized_test_no_multi_bit { + ($name:ident) => { + create_parametrized_test!($name { PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_3_CARRY_3, @@ -38,14 +58,14 @@ macro_rules! create_parametrized_test{ create_parametrized_test!(integer_smart_add); create_parametrized_test!(integer_smart_add_sequence_multi_thread); create_parametrized_test!(integer_smart_add_sequence_single_thread); -create_parametrized_test!(integer_default_add); +create_parametrized_test_no_multi_bit!(integer_default_add); create_parametrized_test!(integer_default_add_work_efficient { // This algorithm requires 3 bits PARAM_MESSAGE_2_CARRY_2, PARAM_MESSAGE_3_CARRY_3, PARAM_MESSAGE_4_CARRY_4 }); -create_parametrized_test!(integer_default_add_sequence_multi_thread); +create_parametrized_test_no_multi_bit!(integer_default_add_sequence_multi_thread); // Other tests are pretty slow, and the code is the same as a smart add but slower #[test] fn test_integer_default_add_sequence_single_thread_param_message_2_carry_2() { @@ -54,12 +74,12 @@ fn test_integer_default_add_sequence_single_thread_param_message_2_carry_2() { create_parametrized_test!(integer_smart_bitand); create_parametrized_test!(integer_smart_bitor); create_parametrized_test!(integer_smart_bitxor); -create_parametrized_test!(integer_default_bitand); -create_parametrized_test!(integer_default_bitor); -create_parametrized_test!(integer_default_bitxor); +create_parametrized_test_no_multi_bit!(integer_default_bitand); +create_parametrized_test_no_multi_bit!(integer_default_bitor); +create_parametrized_test_no_multi_bit!(integer_default_bitxor); create_parametrized_test!(integer_unchecked_small_scalar_mul); create_parametrized_test!(integer_smart_small_scalar_mul); -create_parametrized_test!(integer_default_small_scalar_mul); +create_parametrized_test_no_multi_bit!(integer_default_small_scalar_mul); create_parametrized_test!(integer_smart_scalar_mul_u128_fix_non_reg_test { PARAM_MESSAGE_1_CARRY_1, PARAM_MESSAGE_2_CARRY_2 @@ -69,12 +89,12 @@ create_parametrized_test!(integer_default_scalar_mul_u128_fix_non_reg_test { PARAM_MESSAGE_2_CARRY_2 }); create_parametrized_test!(integer_smart_scalar_mul); -create_parametrized_test!(integer_default_scalar_mul); +create_parametrized_test_no_multi_bit!(integer_default_scalar_mul); // scalar left/right shifts create_parametrized_test!(integer_unchecked_scalar_left_shift); -create_parametrized_test!(integer_default_scalar_left_shift); +create_parametrized_test_no_multi_bit!(integer_default_scalar_left_shift); create_parametrized_test!(integer_unchecked_scalar_right_shift); -create_parametrized_test!(integer_default_scalar_right_shift); +create_parametrized_test_no_multi_bit!(integer_default_scalar_right_shift); // left/right shifts create_parametrized_test!(integer_unchecked_left_shift { // This algorithm requires 3 bits @@ -104,13 +124,13 @@ create_parametrized_test!(integer_unchecked_rotate_right { // left/right rotations create_parametrized_test!(integer_unchecked_scalar_rotate_right); create_parametrized_test!(integer_unchecked_scalar_rotate_left); -create_parametrized_test!(integer_scalar_rotate_right); -create_parametrized_test!(integer_scalar_rotate_left); +create_parametrized_test_no_multi_bit!(integer_default_scalar_rotate_right); +create_parametrized_test_no_multi_bit!(integer_default_scalar_rotate_left); // negations create_parametrized_test!(integer_smart_neg); -create_parametrized_test!(integer_default_neg); +create_parametrized_test_no_multi_bit!(integer_default_neg); create_parametrized_test!(integer_smart_sub); -create_parametrized_test!(integer_default_sub); +create_parametrized_test_no_multi_bit!(integer_default_sub); create_parametrized_test!(integer_default_sub_work_efficient { // This algorithm requires 3 bits PARAM_MESSAGE_2_CARRY_2, @@ -119,19 +139,18 @@ create_parametrized_test!(integer_default_sub_work_efficient { }); create_parametrized_test!(integer_unchecked_block_mul); create_parametrized_test!(integer_smart_block_mul); -create_parametrized_test!(integer_default_block_mul); +create_parametrized_test_no_multi_bit!(integer_default_block_mul); create_parametrized_test!(integer_smart_mul); -#[test] -fn test_integer_smart_mul_param_multi_bit_message_2_carry_2_group_2() { - integer_smart_mul(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2) -} -create_parametrized_test!(integer_default_mul); +create_parametrized_test_no_multi_bit!(integer_default_mul); create_parametrized_test!(integer_smart_scalar_sub); -create_parametrized_test!(integer_default_scalar_sub); +create_parametrized_test_no_multi_bit!(integer_default_scalar_sub); create_parametrized_test!(integer_smart_scalar_add); -create_parametrized_test!(integer_default_scalar_add); +create_parametrized_test_no_multi_bit!(integer_default_scalar_add); -fn integer_smart_add(param: ClassicPBSParameters) { +fn integer_smart_add

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -139,7 +158,7 @@ fn integer_smart_add(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -174,7 +193,10 @@ fn integer_smart_add(param: ClassicPBSParameters) { } } -fn integer_smart_add_sequence_multi_thread(param: ClassicPBSParameters) { +fn integer_smart_add_sequence_multi_thread

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -182,7 +204,7 @@ fn integer_smart_add_sequence_multi_thread(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for len in [1, 2, 15, 16, 17, 64, 65] { for _ in 0..NB_TEST_SMALLER { @@ -208,7 +230,10 @@ fn integer_smart_add_sequence_multi_thread(param: ClassicPBSParameters) { } } -fn integer_smart_add_sequence_single_thread(param: ClassicPBSParameters) { +fn integer_smart_add_sequence_single_thread

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -216,7 +241,7 @@ fn integer_smart_add_sequence_single_thread(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for len in [1, 2, 15, 16, 17] { for _ in 0..NB_TEST_SMALLER { @@ -248,7 +273,10 @@ fn integer_smart_add_sequence_single_thread(param: ClassicPBSParameters) { } } -fn integer_default_add(param: ClassicPBSParameters) { +fn integer_default_add

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -256,7 +284,7 @@ fn integer_default_add(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -296,7 +324,10 @@ fn integer_default_add(param: ClassicPBSParameters) { } // Smaller test for this one -fn integer_default_add_work_efficient(param: ClassicPBSParameters) { +fn integer_default_add_work_efficient

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -304,7 +335,7 @@ fn integer_default_add_work_efficient(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST_SMALLER { let clear_0 = rng.gen::() % modulus; @@ -326,7 +357,10 @@ fn integer_default_add_work_efficient(param: ClassicPBSParameters) { } } -fn integer_default_add_sequence_multi_thread(param: ClassicPBSParameters) { +fn integer_default_add_sequence_multi_thread

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -334,7 +368,7 @@ fn integer_default_add_sequence_multi_thread(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for len in [1, 2, 15, 16, 17, 64, 65] { for _ in 0..NB_TEST_SMALLER { @@ -365,7 +399,10 @@ fn integer_default_add_sequence_multi_thread(param: ClassicPBSParameters) { } } -fn integer_default_add_sequence_single_thread(param: ClassicPBSParameters) { +fn integer_default_add_sequence_single_thread

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -373,7 +410,7 @@ fn integer_default_add_sequence_single_thread(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for len in [1, 2, 15, 16, 17] { for _ in 0..NB_TEST_SMALLER { @@ -406,7 +443,10 @@ fn integer_default_add_sequence_single_thread(param: ClassicPBSParameters) { } } -fn integer_smart_bitand(param: ClassicPBSParameters) { +fn integer_smart_bitand

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -414,7 +454,7 @@ fn integer_smart_bitand(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -451,7 +491,10 @@ fn integer_smart_bitand(param: ClassicPBSParameters) { } } -fn integer_smart_bitor(param: ClassicPBSParameters) { +fn integer_smart_bitor

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -459,7 +502,7 @@ fn integer_smart_bitor(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -496,7 +539,10 @@ fn integer_smart_bitor(param: ClassicPBSParameters) { } } -fn integer_smart_bitxor(param: ClassicPBSParameters) { +fn integer_smart_bitxor

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -504,7 +550,7 @@ fn integer_smart_bitxor(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -541,7 +587,10 @@ fn integer_smart_bitxor(param: ClassicPBSParameters) { } } -fn integer_default_bitand(param: ClassicPBSParameters) { +fn integer_default_bitand

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -549,7 +598,7 @@ fn integer_default_bitand(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -590,7 +639,10 @@ fn integer_default_bitand(param: ClassicPBSParameters) { } } -fn integer_default_bitor(param: ClassicPBSParameters) { +fn integer_default_bitor

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -598,7 +650,7 @@ fn integer_default_bitor(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -639,7 +691,10 @@ fn integer_default_bitor(param: ClassicPBSParameters) { } } -fn integer_default_bitxor(param: ClassicPBSParameters) { +fn integer_default_bitxor

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -647,7 +702,7 @@ fn integer_default_bitxor(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -688,7 +743,10 @@ fn integer_default_bitxor(param: ClassicPBSParameters) { } } -fn integer_unchecked_small_scalar_mul(param: ClassicPBSParameters) { +fn integer_unchecked_small_scalar_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -696,9 +754,9 @@ fn integer_unchecked_small_scalar_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let scalar_modulus = param.message_modulus.0 as u64; + let scalar_modulus = cks.parameters().message_modulus().0 as u64; for _ in 0..NB_TEST { let clear = rng.gen::() % modulus; @@ -718,7 +776,10 @@ fn integer_unchecked_small_scalar_mul(param: ClassicPBSParameters) { } } -fn integer_smart_small_scalar_mul(param: ClassicPBSParameters) { +fn integer_smart_small_scalar_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -726,9 +787,9 @@ fn integer_smart_small_scalar_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let scalar_modulus = param.message_modulus.0 as u64; + let scalar_modulus = cks.parameters().message_modulus().0 as u64; let mut clear_res; for _ in 0..NB_TEST_SMALLER { @@ -756,7 +817,10 @@ fn integer_smart_small_scalar_mul(param: ClassicPBSParameters) { } } -fn integer_default_small_scalar_mul(param: ClassicPBSParameters) { +fn integer_default_small_scalar_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -764,9 +828,9 @@ fn integer_default_small_scalar_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let scalar_modulus = param.message_modulus.0 as u64; + let scalar_modulus = cks.parameters().message_modulus().0 as u64; let mut clear_res; for _ in 0..NB_TEST_SMALLER { @@ -798,7 +862,10 @@ fn integer_default_small_scalar_mul(param: ClassicPBSParameters) { } } -fn integer_smart_scalar_mul(param: ClassicPBSParameters) { +fn integer_smart_scalar_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -806,7 +873,7 @@ fn integer_smart_scalar_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST { let clear = rng.gen::() % modulus; @@ -827,7 +894,10 @@ fn integer_smart_scalar_mul(param: ClassicPBSParameters) { } } -fn integer_default_scalar_mul(param: ClassicPBSParameters) { +fn integer_default_scalar_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -835,7 +905,7 @@ fn integer_default_scalar_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST { let clear = rng.gen::() % modulus; @@ -859,14 +929,18 @@ fn integer_default_scalar_mul(param: ClassicPBSParameters) { } } -fn integer_unchecked_mul_corner_cases(param: ClassicPBSParameters) { +fn integer_unchecked_mul_corner_cases

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); // This example will not pass if the terms reduction is wrong // on the chunk size it uses to reduce the 'terms' resulting // from blockmuls { - let nb_ct = (128f64 / (param.message_modulus.0 as f64).log2().ceil()).ceil() as usize; + let nb_ct = + (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let clear = 307096569525960547621731375222677666984u128; let scalar = 5207034748027904122u64; @@ -875,7 +949,8 @@ fn integer_unchecked_mul_corner_cases(param: ClassicPBSParameters) { let dec_res: u128 = cks.decrypt_radix(&ct_res); assert_eq!(clear.wrapping_mul(scalar as u128), dec_res); - let nb_ct = (128f64 / (param.message_modulus.0 as f64).log2().ceil()).ceil() as usize; + let nb_ct = + (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let clear = 307096569525960547621731375222677666984u128; let scalar = 5207034748027904122u64; @@ -888,7 +963,8 @@ fn integer_unchecked_mul_corner_cases(param: ClassicPBSParameters) { } { - let nb_ct = (128f64 / (param.message_modulus.0 as f64).log2().ceil()).ceil() as usize; + let nb_ct = + (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let clear = u128::MAX; let scalar = u64::MAX; @@ -910,7 +986,8 @@ fn integer_unchecked_mul_corner_cases(param: ClassicPBSParameters) { // Trying to multiply a ciphertext with a scalar value // bigger than the ciphertext modulus should work { - let nb_ct = (8f64 / (param.message_modulus.0 as f64).log2().ceil()).ceil() as usize; + let nb_ct = + (8f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let clear = 123u64; let scalar = 17823812983255694336u64; assert_eq!(scalar % 256, 0); @@ -922,9 +999,13 @@ fn integer_unchecked_mul_corner_cases(param: ClassicPBSParameters) { } } -fn integer_smart_scalar_mul_u128_fix_non_reg_test(param: ClassicPBSParameters) { +fn integer_smart_scalar_mul_u128_fix_non_reg_test

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); - let nb_ct = (128f64 / (param.message_modulus.0 as f64).log2().ceil()).ceil() as usize; + let nb_ct = + (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let cks = RadixClientKey::from((cks, nb_ct)); //RNG @@ -947,9 +1028,13 @@ fn integer_smart_scalar_mul_u128_fix_non_reg_test(param: ClassicPBSParameters) { assert_eq!(clear.wrapping_mul(scalar as u128), dec_res); } -fn integer_default_scalar_mul_u128_fix_non_reg_test(param: ClassicPBSParameters) { +fn integer_default_scalar_mul_u128_fix_non_reg_test

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); - let nb_ct = (128f64 / (param.message_modulus.0 as f64).log2().ceil()).ceil() as usize; + let nb_ct = + (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let cks = RadixClientKey::from((cks, nb_ct)); //RNG @@ -972,14 +1057,17 @@ fn integer_default_scalar_mul_u128_fix_non_reg_test(param: ClassicPBSParameters) assert_eq!(clear.wrapping_mul(scalar as u128), dec_res); } -fn integer_unchecked_scalar_left_shift(param: ClassicPBSParameters) { +fn integer_unchecked_scalar_left_shift

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); for _ in 0..NB_TEST { @@ -1008,7 +1096,7 @@ fn integer_unchecked_scalar_left_shift(param: ClassicPBSParameters) { let clear = rng.gen::() % modulus; let ct = cks.encrypt(clear); - let nb_bits_in_block = param.message_modulus.0.ilog2(); + let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2(); for scalar in 0..nb_bits_in_block { let ct_res = sks.unchecked_scalar_left_shift_parallelized(&ct, scalar as u64); let dec_res: u64 = cks.decrypt(&ct_res); @@ -1016,14 +1104,17 @@ fn integer_unchecked_scalar_left_shift(param: ClassicPBSParameters) { } } -fn integer_unchecked_left_shift(param: PBSParameters) { +fn integer_unchecked_left_shift

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; assert!(modulus.is_power_of_two()); let nb_bits = modulus.ilog2(); @@ -1061,14 +1152,17 @@ fn integer_unchecked_left_shift(param: PBSParameters) { } } -fn integer_unchecked_right_shift(param: PBSParameters) { +fn integer_unchecked_right_shift

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; assert!(modulus.is_power_of_two()); let nb_bits = modulus.ilog2(); @@ -1107,14 +1201,17 @@ fn integer_unchecked_right_shift(param: PBSParameters) { } } -fn integer_unchecked_rotate_left(param: PBSParameters) { +fn integer_unchecked_rotate_left

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; assert!(modulus.is_power_of_two()); let nb_bits = modulus.ilog2(); @@ -1153,14 +1250,17 @@ fn integer_unchecked_rotate_left(param: PBSParameters) { } } -fn integer_unchecked_rotate_right(param: PBSParameters) { +fn integer_unchecked_rotate_right

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; assert!(modulus.is_power_of_two()); let nb_bits = modulus.ilog2(); @@ -1199,13 +1299,16 @@ fn integer_unchecked_rotate_right(param: PBSParameters) { } } -fn integer_default_scalar_left_shift(param: ClassicPBSParameters) { +fn integer_default_scalar_left_shift

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); for _ in 0..NB_TEST { @@ -1240,7 +1343,7 @@ fn integer_default_scalar_left_shift(param: ClassicPBSParameters) { let clear = rng.gen::() % modulus; let ct = cks.encrypt(clear); - let nb_bits_in_block = param.message_modulus.0.ilog2(); + let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2(); for scalar in 0..nb_bits_in_block { let ct_res = sks.scalar_left_shift_parallelized(&ct, scalar as u64); let dec_res: u64 = cks.decrypt(&ct_res); @@ -1248,14 +1351,17 @@ fn integer_default_scalar_left_shift(param: ClassicPBSParameters) { } } -fn integer_unchecked_scalar_right_shift(param: ClassicPBSParameters) { +fn integer_unchecked_scalar_right_shift

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); for _ in 0..NB_TEST { @@ -1268,9 +1374,7 @@ fn integer_unchecked_scalar_right_shift(param: ClassicPBSParameters) { { let scalar = scalar % nb_bits; let ct_res = sks.unchecked_scalar_right_shift_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_right_shift_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); assert_eq!(clear.checked_shr(scalar).unwrap_or(0) % modulus, dec_res); } @@ -1279,9 +1383,7 @@ fn integer_unchecked_scalar_right_shift(param: ClassicPBSParameters) { { let scalar = scalar.saturating_add(nb_bits); let ct_res = sks.unchecked_scalar_right_shift_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_right_shift_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); assert_eq!(clear.wrapping_shr(scalar % nb_bits) % modulus, dec_res); } @@ -1290,25 +1392,26 @@ fn integer_unchecked_scalar_right_shift(param: ClassicPBSParameters) { let clear = rng.gen::() % modulus; let ct = cks.encrypt(clear); - let nb_bits_in_block = param.message_modulus.0.ilog2(); + let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2(); for scalar in 0..nb_bits_in_block { let ct_res = sks.unchecked_scalar_right_shift_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_right_shift_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); assert_eq!(clear.checked_shr(scalar).unwrap_or(0) % modulus, dec_res); } } -fn integer_default_scalar_right_shift(param: ClassicPBSParameters) { +fn integer_default_scalar_right_shift

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); for _ in 0..NB_TEST { @@ -1343,7 +1446,7 @@ fn integer_default_scalar_right_shift(param: ClassicPBSParameters) { let clear = rng.gen::() % modulus; let ct = cks.encrypt(clear); - let nb_bits_in_block = param.message_modulus.0.ilog2(); + let nb_bits_in_block = cks.parameters().message_modulus().0.ilog2(); for scalar in 0..nb_bits_in_block { let ct_res = sks.scalar_right_shift_parallelized(&ct, scalar as u64); let tmp = sks.scalar_right_shift_parallelized(&ct, scalar as u64); @@ -1399,16 +1502,19 @@ fn rotate_right_helper(value: u64, n: u32, actual_bit_size: u32) -> u64 { (rotated & mask) | ((rotated & shifted_mask) >> (u64::BITS - actual_bit_size)) } -fn integer_unchecked_scalar_rotate_right(param: ClassicPBSParameters) { +fn integer_unchecked_scalar_rotate_right

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); - let bits_per_block = param.message_modulus.0.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); for _ in 0..(NB_TEST / 3).max(1) { let clear = rng.gen::() % modulus; @@ -1420,9 +1526,7 @@ fn integer_unchecked_scalar_rotate_right(param: ClassicPBSParameters) { { let scalar = scalar - (scalar % bits_per_block); let ct_res = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); let expected = rotate_right_helper(clear, scalar, nb_bits); assert_eq!(expected, dec_res); @@ -1437,9 +1541,7 @@ fn integer_unchecked_scalar_rotate_right(param: ClassicPBSParameters) { scalar }; let ct_res = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); let expected = rotate_right_helper(clear, scalar, nb_bits); assert_eq!(expected, dec_res); @@ -1452,9 +1554,7 @@ fn integer_unchecked_scalar_rotate_right(param: ClassicPBSParameters) { let value = rng.gen_range(1..=u32::MAX); let scalar = value.trailing_zeros() + rng.gen_range(1..nb_bits); let ct_res = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); let expected = rotate_right_helper(clear, scalar, nb_bits); assert_eq!(expected, dec_res); @@ -1462,16 +1562,19 @@ fn integer_unchecked_scalar_rotate_right(param: ClassicPBSParameters) { } } -fn integer_unchecked_scalar_rotate_left(param: ClassicPBSParameters) { +fn integer_unchecked_scalar_rotate_left

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); - let bits_per_block = param.message_modulus.0.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); for _ in 0..(NB_TEST / 3).max(1) { let clear = rng.gen::() % modulus; @@ -1483,9 +1586,7 @@ fn integer_unchecked_scalar_rotate_left(param: ClassicPBSParameters) { { let scalar = scalar - (scalar % bits_per_block); let ct_res = sks.unchecked_scalar_rotate_left_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_rotate_left_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); let expected = rotate_left_helper(clear, scalar, nb_bits); assert_eq!(expected, dec_res); @@ -1500,9 +1601,7 @@ fn integer_unchecked_scalar_rotate_left(param: ClassicPBSParameters) { scalar }; let ct_res = sks.unchecked_scalar_rotate_left_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_rotate_left_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); let expected = rotate_left_helper(clear, scalar, nb_bits); assert_eq!(expected, dec_res); @@ -1515,9 +1614,7 @@ fn integer_unchecked_scalar_rotate_left(param: ClassicPBSParameters) { let value = rng.gen_range(1..=u32::MAX); let scalar = value.leading_zeros() + rng.gen_range(1..nb_bits); let ct_res = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); - let tmp = sks.unchecked_scalar_rotate_right_parallelized(&ct, scalar as u64); assert!(ct_res.block_carries_are_empty()); - assert_eq!(ct_res, tmp); let dec_res: u64 = cks.decrypt(&ct_res); let expected = rotate_right_helper(clear, scalar, nb_bits); assert_eq!(expected, dec_res); @@ -1525,16 +1622,19 @@ fn integer_unchecked_scalar_rotate_left(param: ClassicPBSParameters) { } } -fn integer_scalar_rotate_right(param: ClassicPBSParameters) { +fn integer_default_scalar_rotate_right

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); - let bits_per_block = param.message_modulus.0.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); for _ in 0..(NB_TEST / 2).max(1) { let clear = rng.gen::() % modulus; @@ -1588,16 +1688,19 @@ fn integer_scalar_rotate_right(param: ClassicPBSParameters) { } } -fn integer_scalar_rotate_left(param: ClassicPBSParameters) { +fn integer_default_scalar_rotate_left

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let nb_bits = modulus.ilog2(); - let bits_per_block = param.message_modulus.0.ilog2(); + let bits_per_block = cks.parameters().message_modulus().0.ilog2(); for _ in 0..(NB_TEST / 3).max(1) { let clear = rng.gen::() % modulus; @@ -1651,7 +1754,10 @@ fn integer_scalar_rotate_left(param: ClassicPBSParameters) { } } -fn integer_smart_neg(param: ClassicPBSParameters) { +fn integer_smart_neg

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1659,7 +1765,7 @@ fn integer_smart_neg(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST { // Define the cleartexts @@ -1681,7 +1787,10 @@ fn integer_smart_neg(param: ClassicPBSParameters) { } } -fn integer_default_neg(param: ClassicPBSParameters) { +fn integer_default_neg

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1689,7 +1798,7 @@ fn integer_default_neg(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST { // Define the cleartexts @@ -1714,7 +1823,10 @@ fn integer_default_neg(param: ClassicPBSParameters) { } } -fn integer_smart_sub(param: ClassicPBSParameters) { +fn integer_smart_sub

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1722,7 +1834,7 @@ fn integer_smart_sub(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST_SMALLER { // Define the cleartexts @@ -1749,7 +1861,10 @@ fn integer_smart_sub(param: ClassicPBSParameters) { } } -fn integer_default_sub(param: ClassicPBSParameters) { +fn integer_default_sub

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1757,7 +1872,7 @@ fn integer_default_sub(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST_SMALLER { // Define the cleartexts @@ -1787,7 +1902,10 @@ fn integer_default_sub(param: ClassicPBSParameters) { } } -fn integer_default_sub_work_efficient(param: ClassicPBSParameters) { +fn integer_default_sub_work_efficient

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1795,7 +1913,7 @@ fn integer_default_sub_work_efficient(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST_SMALLER { // Define the cleartexts @@ -1819,7 +1937,10 @@ fn integer_default_sub_work_efficient(param: ClassicPBSParameters) { } } -fn integer_unchecked_block_mul(param: ClassicPBSParameters) { +fn integer_unchecked_block_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1827,9 +1948,9 @@ fn integer_unchecked_block_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let block_modulus = param.message_modulus.0 as u64; + let block_modulus = cks.parameters().message_modulus().0 as u64; for _ in 0..NB_TEST { let clear_0 = rng.gen::() % modulus; @@ -1852,7 +1973,10 @@ fn integer_unchecked_block_mul(param: ClassicPBSParameters) { } } -fn integer_smart_block_mul(param: ClassicPBSParameters) { +fn integer_smart_block_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1860,9 +1984,9 @@ fn integer_smart_block_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let block_modulus = param.message_modulus.0 as u64; + let block_modulus = cks.parameters().message_modulus().0 as u64; for _ in 0..5 { // Define the cleartexts @@ -1890,7 +2014,10 @@ fn integer_smart_block_mul(param: ClassicPBSParameters) { } } -fn integer_default_block_mul(param: ClassicPBSParameters) { +fn integer_default_block_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1898,9 +2025,9 @@ fn integer_default_block_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; - let block_modulus = param.message_modulus.0 as u64; + let block_modulus = cks.parameters().message_modulus().0 as u64; for _ in 0..5 { // Define the cleartexts @@ -1944,7 +2071,7 @@ where let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST_SMALLER { // Define the cleartexts @@ -1974,7 +2101,10 @@ where } } -fn integer_default_mul(param: ClassicPBSParameters) { +fn integer_default_mul

(param: P) +where + P: Into, +{ let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); @@ -1982,7 +2112,7 @@ fn integer_default_mul(param: ClassicPBSParameters) { let mut rng = rand::thread_rng(); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; for _ in 0..NB_TEST_SMALLER { // Define the cleartexts @@ -2016,13 +2146,16 @@ fn integer_default_mul(param: ClassicPBSParameters) { } } -fn integer_smart_scalar_add(param: ClassicPBSParameters) { +fn integer_smart_scalar_add

(param: P) +where + P: Into, +{ // generate the server-client key set let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -2057,13 +2190,16 @@ fn integer_smart_scalar_add(param: ClassicPBSParameters) { } } -fn integer_default_scalar_add(param: ClassicPBSParameters) { +fn integer_default_scalar_add

(param: P) +where + P: Into, +{ // generate the server-client key set let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -2102,13 +2238,16 @@ fn integer_default_scalar_add(param: ClassicPBSParameters) { } } -fn integer_smart_scalar_sub(param: ClassicPBSParameters) { +fn integer_smart_scalar_sub

(param: P) +where + P: Into, +{ // generate the server-client key set let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; @@ -2143,13 +2282,16 @@ fn integer_smart_scalar_sub(param: ClassicPBSParameters) { } } -fn integer_default_scalar_sub(param: ClassicPBSParameters) { +fn integer_default_scalar_sub

(param: P) +where + P: Into, +{ // generate the server-client key set let (cks, sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); // message_modulus^vec_length - let modulus = param.message_modulus.0.pow(NB_CTXT as u32) as u64; + let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; let mut clear; diff --git a/tfhe/src/shortint/server_key/tests.rs b/tfhe/src/shortint/server_key/tests.rs index f5161fead..e075b81b7 100644 --- a/tfhe/src/shortint/server_key/tests.rs +++ b/tfhe/src/shortint/server_key/tests.rs @@ -49,7 +49,17 @@ macro_rules! create_parametrized_test{ PARAM_MESSAGE_5_CARRY_3, PARAM_MESSAGE_6_CARRY_1, PARAM_MESSAGE_6_CARRY_2, - PARAM_MESSAGE_7_CARRY_1 + PARAM_MESSAGE_7_CARRY_1, + PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2, + // // These parameters seem to introduce too much noise during computation + // PARAM_MULTI_BIT_MESSAGE_4_CARRY_4_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3 + // // These parameters seem to introduce too much noise during computation + // PARAM_MULTI_BIT_MESSAGE_4_CARRY_4_GROUP_3 }); }; } @@ -84,7 +94,17 @@ macro_rules! create_parametrized_test_bivariate_pbs_compliant{ PARAM_MESSAGE_3_CARRY_3, PARAM_MESSAGE_3_CARRY_4, PARAM_MESSAGE_3_CARRY_5, - PARAM_MESSAGE_4_CARRY_4 + PARAM_MESSAGE_4_CARRY_4, + PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2, + // // These parameters seem to introduce too much noise during computation + // PARAM_MULTI_BIT_MESSAGE_4_CARRY_4_GROUP_2, + PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3, + PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3, + PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3 + // // These parameters seem to introduce too much noise during computation + // PARAM_MULTI_BIT_MESSAGE_4_CARRY_4_GROUP_3 }); }; } @@ -125,11 +145,6 @@ create_parametrized_test!(shortint_default_sub); create_parametrized_test!(shortint_mul_small_carry); create_parametrized_test!(shortint_mux); -#[test] -fn test_shortint_mux_param_multi_bit_message_2_carry_2_group_2() { - shortint_mux(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2) -} - // Public key tests are limited to small parameter sets to avoid blowing up memory and large testing // times. Compressed keygen takes 20 minutes for params 2_2 and for encryption as well. // 2_2 uncompressed keys take ~2 GB and 3_3 about ~34 GB, hence why we stop at 2_2. @@ -195,7 +210,10 @@ create_parametrized_test_bivariate_pbs_compliant!( create_parametrized_test_bivariate_pbs_compliant!(shortint_unchecked_less_or_equal_trivial); /// test encryption and decryption with the LWE client key -fn shortint_encrypt_decrypt(param: ClassicPBSParameters) { +fn shortint_encrypt_decrypt

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let cks = keys.client_key(); @@ -217,7 +235,10 @@ fn shortint_encrypt_decrypt(param: ClassicPBSParameters) { } /// test encryption and decryption with the LWE client key -fn shortint_encrypt_with_message_modulus_decrypt(param: ClassicPBSParameters) { +fn shortint_encrypt_with_message_modulus_decrypt

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let cks = keys.client_key(); @@ -241,7 +262,10 @@ fn shortint_encrypt_with_message_modulus_decrypt(param: ClassicPBSParameters) { } } -fn shortint_encrypt_decrypt_without_padding(param: ClassicPBSParameters) { +fn shortint_encrypt_decrypt_without_padding

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let cks = keys.client_key(); @@ -263,7 +287,10 @@ fn shortint_encrypt_decrypt_without_padding(param: ClassicPBSParameters) { } } -fn shortint_keyswitch_bootstrap(param: ClassicPBSParameters) { +fn shortint_keyswitch_bootstrap

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); @@ -296,7 +323,10 @@ fn shortint_keyswitch_bootstrap(param: ClassicPBSParameters) { assert_eq!(0, failures); } -fn shortint_keyswitch_programmable_bootstrap(param: ClassicPBSParameters) { +fn shortint_keyswitch_programmable_bootstrap

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -323,7 +353,10 @@ fn shortint_keyswitch_programmable_bootstrap(param: ClassicPBSParameters) { } } -fn shortint_keyswitch_bivariate_programmable_bootstrap(param: ClassicPBSParameters) { +fn shortint_keyswitch_bivariate_programmable_bootstrap

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -352,7 +385,10 @@ fn shortint_keyswitch_bivariate_programmable_bootstrap(param: ClassicPBSParamete } /// test extraction of a carry -fn shortint_carry_extract(param: ClassicPBSParameters) { +fn shortint_carry_extract

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -385,15 +421,19 @@ fn shortint_carry_extract(param: ClassicPBSParameters) { } /// test extraction of a message -fn shortint_message_extract(param: ClassicPBSParameters) { +fn shortint_message_extract

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG let mut rng = rand::thread_rng(); - let modulus_sup = (param.message_modulus.0 * param.carry_modulus.0) as u64; + let modulus_sup = + (cks.parameters.message_modulus().0 * cks.parameters.carry_modulus().0) as u64; - let modulus = param.message_modulus.0 as u64; + let modulus = cks.parameters.message_modulus().0 as u64; for _ in 0..NB_TEST { let clear = rng.gen::() % modulus_sup; @@ -413,7 +453,10 @@ fn shortint_message_extract(param: ClassicPBSParameters) { } /// test multiplication with the LWE server key -fn shortint_generate_accumulator(param: ClassicPBSParameters) { +fn shortint_generate_accumulator

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); let double = |x| 2 * x; @@ -441,7 +484,10 @@ fn shortint_generate_accumulator(param: ClassicPBSParameters) { } /// test addition with the LWE server key -fn shortint_unchecked_add(param: ClassicPBSParameters) { +fn shortint_unchecked_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -476,7 +522,10 @@ fn shortint_unchecked_add(param: ClassicPBSParameters) { } /// test addition with the LWE server key -fn shortint_smart_add(param: ClassicPBSParameters) { +fn shortint_smart_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); @@ -515,7 +564,10 @@ fn shortint_smart_add(param: ClassicPBSParameters) { } /// test default addition with the LWE server key -fn shortint_default_add(param: ClassicPBSParameters) { +fn shortint_default_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); @@ -554,7 +606,10 @@ fn shortint_default_add(param: ClassicPBSParameters) { } /// test addition with the LWE server key using the a public key for encryption -fn shortint_compressed_public_key_smart_add(param: ClassicPBSParameters) { +fn shortint_compressed_public_key_smart_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); let pk = crate::shortint::CompressedPublicKeyBig::new(cks); @@ -594,7 +649,10 @@ fn shortint_compressed_public_key_smart_add(param: ClassicPBSParameters) { } /// test addition with the LWE server key using the a public key for encryption -fn shortint_public_key_smart_add(param: ClassicPBSParameters) { +fn shortint_public_key_smart_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); let pk = crate::shortint::PublicKeyBig::new(cks); @@ -634,7 +692,10 @@ fn shortint_public_key_smart_add(param: ClassicPBSParameters) { } /// test bitwise 'and' with the LWE server key -fn shortint_unchecked_bitand(param: ClassicPBSParameters) { +fn shortint_unchecked_bitand

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -664,7 +725,10 @@ fn shortint_unchecked_bitand(param: ClassicPBSParameters) { } /// test bitwise 'or' with the LWE server key -fn shortint_unchecked_bitor(param: ClassicPBSParameters) { +fn shortint_unchecked_bitor

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -694,7 +758,10 @@ fn shortint_unchecked_bitor(param: ClassicPBSParameters) { } /// test bitwise 'xor' with the LWE server key -fn shortint_unchecked_bitxor(param: ClassicPBSParameters) { +fn shortint_unchecked_bitxor

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -724,7 +791,10 @@ fn shortint_unchecked_bitxor(param: ClassicPBSParameters) { } /// test bitwise 'and' with the LWE server key -fn shortint_smart_bitand(param: ClassicPBSParameters) { +fn shortint_smart_bitand

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -762,7 +832,10 @@ fn shortint_smart_bitand(param: ClassicPBSParameters) { } /// test default bitwise 'and' with the LWE server key -fn shortint_default_bitand(param: ClassicPBSParameters) { +fn shortint_default_bitand

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -800,7 +873,10 @@ fn shortint_default_bitand(param: ClassicPBSParameters) { } /// test bitwise 'or' with the LWE server key -fn shortint_smart_bitor(param: ClassicPBSParameters) { +fn shortint_smart_bitor

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -838,7 +914,10 @@ fn shortint_smart_bitor(param: ClassicPBSParameters) { } /// test default bitwise 'or' with the LWE server key -fn shortint_default_bitor(param: ClassicPBSParameters) { +fn shortint_default_bitor

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -876,7 +955,10 @@ fn shortint_default_bitor(param: ClassicPBSParameters) { } /// test bitwise 'xor' with the LWE server key -fn shortint_smart_bitxor(param: ClassicPBSParameters) { +fn shortint_smart_bitxor

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -914,7 +996,10 @@ fn shortint_smart_bitxor(param: ClassicPBSParameters) { } /// test default bitwise 'xor' with the LWE server key -fn shortint_default_bitxor(param: ClassicPBSParameters) { +fn shortint_default_bitxor

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -952,7 +1037,10 @@ fn shortint_default_bitxor(param: ClassicPBSParameters) { } /// test '>' with the LWE server key -fn shortint_unchecked_greater(param: ClassicPBSParameters) { +fn shortint_unchecked_greater

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -982,7 +1070,10 @@ fn shortint_unchecked_greater(param: ClassicPBSParameters) { } /// test '>' with the LWE server key -fn shortint_smart_greater(param: ClassicPBSParameters) { +fn shortint_smart_greater

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1012,7 +1103,10 @@ fn shortint_smart_greater(param: ClassicPBSParameters) { } /// test default '>' with the LWE server key -fn shortint_default_greater(param: ClassicPBSParameters) { +fn shortint_default_greater

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1042,7 +1136,10 @@ fn shortint_default_greater(param: ClassicPBSParameters) { } /// test '>=' with the LWE server key -fn shortint_unchecked_greater_or_equal(param: ClassicPBSParameters) { +fn shortint_unchecked_greater_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1072,7 +1169,10 @@ fn shortint_unchecked_greater_or_equal(param: ClassicPBSParameters) { } /// test '>=' with the LWE server key -fn shortint_smart_greater_or_equal(param: ClassicPBSParameters) { +fn shortint_smart_greater_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1110,7 +1210,10 @@ fn shortint_smart_greater_or_equal(param: ClassicPBSParameters) { } /// test default '>=' with the LWE server key -fn shortint_default_greater_or_equal(param: ClassicPBSParameters) { +fn shortint_default_greater_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1148,7 +1251,10 @@ fn shortint_default_greater_or_equal(param: ClassicPBSParameters) { } /// test '<' with the LWE server key -fn shortint_unchecked_less(param: ClassicPBSParameters) { +fn shortint_unchecked_less

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1178,7 +1284,10 @@ fn shortint_unchecked_less(param: ClassicPBSParameters) { } /// test '<' with the LWE server key -fn shortint_smart_less(param: ClassicPBSParameters) { +fn shortint_smart_less

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1216,7 +1325,10 @@ fn shortint_smart_less(param: ClassicPBSParameters) { } /// test default '<' with the LWE server key -fn shortint_default_less(param: ClassicPBSParameters) { +fn shortint_default_less

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1254,7 +1366,10 @@ fn shortint_default_less(param: ClassicPBSParameters) { } /// test '<=' with the LWE server key -fn shortint_unchecked_less_or_equal(param: ClassicPBSParameters) { +fn shortint_unchecked_less_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1284,7 +1399,10 @@ fn shortint_unchecked_less_or_equal(param: ClassicPBSParameters) { } /// test '<=' with the LWE server key -fn shortint_unchecked_less_or_equal_trivial(param: ClassicPBSParameters) { +fn shortint_unchecked_less_or_equal_trivial

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1314,7 +1432,10 @@ fn shortint_unchecked_less_or_equal_trivial(param: ClassicPBSParameters) { } /// test '<=' with the LWE server key -fn shortint_smart_less_or_equal(param: ClassicPBSParameters) { +fn shortint_smart_less_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1352,7 +1473,10 @@ fn shortint_smart_less_or_equal(param: ClassicPBSParameters) { } /// test default '<=' with the LWE server key -fn shortint_default_less_or_equal(param: ClassicPBSParameters) { +fn shortint_default_less_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1389,7 +1513,10 @@ fn shortint_default_less_or_equal(param: ClassicPBSParameters) { } } -fn shortint_unchecked_equal(param: ClassicPBSParameters) { +fn shortint_unchecked_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1419,7 +1546,10 @@ fn shortint_unchecked_equal(param: ClassicPBSParameters) { } /// test '==' with the LWE server key -fn shortint_smart_equal(param: ClassicPBSParameters) { +fn shortint_smart_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1457,7 +1587,10 @@ fn shortint_smart_equal(param: ClassicPBSParameters) { } /// test default '==' with the LWE server key -fn shortint_default_equal(param: ClassicPBSParameters) { +fn shortint_default_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1495,7 +1628,10 @@ fn shortint_default_equal(param: ClassicPBSParameters) { } /// test '==' with the LWE server key -fn shortint_smart_scalar_equal(param: ClassicPBSParameters) { +fn shortint_smart_scalar_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1524,7 +1660,10 @@ fn shortint_smart_scalar_equal(param: ClassicPBSParameters) { } /// test '<' with the LWE server key -fn shortint_smart_scalar_less(param: ClassicPBSParameters) { +fn shortint_smart_scalar_less

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1553,7 +1692,10 @@ fn shortint_smart_scalar_less(param: ClassicPBSParameters) { } /// test '<=' with the LWE server key -fn shortint_smart_scalar_less_or_equal(param: ClassicPBSParameters) { +fn shortint_smart_scalar_less_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1582,7 +1724,10 @@ fn shortint_smart_scalar_less_or_equal(param: ClassicPBSParameters) { } /// test '>' with the LWE server key -fn shortint_smart_scalar_greater(param: ClassicPBSParameters) { +fn shortint_smart_scalar_greater

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1611,7 +1756,10 @@ fn shortint_smart_scalar_greater(param: ClassicPBSParameters) { } /// test '>' with the LWE server key -fn shortint_smart_scalar_greater_or_equal(param: ClassicPBSParameters) { +fn shortint_smart_scalar_greater_or_equal

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1640,7 +1788,10 @@ fn shortint_smart_scalar_greater_or_equal(param: ClassicPBSParameters) { } /// test division with the LWE server key -fn shortint_unchecked_div(param: ClassicPBSParameters) { +fn shortint_unchecked_div

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1670,7 +1821,10 @@ fn shortint_unchecked_div(param: ClassicPBSParameters) { } /// test scalar division with the LWE server key -fn shortint_unchecked_scalar_div(param: ClassicPBSParameters) { +fn shortint_unchecked_scalar_div

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1697,7 +1851,10 @@ fn shortint_unchecked_scalar_div(param: ClassicPBSParameters) { } /// test modulus with the LWE server key -fn shortint_unchecked_mod(param: ClassicPBSParameters) { +fn shortint_unchecked_mod

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1724,7 +1881,10 @@ fn shortint_unchecked_mod(param: ClassicPBSParameters) { } /// test LSB multiplication with the LWE server key -fn shortint_unchecked_mul_lsb(param: ClassicPBSParameters) { +fn shortint_unchecked_mul_lsb

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1754,7 +1914,10 @@ fn shortint_unchecked_mul_lsb(param: ClassicPBSParameters) { } /// test MSB multiplication with the LWE server key -fn shortint_unchecked_mul_msb(param: ClassicPBSParameters) { +fn shortint_unchecked_mul_msb

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1784,7 +1947,10 @@ fn shortint_unchecked_mul_msb(param: ClassicPBSParameters) { } /// test LSB multiplication with the LWE server key -fn shortint_smart_mul_lsb(param: ClassicPBSParameters) { +fn shortint_smart_mul_lsb

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1823,7 +1989,10 @@ fn shortint_smart_mul_lsb(param: ClassicPBSParameters) { } /// test default LSB multiplication with the LWE server key -fn shortint_default_mul_lsb(param: ClassicPBSParameters) { +fn shortint_default_mul_lsb

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1862,7 +2031,10 @@ fn shortint_default_mul_lsb(param: ClassicPBSParameters) { } /// test MSB multiplication with the LWE server key -fn shortint_smart_mul_msb(param: ClassicPBSParameters) { +fn shortint_smart_mul_msb

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1905,7 +2077,10 @@ fn shortint_smart_mul_msb(param: ClassicPBSParameters) { } /// test default MSB multiplication with the LWE server key -fn shortint_default_mul_msb(param: ClassicPBSParameters) { +fn shortint_default_mul_msb

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1948,7 +2123,10 @@ fn shortint_default_mul_msb(param: ClassicPBSParameters) { } /// test unchecked negation -fn shortint_unchecked_neg(param: ClassicPBSParameters) { +fn shortint_unchecked_neg

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -1977,7 +2155,10 @@ fn shortint_unchecked_neg(param: ClassicPBSParameters) { } /// test smart negation -fn shortint_smart_neg(param: ClassicPBSParameters) { +fn shortint_smart_neg

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2010,7 +2191,10 @@ fn shortint_smart_neg(param: ClassicPBSParameters) { } /// test default negation -fn shortint_default_neg(param: ClassicPBSParameters) { +fn shortint_default_neg

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2043,13 +2227,16 @@ fn shortint_default_neg(param: ClassicPBSParameters) { } /// test scalar add -fn shortint_unchecked_scalar_add(param: ClassicPBSParameters) { +fn shortint_unchecked_scalar_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); let mut rng = rand::thread_rng(); - let message_modulus = param.message_modulus.0 as u8; + let message_modulus = cks.parameters.message_modulus().0 as u8; for _ in 0..NB_TEST { let clear = rng.gen::() % message_modulus; @@ -2071,7 +2258,10 @@ fn shortint_unchecked_scalar_add(param: ClassicPBSParameters) { } /// test smart scalar add -fn shortint_smart_scalar_add(param: ClassicPBSParameters) { +fn shortint_smart_scalar_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2106,7 +2296,10 @@ fn shortint_smart_scalar_add(param: ClassicPBSParameters) { } /// test default smart scalar add -fn shortint_default_scalar_add(param: ClassicPBSParameters) { +fn shortint_default_scalar_add

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2141,13 +2334,16 @@ fn shortint_default_scalar_add(param: ClassicPBSParameters) { } /// test unchecked scalar sub -fn shortint_unchecked_scalar_sub(param: ClassicPBSParameters) { +fn shortint_unchecked_scalar_sub

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); let mut rng = rand::thread_rng(); - let message_modulus = param.message_modulus.0 as u8; + let message_modulus = cks.parameters.message_modulus().0 as u8; for _ in 0..NB_TEST { let clear = rng.gen::() % message_modulus; @@ -2168,7 +2364,10 @@ fn shortint_unchecked_scalar_sub(param: ClassicPBSParameters) { } } -fn shortint_smart_scalar_sub(param: ClassicPBSParameters) { +fn shortint_smart_scalar_sub

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2208,7 +2407,10 @@ fn shortint_smart_scalar_sub(param: ClassicPBSParameters) { } } -fn shortint_default_scalar_sub(param: ClassicPBSParameters) { +fn shortint_default_scalar_sub

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2249,14 +2451,17 @@ fn shortint_default_scalar_sub(param: ClassicPBSParameters) { } /// test scalar multiplication with the LWE server key -fn shortint_unchecked_scalar_mul(param: ClassicPBSParameters) { +fn shortint_unchecked_scalar_mul

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); let mut rng = rand::thread_rng(); - let message_modulus = param.message_modulus.0 as u8; - let carry_modulus = param.carry_modulus.0 as u8; + let message_modulus = cks.parameters.message_modulus().0 as u8; + let carry_modulus = cks.parameters.carry_modulus().0 as u8; for _ in 0..NB_TEST { let clear = rng.gen::() % message_modulus; @@ -2278,7 +2483,10 @@ fn shortint_unchecked_scalar_mul(param: ClassicPBSParameters) { } /// test default smart scalar multiplication with the LWE server key -fn shortint_smart_scalar_mul(param: ClassicPBSParameters) { +fn shortint_smart_scalar_mul

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2314,7 +2522,10 @@ fn shortint_smart_scalar_mul(param: ClassicPBSParameters) { } /// test default smart scalar multiplication with the LWE server key -fn shortint_default_scalar_mul(param: ClassicPBSParameters) { +fn shortint_default_scalar_mul

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2350,7 +2561,10 @@ fn shortint_default_scalar_mul(param: ClassicPBSParameters) { } /// test unchecked '>>' operation -fn shortint_unchecked_right_shift(param: ClassicPBSParameters) { +fn shortint_unchecked_right_shift

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2377,7 +2591,10 @@ fn shortint_unchecked_right_shift(param: ClassicPBSParameters) { } /// test default unchecked '>>' operation -fn shortint_default_right_shift(param: ClassicPBSParameters) { +fn shortint_default_right_shift

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2404,7 +2621,10 @@ fn shortint_default_right_shift(param: ClassicPBSParameters) { } /// test '<<' operation -fn shortint_unchecked_left_shift(param: ClassicPBSParameters) { +fn shortint_unchecked_left_shift

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2431,7 +2651,10 @@ fn shortint_unchecked_left_shift(param: ClassicPBSParameters) { } /// test default '<<' operation -fn shortint_default_left_shift(param: ClassicPBSParameters) { +fn shortint_default_left_shift

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2458,7 +2681,10 @@ fn shortint_default_left_shift(param: ClassicPBSParameters) { } /// test unchecked subtraction -fn shortint_unchecked_sub(param: ClassicPBSParameters) { +fn shortint_unchecked_sub

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2486,7 +2712,10 @@ fn shortint_unchecked_sub(param: ClassicPBSParameters) { } } -fn shortint_smart_sub(param: ClassicPBSParameters) { +fn shortint_smart_sub

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2517,7 +2746,10 @@ fn shortint_smart_sub(param: ClassicPBSParameters) { } } -fn shortint_default_sub(param: ClassicPBSParameters) { +fn shortint_default_sub

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2549,7 +2781,10 @@ fn shortint_default_sub(param: ClassicPBSParameters) { } /// test multiplication -fn shortint_mul_small_carry(param: ClassicPBSParameters) { +fn shortint_mul_small_carry

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key()); //RNG @@ -2580,7 +2815,10 @@ fn shortint_mul_small_carry(param: ClassicPBSParameters) { } /// test encryption and decryption with the LWE client key -fn shortint_encrypt_with_message_modulus_smart_add_and_mul(param: ClassicPBSParameters) { +fn shortint_encrypt_with_message_modulus_smart_add_and_mul

(param: P) +where + P: Into, +{ let keys = KEY_CACHE.get_from_param(param); let (cks, sks) = (keys.client_key(), keys.server_key());