diff --git a/scripts/test_filtering.py b/scripts/test_filtering.py index 4b2e5c924..1b47f744d 100644 --- a/scripts/test_filtering.py +++ b/scripts/test_filtering.py @@ -192,7 +192,7 @@ def filter_shortint_tests(input_args): msg_carry_pairs.append((4, 4)) filter_expression = [ - f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks_pbs.*/)" + f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks(32)?_pbs.*/)" for msg, carry in msg_carry_pairs ] filter_expression.append("test(/^shortint::.*_ci_run_filter/)") diff --git a/tfhe/src/shortint/atomic_pattern/mod.rs b/tfhe/src/shortint/atomic_pattern/mod.rs index 0d7e81a42..7f90ce2cd 100644 --- a/tfhe/src/shortint/atomic_pattern/mod.rs +++ b/tfhe/src/shortint/atomic_pattern/mod.rs @@ -551,3 +551,46 @@ impl From for AtomicPatternServerKey { Self::KeySwitch32(value) } } + +#[cfg(test)] +mod test { + use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128; + use crate::shortint::{gen_keys, ServerKey}; + + use super::AtomicPatternServerKey; + + // Test an implementation of the KS32 AP as a dynamic atomic pattern + #[test] + fn test_ks32_as_dyn_ap_ci_run_filter() { + let (client_key, server_key) = + gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128); + + // Convert the static ks 32 server key into a dynamic one + let AtomicPatternServerKey::KeySwitch32(ks32_key) = server_key.atomic_pattern else { + panic!("We know from parameters that AP is KS32") + }; + + let ap_key = AtomicPatternServerKey::Dynamic(Box::new(ks32_key)); + + // Re create the server key with the DAP + let server_key = ServerKey::from_raw_parts( + ap_key, + server_key.message_modulus, + server_key.carry_modulus, + server_key.max_degree, + server_key.max_noise_level, + ); + + // Do some operation + let msg1 = 1; + let msg2 = 0; + + let ct_1 = client_key.encrypt(msg1); + let ct_2 = client_key.encrypt(msg2); + + let ct_3 = server_key.add(&ct_1, &ct_2); + + let output = client_key.decrypt(&ct_3); + assert_eq!(output, 1); + } +} diff --git a/tfhe/src/shortint/oprf.rs b/tfhe/src/shortint/oprf.rs index 5a8fdd548..43a221870 100644 --- a/tfhe/src/shortint/oprf.rs +++ b/tfhe/src/shortint/oprf.rs @@ -207,9 +207,11 @@ impl GenericServerKey { #[cfg(test)] pub(crate) mod test { - use crate::core_crypto::prelude::decrypt_lwe_ciphertext; - use crate::shortint::oprf::create_random_from_seed_modulus_switched; - use crate::shortint::{ClientKey, ServerKey}; + use crate::core_crypto::prelude::{decrypt_lwe_ciphertext, LweSecretKey}; + use crate::shortint::{ClientKey, ServerKey, ShortintParameterSet}; + + use super::*; + use rayon::prelude::*; use statrs::distribution::ContinuousCDF; use std::collections::HashMap; @@ -222,22 +224,34 @@ pub(crate) mod test { #[test] fn oprf_compare_plain_ci_run_filter() { use crate::shortint::gen_keys; + use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); for seed in 0..1000 { - oprf_compare_plain_from_seed(Seed(seed), &ck, &sk); + oprf_compare_plain_from_seed::(Seed(seed), &ck, &sk); + } + + let (ck, sk) = gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128); + + for seed in 0..1000 { + oprf_compare_plain_from_seed::(Seed(seed), &ck, &sk); } } - fn oprf_compare_plain_from_seed(seed: Seed, ck: &ClientKey, sk: &ServerKey) { + fn oprf_compare_plain_from_seed + CastInto>( + seed: Seed, + ck: &ClientKey, + sk: &ServerKey, + ) { let params = ck.parameters; let random_bits_count = 2; let input_p = 2 * params.polynomial_size().0 as u64; - let log_input_p = input_p.ilog2(); + let log_input_p = input_p.ilog2() as usize; let p_prime = 1 << random_bits_count; @@ -255,15 +269,24 @@ pub(crate) mod test { params .polynomial_size() .to_blind_rotation_input_modulus_log(), - sk.ciphertext_modulus, + CiphertextModulus::new_native(), ); - let sk = ck.small_lwe_secret_key(); + let sk = LweSecretKey::from_container( + ck.small_lwe_secret_key() + .as_ref() + .iter() + .copied() + .map(|x| Scalar::cast_from(x)) + .collect::>(), + ); - let plain_prf_input = decrypt_lwe_ciphertext(&sk, &ct) - .0 - .wrapping_add(1 << (64 - log_input_p - 1)) - >> (64 - log_input_p); + let plain_prf_input = CastInto::::cast_into( + decrypt_lwe_ciphertext(&sk, &ct) + .0 + .wrapping_add(Scalar::ONE << (Scalar::BITS - log_input_p - 1)) + >> (Scalar::BITS - log_input_p), + ); let half_negacyclic_part = |x| 2 * (x / poly_delta) + 1; @@ -296,20 +319,28 @@ pub(crate) mod test { let p_value_limit: f64 = 0.000_01; use crate::shortint::gen_keys; + use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128; use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); - let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| { - test_uniformity(sample_count, p_value_limit, distinct_values, f) - }; + for params in [ + ShortintParameterSet::from(PARAM_MESSAGE_2_CARRY_2_KS_PBS), + ShortintParameterSet::from(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128), + ] { + let (ck, sk) = gen_keys(params); - let random_bits_count = 2; + let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| { + test_uniformity(sample_count, p_value_limit, distinct_values, f) + }; - test_uniformity(1 << random_bits_count, &|seed| { - let img = sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count); + let random_bits_count = 2; - ck.decrypt_message_and_carry(&img) - }); + test_uniformity(1 << random_bits_count, &|seed| { + let img = + sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count); + + ck.decrypt_message_and_carry(&img) + }); + } } pub fn test_uniformity(sample_count: usize, p_value_limit: f64, distinct_values: u64, f: F) diff --git a/tfhe/src/shortint/parameters/aliases.rs b/tfhe/src/shortint/parameters/aliases.rs index 11b6ff539..1e45f7def 100644 --- a/tfhe/src/shortint/parameters/aliases.rs +++ b/tfhe/src/shortint/parameters/aliases.rs @@ -44,6 +44,7 @@ use current_params::multi_bit::tuniform::p_fail_2_minus_64::ks_pbs_gpu::{ V1_1_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64, }; use current_params::noise_squashing::p_fail_2_minus_128::V1_1_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + // Aliases // Compute Gaussian diff --git a/tfhe/src/shortint/parameters/test_params.rs b/tfhe/src/shortint/parameters/test_params.rs index 164adfacd..3af42c5d0 100644 --- a/tfhe/src/shortint/parameters/test_params.rs +++ b/tfhe/src/shortint/parameters/test_params.rs @@ -1,5 +1,5 @@ use super::current_params::*; -use super::AtomicPatternParameters; +use super::{AtomicPatternParameters, KeySwitch32PBSParameters}; use super::{ ClassicPBSParameters, CompactPublicKeyEncryptionParameters, CompressionParameters, @@ -209,3 +209,7 @@ pub const TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: CompressionPa pub const TEST_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: CompressionParameters = V1_1_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128; + +// KS32 PBS AP +pub const TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128: KeySwitch32PBSParameters = + V1_1_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128; diff --git a/tfhe/src/shortint/server_key/tests/parameterized_test.rs b/tfhe/src/shortint/server_key/tests/parameterized_test.rs index 0e21e2152..d525f6e04 100644 --- a/tfhe/src/shortint/server_key/tests/parameterized_test.rs +++ b/tfhe/src/shortint/server_key/tests/parameterized_test.rs @@ -55,7 +55,8 @@ macro_rules! create_parameterized_test{ TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64, TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, - TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64 + TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64, + TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128 }); }; }