diff --git a/tfhe-benchmark/benches/integer/bench.rs b/tfhe-benchmark/benches/integer/bench.rs index 50bfd1350..493535b40 100644 --- a/tfhe-benchmark/benches/integer/bench.rs +++ b/tfhe-benchmark/benches/integer/bench.rs @@ -13,24 +13,30 @@ use rayon::prelude::*; use std::cell::LazyCell; use std::cmp::max; use std::env; +use tfhe::core_crypto::algorithms::modulus_switch::ModulusSwitchedLweCiphertext; use tfhe::core_crypto::algorithms::{ allocate_and_generate_new_binary_glwe_secret_key, allocate_and_generate_new_binary_lwe_secret_key, allocate_and_generate_new_lwe_keyswitch_key, extract_lwe_sample_from_glwe_ciphertext, keyswitch_lwe_ciphertext_with_scalar_change, + lwe_ciphertext_centered_binary_modulus_switch, }; +use tfhe::core_crypto::entities::modulus_switched_lwe_ciphertext::LazyStandardModulusSwitchedLweCiphertext; +use tfhe::core_crypto::entities::LweCiphertext; use tfhe::integer::keycache::KEY_CACHE; use tfhe::integer::prelude::*; use tfhe::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey, U256}; use tfhe::keycache::NamedParam; -use tfhe::shortint::atomic_pattern::{AtomicPattern, AtomicPatternMut}; +use tfhe::shortint::atomic_pattern::{AtomicPattern, AtomicPatternMut, AtomicPatternServerKey}; +use tfhe::shortint::ciphertext::{Degree, NoiseLevel}; use tfhe::shortint::engine::ShortintEngine; use tfhe::{get_pbs_count, reset_pbs_count}; +#[derive(Clone, Copy, Debug, PartialEq)] pub struct ParamKS32MB { pub lwe_dimension: LweDimension, pub glwe_dimension: GlweDimension, pub polynomial_size: PolynomialSize, - pub lwe_noise_distribution: DynamicDistribution, + pub lwe_noise_distribution: DynamicDistribution, pub glwe_noise_distribution: DynamicDistribution, pub pbs_base_log: DecompositionBaseLog, pub pbs_level: DecompositionLevelCount, @@ -69,68 +75,119 @@ impl ClientKeyKS32MB { Self { lwe_secret_key, glwe_secret_key, - parameters: *params.to_owned(), + parameters: *params, } } + // warning fake encryption for benchmark pub fn encrypt_shortint(&self, msg: u64) -> Ciphertext { - todo!() + let cont = (0..self + .parameters + .glwe_dimension + .to_equivalent_lwe_dimension(self.parameters.polynomial_size) + .to_lwe_size() + .0) + .map(|_| rand::random()) + .collect::>(); + Ciphertext::new( + LweCiphertext::from_container(cont, self.parameters.ciphertext_modulus), + Degree::new(self.parameters.message_modulus.0 - 1), + NoiseLevel::NOMINAL, + self.parameters.message_modulus, + self.parameters.carry_modulus, + tfhe::shortint::AtomicPatternKind::KeySwitch32, + ) } // tfhe/src/integer/ciphertext/base.rs:29 pub fn encrypt_radix(&self, msg: u64, block_count: usize) -> RadixCiphertext { - let blocks: Vec = todo!(); + let blocks: Vec = (0..block_count) + .map(|_| self.encrypt_shortint(msg)) + .collect(); RadixCiphertext::from(blocks) } } +#[derive(Clone, Debug, PartialEq)] pub struct ServerKeyKS32MB { lwe_ksk: LweKeyswitchKeyOwned, lwe_bsk: ShortintBootstrappingKey, is_2m128: bool, + parameters: ParamKS32MB, } impl ServerKeyKS32MB { pub fn new(cks: &ClientKeyKS32MB, is_2m128: bool) -> Self { let params = &cks.parameters; - let in_key = cks.lwe_secret_key.as_view(); + let in_key_u64 = { + // Convert to u64 as a workaround + let cont = cks + .lwe_secret_key + .as_ref() + .iter() + .map(|x| *x as u64) + .collect::>(); + LweSecretKey::from_container(cont) + }; let out_key = &cks.glwe_secret_key; - let (key_switching_key, bootstrapping_key) = ShortintEngine::with_thread_local_mut(|engine| { + let (key_switching_key, bootstrapping_key) = + ShortintEngine::with_thread_local_mut(|engine| { + let bootstrapping_key_base = engine.new_multibit_bootstrapping_key( + &in_key_u64, + &out_key, + params.glwe_noise_distribution, + params.pbs_base_log, + params.pbs_level, + params.grouping_factor, + params.ciphertext_modulus, + ); - let bootstrapping_key_base = engine.new_multibit_bootstrapping_key( - &in_key, - &out_key, - params.glwe_noise_distribution, - params.pbs_base_log, - params.pbs_level, - params.grouping_factor, - params.ciphertext_modulus, + // Creation of the key switching key + let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key( + &cks.glwe_secret_key.as_lwe_secret_key(), + &cks.lwe_secret_key, + params.ks_base_log, + params.ks_level, + params.lwe_noise_distribution, + CiphertextModulus32::new_native(), + &mut engine.encryption_generator, + ); + + (key_switching_key, bootstrapping_key_base) + }); + + let thread_count = ShortintEngine::get_thread_count_for_multi_bit_pbs( + bootstrapping_key.input_lwe_dimension(), + bootstrapping_key.glwe_size().to_glwe_dimension(), + bootstrapping_key.polynomial_size(), + bootstrapping_key.decomposition_base_log(), + bootstrapping_key.decomposition_level_count(), + bootstrapping_key.grouping_factor(), ); - // Creation of the key switching key - let key_switching_key = allocate_and_generate_new_lwe_keyswitch_key( - &cks.glwe_secret_key.as_lwe_secret_key(), - &in_key, - params.ks_base_log, - params.ks_level, - params.lwe_noise_distribution, - CiphertextModulus32::new_native(), - &mut engine.encryption_generator, - ); - - (key_switching_key, bootstrapping_key_base) - }); - Self { lwe_ksk: key_switching_key, - lwe_bsk: bootstrapping_key, + lwe_bsk: ShortintBootstrappingKey::MultiBit { + fourier_bsk: bootstrapping_key, + thread_count, + deterministic_execution: false, + }, is_2m128, + parameters: *params, } } + + pub fn intermediate_lwe_dimension(&self) -> LweDimension { + self.lwe_bsk.input_lwe_dimension() + } + + fn intermediate_ciphertext_modulus(&self) -> CiphertextModulus32 { + self.lwe_ksk.ciphertext_modulus() + } } impl AtomicPatternMut for ServerKeyKS32MB { @@ -150,7 +207,10 @@ impl AtomicPattern for ServerKeyKS32MB { } fn ciphertext_modulus_for_key(&self, key_choice: EncryptionKeyChoice) -> CiphertextModulus { - todo!() // not required normally + match key_choice { + EncryptionKeyChoice::Big => self.parameters.ciphertext_modulus, + EncryptionKeyChoice::Small => CiphertextModulus32::new_native().try_to().unwrap(), + } } fn ciphertext_decompression_method(&self) -> tfhe::core_crypto::prelude::MsDecompressionType { @@ -163,7 +223,48 @@ impl AtomicPattern for ServerKeyKS32MB { acc: &tfhe::shortint::server_key::LookupTableOwned, ) { if self.is_2m128 { - todo!("code bizarre"); + ShortintEngine::with_thread_local_mut(|engine| { + // :warning: the mean compensation is classical mean compensation + classical + // modulus switch + multi bit bootstrap (not sure this is + // implemented in tfhe-rs yet) :warning: + let (mut ciphertext_buffer, buffers) = engine.get_buffers( + self.intermediate_lwe_dimension(), + self.intermediate_ciphertext_modulus(), + ); + + keyswitch_lwe_ciphertext_with_scalar_change( + &self.lwe_ksk, + &ct.ct, + &mut ciphertext_buffer, + ); + + let br_input_modulus_log = self + .lwe_bsk + .polynomial_size() + .to_blind_rotation_input_modulus_log(); + + let msed: LazyStandardModulusSwitchedLweCiphertext = + lwe_ciphertext_centered_binary_modulus_switch( + ciphertext_buffer.as_view(), + self.lwe_bsk + .polynomial_size() + .to_blind_rotation_input_modulus_log(), + ); + + let mut lwe_rescaled_cont = + Vec::with_capacity(msed.lwe_dimension().to_lwe_size().0); + for mask in msed.mask() { + lwe_rescaled_cont.push(mask << (u32::BITS - br_input_modulus_log.0 as u32)); + } + lwe_rescaled_cont.push(msed.body() << (u32::BITS - br_input_modulus_log.0 as u32)); + + let lwe = LweCiphertext::from_container( + lwe_rescaled_cont, + self.lwe_ksk.ciphertext_modulus(), + ); + + apply_programmable_bootstrap(&self.lwe_bsk, &lwe, &mut ct.ct, &acc.acc, buffers); + }) } else { ShortintEngine::with_thread_local_mut(|engine| { let (mut ciphertext_buffer, buffers) = engine.get_buffers( @@ -195,26 +296,71 @@ impl AtomicPattern for ServerKeyKS32MB { ) -> Vec { let mut acc = lut.acc.clone(); - ShortintEngine::with_thread_local_mut(|engine| { - let (mut ciphertext_buffer, buffers) = engine.get_buffers( - self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Small), - self.lwe_ksk.ciphertext_modulus(), - ); + if self.is_2m128 { + ShortintEngine::with_thread_local_mut(|engine| { + // :warning: the mean compensation is classical mean compensation + classical + // modulus switch + multi bit bootstrap (not sure this is + // implemented in tfhe-rs yet) :warning: + let (mut ciphertext_buffer, buffers) = engine.get_buffers( + self.intermediate_lwe_dimension(), + self.intermediate_ciphertext_modulus(), + ); - // Compute a key switch - keyswitch_lwe_ciphertext_with_scalar_change( - &self.lwe_ksk, - &ct.ct, - &mut ciphertext_buffer, - ); + keyswitch_lwe_ciphertext_with_scalar_change( + &self.lwe_ksk, + &ct.ct, + &mut ciphertext_buffer, + ); - apply_ms_blind_rotate( - &self.lwe_bsk, - &ciphertext_buffer.as_view(), - &mut acc, - buffers, - ); - }); + let br_input_modulus_log = self + .lwe_bsk + .polynomial_size() + .to_blind_rotation_input_modulus_log(); + + let msed: LazyStandardModulusSwitchedLweCiphertext = + lwe_ciphertext_centered_binary_modulus_switch( + ciphertext_buffer.as_view(), + self.lwe_bsk + .polynomial_size() + .to_blind_rotation_input_modulus_log(), + ); + + let mut lwe_rescaled_cont = + Vec::with_capacity(msed.lwe_dimension().to_lwe_size().0); + for mask in msed.mask() { + lwe_rescaled_cont.push(mask << (u32::BITS - br_input_modulus_log.0 as u32)); + } + lwe_rescaled_cont.push(msed.body() << (u32::BITS - br_input_modulus_log.0 as u32)); + + let lwe = LweCiphertext::from_container( + lwe_rescaled_cont, + self.lwe_ksk.ciphertext_modulus(), + ); + + apply_ms_blind_rotate(&self.lwe_bsk, &lwe, &mut acc, buffers); + }) + } else { + ShortintEngine::with_thread_local_mut(|engine| { + let (mut ciphertext_buffer, buffers) = engine.get_buffers( + self.ciphertext_lwe_dimension_for_key(EncryptionKeyChoice::Small), + self.lwe_ksk.ciphertext_modulus(), + ); + + // Compute a key switch + keyswitch_lwe_ciphertext_with_scalar_change( + &self.lwe_ksk, + &ct.ct, + &mut ciphertext_buffer, + ); + + apply_ms_blind_rotate( + &self.lwe_bsk, + &ciphertext_buffer.as_view(), + &mut acc, + buffers, + ); + }); + } // The accumulator has been rotated, we can now proceed with the various sample extractions let function_count = lut.function_count(); @@ -239,12 +385,14 @@ impl AtomicPattern for ServerKeyKS32MB { } fn lookup_table_size(&self) -> tfhe::shortint::server_key::LookupTableSize { - // See KS32 shortint - todo!() + tfhe::shortint::server_key::LookupTableSize::new( + self.lwe_bsk.glwe_size(), + self.lwe_bsk.polynomial_size(), + ) } fn kind(&self) -> tfhe::shortint::AtomicPatternKind { - todo!() // maybe + tfhe::shortint::AtomicPatternKind::KeySwitch32 } fn generate_oblivious_pseudo_random( @@ -282,7 +430,7 @@ impl AtomicPattern for ServerKeyKS32MB { pub fn create_sk(cks: &ClientKeyKS32MB, is_2m128: bool) -> ServerKey { let key = ServerKeyKS32MB::new(cks, is_2m128); - let boxed: Box = Box::new(key); + let boxed = Box::new(key); // tfhe/src/shortint/engine/server_side.rs:21 @@ -292,7 +440,7 @@ pub fn create_sk(cks: &ClientKeyKS32MB, is_2m128: bool) -> ServerKey { ); let shortint_sk = tfhe::shortint::ServerKey::from_raw_parts( - *boxed, + AtomicPatternServerKey::Dynamic(boxed), cks.parameters.message_modulus, cks.parameters.carry_modulus, max_degree, @@ -303,7 +451,7 @@ pub fn create_sk(cks: &ClientKeyKS32MB, is_2m128: bool) -> ServerKey { } const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M40: ParamKS32MB = ParamKS32MB { - lwe_dimension: LweDimension(429), + lwe_dimension: LweDimension(429 * 2), glwe_dimension: GlweDimension(1), polynomial_size: PolynomialSize(2048), lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(2.34899e-6)), @@ -324,7 +472,7 @@ const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M40: ParamKS32M }; const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M64: ParamKS32MB = ParamKS32MB { - lwe_dimension: LweDimension(444), + lwe_dimension: LweDimension(444 * 2), glwe_dimension: GlweDimension(1), polynomial_size: PolynomialSize(2048), lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(1.39987e-6)), @@ -345,7 +493,7 @@ const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M64: ParamKS32M }; const BENCH_MULTI_BIT_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_GAUSSIAN_2M128: ParamKS32MB = ParamKS32MB { - lwe_dimension: LweDimension(457), + lwe_dimension: LweDimension(457 * 2), glwe_dimension: GlweDimension(1), polynomial_size: PolynomialSize(2048), lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev( @@ -742,38 +890,44 @@ fn bench_server_key_binary_function_clean_inputs_with_ks32( .measurement_time(std::time::Duration::from_secs(60)); let mut rng = rand::thread_rng(); + let bit_sizes = [4u32, 8, 16, 32, 64, 128, 256]; + for (param_name, param, is_2m128) in KS32_BENCHMARK_PARAM_SET.iter() { - let bench_id; + for bit_size in bit_sizes { + let num_block = bit_size.div_ceil(param.message_modulus.0.ilog2()) as usize; - let bench_data = LazyCell::new(|| { - let cks = ClientKeyKS32MB::new(param); - let sks = create_sk(&cks, *is_2m128); + let bench_id; - let clear_0 = gen_random_u256(&mut rng); - let clear_1 = gen_random_u256(&mut rng); + let bench_data = LazyCell::new(|| { + let cks = ClientKeyKS32MB::new(param); + let sks = create_sk(&cks, *is_2m128); - let ct_0 = cks.encrypt_radix(clear_0, num_block); - let ct_1 = cks.encrypt_radix(clear_1, num_block); - (sks, ct_0, ct_1) - }); + let clear_0 = rand::random::(); + let clear_1 = rand::random::(); - bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits"); - bench_group.bench_function(&bench_id, |b| { - let (sks, ct_0, ct_1) = (&bench_data.0, &bench_data.1, &bench_data.2); - b.iter(|| { - binary_op(sks, ct_0, ct_1); - }) - }); + let ct_0 = cks.encrypt_radix(clear_0, num_block); + let ct_1 = cks.encrypt_radix(clear_1, num_block); + (sks, ct_0, ct_1) + }); - // write_to_json::( - // &bench_id, - // param, - // *param_name, - // display_name, - // &OperatorType::Atomic, - // bit_size as u32, - // vec![param.message_modulus.0.ilog2(); num_block], - // ); + bench_id = format!("{bench_name}::{param_name}::{bit_size}_bits"); + bench_group.bench_function(&bench_id, |b| { + let (sks, ct_0, ct_1) = (&bench_data.0, &bench_data.1, &bench_data.2); + b.iter(|| { + binary_op(sks, ct_0, ct_1); + }) + }); + + // write_to_json::( + // &bench_id, + // param, + // *param_name, + // display_name, + // &OperatorType::Atomic, + // bit_size as u32, + // vec![param.message_modulus.0.ilog2(); num_block], + // ); + } } bench_group.finish() @@ -3382,16 +3536,15 @@ use cuda::{ cuda_cast_ops, default_cuda_dedup_ops, default_cuda_ops, default_scalar_cuda_ops, unchecked_cuda_ops, unchecked_scalar_cuda_ops, }; -use tfhe::boolean::parameters::{ - DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, GlweDimension, - LweDimension, PolynomialSize, StandardDev, -}; use tfhe::core_crypto::entities::LweKeyswitchKeyOwned; use tfhe::core_crypto::prelude::{ GlweSecretKey, LweBskGroupingFactor, LweSecretKey, MonomialDegree, }; use tfhe::shortint::ciphertext::MaxDegree; -use tfhe::shortint::parameters::CiphertextModulus32; +use tfhe::shortint::parameters::{ + CiphertextModulus32, DecompositionBaseLog, DecompositionLevelCount, DynamicDistribution, + GlweDimension, LweDimension, PolynomialSize, StandardDev, +}; use tfhe::shortint::server_key::{ apply_ms_blind_rotate, apply_programmable_bootstrap, ShortintBootstrappingKey, }; diff --git a/tfhe/src/shortint/engine/server_side.rs b/tfhe/src/shortint/engine/server_side.rs index 40a2630ec..77de31dd6 100644 --- a/tfhe/src/shortint/engine/server_side.rs +++ b/tfhe/src/shortint/engine/server_side.rs @@ -26,7 +26,7 @@ impl ShortintEngine { self.new_server_key_with_max_degree(cks, max_degree) } - pub(crate) fn get_thread_count_for_multi_bit_pbs( + pub fn get_thread_count_for_multi_bit_pbs( lwe_dimension: LweDimension, glwe_dimension: GlweDimension, polynomial_size: PolynomialSize,