diff --git a/mockups/tfhe-hpu-mockup/src/lib.rs b/mockups/tfhe-hpu-mockup/src/lib.rs index 59cf75308..42a309e28 100644 --- a/mockups/tfhe-hpu-mockup/src/lib.rs +++ b/mockups/tfhe-hpu-mockup/src/lib.rs @@ -13,6 +13,7 @@ use tfhe::core_crypto::entities::{ Cleartext, LweCiphertextOwned, LweCiphertextView, LweKeyswitchKey, NttLweBootstrapKey, Plaintext, }; +use tfhe::core_crypto::hpu::algorithms::lwe_keyswitch::hpu_keyswitch_lwe_ciphertext; use tfhe::core_crypto::hpu::from_with::FromWith; use tfhe::core_crypto::hpu::glwe_lookuptable::create_hpu_lookuptable; use tfhe::core_crypto::prelude::*; @@ -648,7 +649,7 @@ impl HpuSim { // TODO add a check on trivialness for fast simulation ? // TODO assert ordering (i.e. KS+PBS) - keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks); + hpu_keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks); blind_rotate_ntt64_bnf_assign(bfr_after_ks, &mut tfhe_lut, &bsk); assert_eq!( diff --git a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf.rs b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf.rs index 37696554c..cb39142a7 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf.rs @@ -587,9 +587,11 @@ pub(crate) fn add_external_product_ntt64_bnf_assign( // DOMAIN In this section, we perform the external product in the ntt // domain, and accumulate the result in the output_fft_buffer variable. let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new( - glwe.as_ref() - .iter() - .map(|s| decomposer.init_decomposer_state(*s)), + glwe.as_ref().iter().map(|s| { + decomposer.closest_representable(*s) + >> (u64::BITS + - (decomposer.level_count().0 * decomposer.base_log().0) as u32) + }), decomposer.base_log(), decomposer.level_count(), substack0, diff --git a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs index 95f3cb187..4374d9d44 100644 --- a/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs +++ b/tfhe/src/core_crypto/commons/math/decomposition/decomposer.rs @@ -55,7 +55,6 @@ pub fn native_closest_representable( /// returns 1 if the following if condition is true otherwise 0 /// /// (val > B / 2) || ((val == B / 2) && (random == 1)) -#[cfg(any(not(feature = "hpu"), test))] #[inline(always)] fn balanced_rounding_condition_bit_trick( val: Scalar, @@ -147,43 +146,33 @@ where #[inline(always)] pub fn init_decomposer_state(&self, input: Scalar) -> Scalar { - // Default mode -> use balanced decomposition - #[cfg(not(feature = "hpu"))] - { - // The closest number representable by the decomposition can be computed by performing - // the rounding at the appropriate bit. + // The closest number representable by the decomposition can be computed by performing + // the rounding at the appropriate bit. - // We compute the number of least significant bits which can not be represented by the - // decomposition - // Example with level_count = 3, base_log = 4 and BITS == 64 -> 52 - let rep_bit_count = self.level_count * self.base_log; - let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count; - // Move the representable bits + 1 to the LSB, with our example : - // |-----| 64 - (64 - 12 - 1) == 13 bits - // 0....0XX...XX - let mut res = input >> (non_rep_bit_count - 1); - // Fetch the first bit value as we need it for a balanced rounding - let rounding_bit = res & Scalar::ONE; - // Add one to do the rounding by adding the half interval - res += Scalar::ONE; - // Discard the LSB which was the one deciding in which direction we round - res >>= 1; - // Keep the low base_log * level bits - let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count); - res &= mod_mask; - // Control bit about whether we should balance the state - // This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == 1) - let need_balance = - balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit); - // Balance depending on the control bit - res.wrapping_sub(need_balance << rep_bit_count) - } - - // Hpu used unbalanced decomposition - #[cfg(feature = "hpu")] - { - self.closest_representable(input) >> (Scalar::BITS - (self.level_count * self.base_log)) - } + // We compute the number of least significant bits which can not be represented by the + // decomposition + // Example with level_count = 3, base_log = 4 and BITS == 64 -> 52 + let rep_bit_count = self.level_count * self.base_log; + let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count; + // Move the representable bits + 1 to the LSB, with our example : + // |-----| 64 - (64 - 12 - 1) == 13 bits + // 0....0XX...XX + let mut res = input >> (non_rep_bit_count - 1); + // Fetch the first bit value as we need it for a balanced rounding + let rounding_bit = res & Scalar::ONE; + // Add one to do the rounding by adding the half interval + res += Scalar::ONE; + // Discard the LSB which was the one deciding in which direction we round + res >>= 1; + // Keep the low base_log * level bits + let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count); + res &= mod_mask; + // Control bit about whether we should balance the state + // This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == + // 1) + let need_balance = balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit); + // Balance depending on the control bit + res.wrapping_sub(need_balance << rep_bit_count) } /// Generate an iterator over the terms of the decomposition of the input. @@ -228,6 +217,48 @@ where ) } + /// Generate an iterator over the terms of the decomposition of the input. + /// # Warning + /// This used unbalanced decomposition and shouldn't be used with one-level decomposition + /// The returned iterator yields the terms $\tilde{\theta}\_i$ in order of decreasing $i$. + /// + /// # Example + /// + /// ```rust + /// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer; + /// use tfhe::core_crypto::commons::numeric::UnsignedInteger; + /// use tfhe::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount}; + /// let decomposer = + /// SignedDecomposer::::new(DecompositionBaseLog(4), DecompositionLevelCount(3)); + /// // 2147483647 == 2^31 - 1 and has a decomposition term == to half_basis + /// for term in decomposer.decompose(2147483647u32) { + /// assert!(1 <= term.level().0); + /// assert!(term.level().0 <= 3); + /// let signed_term = term.value().into_signed(); + /// let half_basis = 2i32.pow(4) / 2i32; + /// assert!( + /// -half_basis <= signed_term, + /// "{} <= {signed_term} failed", + /// -half_basis + /// ); + /// assert!( + /// signed_term <= half_basis, + /// "{signed_term} <= {half_basis} failed" + /// ); + /// } + /// assert_eq!(decomposer.decompose(1).count(), 3); + /// ``` + pub fn decompose_raw(&self, input: Scalar) -> SignedDecompositionIter { + // Note that there would be no sense of making the decomposition on an input which was + // not rounded to the closest representable first. We then perform it before decomposing. + SignedDecompositionIter::new( + self.closest_representable(input) + >> (Scalar::BITS - (self.level_count * self.base_log)), + DecompositionBaseLog(self.base_log), + DecompositionLevelCount(self.level_count), + ) + } + /// Recomposes a decomposed value by summing all the terms. /// /// If the input iterator yields $\tilde{\theta}\_i$, this returns diff --git a/tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs b/tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs new file mode 100644 index 000000000..2edf5ff76 --- /dev/null +++ b/tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs @@ -0,0 +1,532 @@ +//! Module containing primitives pertaining to [`LWE ciphertext +//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`). + +use crate::core_crypto::algorithms::slice_algorithms::*; +use crate::core_crypto::commons::math::decomposition::SignedDecomposer; +use crate::core_crypto::commons::parameters::{ + DecompositionBaseLog, DecompositionLevelCount, ThreadCount, +}; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use rayon::prelude::*; + +/// Keyswitch an [`LWE ciphertext`](`LweCiphertext`) encrypted under an +/// [`LWE secret key`](`LweSecretKey`) to another [`LWE secret key`](`LweSecretKey`). +/// +/// # Panics +/// +/// Panics if the modulus of the inputs are not power of twos. +/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key` +/// modulus. +/// +/// # Formal Definition +/// +/// See [`LWE keyswitch key`](`LweKeyswitchKey#lwe-keyswitch`). +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweKeyswitchKey creation +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let output_lwe_dimension = LweDimension(2048); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(5); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key( +/// output_lwe_dimension, +/// &mut secret_generator, +/// ); +/// +/// let ksk = allocate_and_generate_new_lwe_keyswitch_key( +/// &input_lwe_secret_key, +/// &output_lwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// lwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Create the plaintext +/// let msg = 3u64; +/// let plaintext = Plaintext(msg << 60); +/// +/// // Create a new LweCiphertext +/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext( +/// &input_lwe_secret_key, +/// plaintext, +/// lwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_lwe = LweCiphertext::new( +/// 0, +/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), +/// ciphertext_modulus, +/// ); +/// +/// hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe); +/// +/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); +/// +/// // Remove the encoding +/// let cleartext = rounded >> 60; +/// +/// // Check we recovered the original message +/// assert_eq!(cleartext, msg); +/// ``` +pub fn hpu_keyswitch_lwe_ciphertext( + lwe_keyswitch_key: &LweKeyswitchKey, + input_lwe_ciphertext: &LweCiphertext, + output_lwe_ciphertext: &mut LweCiphertext, +) where + Scalar: UnsignedInteger, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + lwe_keyswitch_key.input_key_lwe_dimension() + == input_lwe_ciphertext.lwe_size().to_lwe_dimension(), + "Mismatched input LweDimension. \ + LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.", + lwe_keyswitch_key.input_key_lwe_dimension(), + input_lwe_ciphertext.lwe_size().to_lwe_dimension(), + ); + assert!( + lwe_keyswitch_key.output_key_lwe_dimension() + == output_lwe_ciphertext.lwe_size().to_lwe_dimension(), + "Mismatched output LweDimension. \ + LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.", + lwe_keyswitch_key.output_key_lwe_dimension(), + output_lwe_ciphertext.lwe_size().to_lwe_dimension(), + ); + + let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus(); + + assert_eq!( + lwe_keyswitch_key.ciphertext_modulus(), + output_ciphertext_modulus, + "Mismatched CiphertextModulus. \ + LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.", + lwe_keyswitch_key.ciphertext_modulus(), + output_ciphertext_modulus + ); + assert!( + output_ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports power of 2 moduli" + ); + + let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus(); + + assert!( + input_ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports power of 2 moduli" + ); + + // Clear the output ciphertext, as it will get updated gradually + output_lwe_ciphertext.as_mut().fill(Scalar::ZERO); + + // Copy the input body to the output ciphertext + *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data; + + // If the moduli are not the same, we need to round the body in the output ciphertext + if output_ciphertext_modulus != input_ciphertext_modulus + && !output_ciphertext_modulus.is_native_modulus() + { + let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize; + let output_decomposer = SignedDecomposer::new( + DecompositionBaseLog(modulus_bits), + DecompositionLevelCount(1), + ); + + *output_lwe_ciphertext.get_mut_body().data = + output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data); + } + + // We instantiate a decomposer + let decomposer = SignedDecomposer::new( + lwe_keyswitch_key.decomposition_base_log(), + lwe_keyswitch_key.decomposition_level_count(), + ); + + for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key + .iter() + .zip(input_lwe_ciphertext.get_mask().as_ref()) + { + let decomposition_iter = decomposer.decompose_raw(input_mask_element); + // Loop over the levels + for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter) + { + slice_wrapping_sub_scalar_mul_assign( + output_lwe_ciphertext.as_mut(), + level_key_ciphertext.as_ref(), + decomposed.value(), + ); + } + } +} + +/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`]. +/// +/// This will use all threads available in the current rayon thread pool. +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweKeyswitchKey creation +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let output_lwe_dimension = LweDimension(2048); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(5); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key( +/// output_lwe_dimension, +/// &mut secret_generator, +/// ); +/// +/// let ksk = allocate_and_generate_new_lwe_keyswitch_key( +/// &input_lwe_secret_key, +/// &output_lwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// lwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Create the plaintext +/// let msg = 3u64; +/// let plaintext = Plaintext(msg << 60); +/// +/// // Create a new LweCiphertext +/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext( +/// &input_lwe_secret_key, +/// plaintext, +/// lwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_lwe = LweCiphertext::new( +/// 0, +/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), +/// ciphertext_modulus, +/// ); +/// +/// // Use all threads available in the current rayon thread pool +/// par_hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe); +/// +/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); +/// +/// // Remove the encoding +/// let cleartext = rounded >> 60; +/// +/// // Check we recovered the original message +/// assert_eq!(cleartext, msg); +/// ``` +pub fn par_hpu_keyswitch_lwe_ciphertext( + lwe_keyswitch_key: &LweKeyswitchKey, + input_lwe_ciphertext: &LweCiphertext, + output_lwe_ciphertext: &mut LweCiphertext, +) where + Scalar: UnsignedInteger + Send + Sync, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + let thread_count = ThreadCount(rayon::current_num_threads()); + par_hpu_keyswitch_lwe_ciphertext_with_thread_count( + lwe_keyswitch_key, + input_lwe_ciphertext, + output_lwe_ciphertext, + thread_count, + ); +} + +/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`]. +/// +/// This will try to use `thread_count` threads for the computation, if this number is bigger than +/// the available number of threads in the current rayon thread pool then only the number of +/// available threads will be used. Note that `thread_count` cannot be 0. +/// +/// # Panics +/// +/// Panics if the modulus of the inputs are not power of twos. +/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key` +/// modulus. +/// +/// # Example +/// +/// ```rust +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweKeyswitchKey creation +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_noise_distribution = +/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); +/// let output_lwe_dimension = LweDimension(2048); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(5); +/// let ciphertext_modulus = CiphertextModulus::new_native(); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key( +/// output_lwe_dimension, +/// &mut secret_generator, +/// ); +/// +/// let ksk = allocate_and_generate_new_lwe_keyswitch_key( +/// &input_lwe_secret_key, +/// &output_lwe_secret_key, +/// decomp_base_log, +/// decomp_level_count, +/// lwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// // Create the plaintext +/// let msg = 3u64; +/// let plaintext = Plaintext(msg << 60); +/// +/// // Create a new LweCiphertext +/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext( +/// &input_lwe_secret_key, +/// plaintext, +/// lwe_noise_distribution, +/// ciphertext_modulus, +/// &mut encryption_generator, +/// ); +/// +/// let mut output_lwe = LweCiphertext::new( +/// 0, +/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), +/// ciphertext_modulus, +/// ); +/// +/// // Try to use 4 threads for the keyswitch if enough are available +/// // in the current rayon thread pool +/// par_hpu_keyswitch_lwe_ciphertext_with_thread_count( +/// &ksk, +/// &input_lwe, +/// &mut output_lwe, +/// ThreadCount(4), +/// ); +/// +/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe); +/// +/// // Round and remove encoding +/// // First create a decomposer working on the high 4 bits corresponding to our encoding. +/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); +/// +/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); +/// +/// // Remove the encoding +/// let cleartext = rounded >> 60; +/// +/// // Check we recovered the original message +/// assert_eq!(cleartext, msg); +/// ``` +pub fn par_hpu_keyswitch_lwe_ciphertext_with_thread_count( + lwe_keyswitch_key: &LweKeyswitchKey, + input_lwe_ciphertext: &LweCiphertext, + output_lwe_ciphertext: &mut LweCiphertext, + thread_count: ThreadCount, +) where + Scalar: UnsignedInteger + Send + Sync, + KSKCont: Container, + InputCont: Container, + OutputCont: ContainerMut, +{ + assert!( + lwe_keyswitch_key.input_key_lwe_dimension() + == input_lwe_ciphertext.lwe_size().to_lwe_dimension(), + "Mismatched input LweDimension. \ + LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.", + lwe_keyswitch_key.input_key_lwe_dimension(), + input_lwe_ciphertext.lwe_size().to_lwe_dimension(), + ); + assert!( + lwe_keyswitch_key.output_key_lwe_dimension() + == output_lwe_ciphertext.lwe_size().to_lwe_dimension(), + "Mismatched output LweDimension. \ + LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.", + lwe_keyswitch_key.output_key_lwe_dimension(), + output_lwe_ciphertext.lwe_size().to_lwe_dimension(), + ); + + let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus(); + + assert_eq!( + lwe_keyswitch_key.ciphertext_modulus(), + output_ciphertext_modulus, + "Mismatched CiphertextModulus. \ + LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.", + lwe_keyswitch_key.ciphertext_modulus(), + output_ciphertext_modulus + ); + assert!( + output_ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports power of 2 moduli" + ); + + let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus(); + + assert!( + input_ciphertext_modulus.is_compatible_with_native_modulus(), + "This operation currently only supports power of 2 moduli" + ); + + assert!( + thread_count.0 != 0, + "Got thread_count == 0, this is not supported" + ); + + // Clear the output ciphertext, as it will get updated gradually + output_lwe_ciphertext.as_mut().fill(Scalar::ZERO); + + let output_lwe_size = output_lwe_ciphertext.lwe_size(); + + // Copy the input body to the output ciphertext + *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data; + + // If the moduli are not the same, we need to round the body in the output ciphertext + if output_ciphertext_modulus != input_ciphertext_modulus + && !output_ciphertext_modulus.is_native_modulus() + { + let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize; + let output_decomposer = SignedDecomposer::new( + DecompositionBaseLog(modulus_bits), + DecompositionLevelCount(1), + ); + + *output_lwe_ciphertext.get_mut_body().data = + output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data); + } + + // We instantiate a decomposer + let decomposer = SignedDecomposer::new( + lwe_keyswitch_key.decomposition_base_log(), + lwe_keyswitch_key.decomposition_level_count(), + ); + + // Don't go above the current number of threads + let thread_count = thread_count.0.min(rayon::current_num_threads()); + let mut intermediate_accumulators = Vec::with_capacity(thread_count); + + // Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size + let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count); + + lwe_keyswitch_key + .par_chunks(chunk_size) + .zip( + input_lwe_ciphertext + .get_mask() + .as_ref() + .par_chunks(chunk_size), + ) + .map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| { + let mut buffer = + LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus); + + for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk + .iter() + .zip(input_mask_element_chunk.iter()) + { + let decomposition_iter = decomposer.decompose_raw(input_mask_element); + // Loop over the levels + for (level_key_ciphertext, decomposed) in + keyswitch_key_block.iter().zip(decomposition_iter) + { + slice_wrapping_sub_scalar_mul_assign( + buffer.as_mut(), + level_key_ciphertext.as_ref(), + decomposed.value(), + ); + } + } + buffer + }) + .collect_into_vec(&mut intermediate_accumulators); + + let reduced = intermediate_accumulators + .par_iter_mut() + .reduce_with(|lhs, rhs| { + lhs.as_mut() + .iter_mut() + .zip(rhs.as_ref().iter()) + .for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src)); + + lhs + }) + .unwrap(); + + output_lwe_ciphertext + .get_mut_mask() + .as_mut() + .copy_from_slice(reduced.get_mask().as_ref()); + let reduced_ksed_body = *reduced.get_body().data; + + // Add the reduced body of the keyswitch to the output body to complete the keyswitch + *output_lwe_ciphertext.get_mut_body().data = + (*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body); +} diff --git a/tfhe/src/core_crypto/hpu/algorithms/mod.rs b/tfhe/src/core_crypto/hpu/algorithms/mod.rs index bffde9c47..6cf39c3fb 100644 --- a/tfhe/src/core_crypto/hpu/algorithms/mod.rs +++ b/tfhe/src/core_crypto/hpu/algorithms/mod.rs @@ -1,2 +1,3 @@ +pub mod lwe_keyswitch; pub mod modswitch; pub mod order;