diff --git a/mockups/tfhe-hpu-mockup/src/lib.rs b/mockups/tfhe-hpu-mockup/src/lib.rs index c9a9d1cfb..722126689 100644 --- a/mockups/tfhe-hpu-mockup/src/lib.rs +++ b/mockups/tfhe-hpu-mockup/src/lib.rs @@ -13,7 +13,6 @@ use tfhe::core_crypto::entities::{ Cleartext, LweCiphertextOwned, LweCiphertextView, LweKeyswitchKey, NttLweBootstrapKey, Plaintext, }; -use tfhe::core_crypto::hpu::algorithms::lwe_keyswitch::hpu_keyswitch_lwe_ciphertext; use tfhe::core_crypto::hpu::glwe_lookuptable::create_hpu_lookuptable; use tfhe::core_crypto::prelude::*; use tfhe::shortint::prelude::ClassicPBSParameters; @@ -654,7 +653,7 @@ impl HpuSim { // TODO add a check on trivialness for fast simulation ? // TODO assert ordering (i.e. KS+PBS) - hpu_keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks); + keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks); blind_rotate_ntt64_bnf_assign(bfr_after_ks, &mut tfhe_lut, &bsk); assert_eq!( diff --git a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf_pbs.rs b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf_pbs.rs index 1cc6304b8..4ca802576 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf_pbs.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping/ntt64_bnf_pbs.rs @@ -584,11 +584,9 @@ pub(crate) fn add_external_product_ntt64_bnf_assign( // DOMAIN In this section, we perform the external product in the ntt // domain, and accumulate the result in the output_fft_buffer variable. let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new( - glwe.as_ref().iter().map(|s| { - decomposer.closest_representable(*s) - >> (u64::BITS - - (decomposer.level_count().0 * decomposer.base_log().0) as u32) - }), + glwe.as_ref() + .iter() + .map(|s| decomposer.init_decomposer_state(*s)), decomposer.base_log(), decomposer.level_count(), substack0, diff --git a/tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs b/tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs deleted file mode 100644 index 2edf5ff76..000000000 --- a/tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs +++ /dev/null @@ -1,532 +0,0 @@ -//! Module containing primitives pertaining to [`LWE ciphertext -//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`). - -use crate::core_crypto::algorithms::slice_algorithms::*; -use crate::core_crypto::commons::math::decomposition::SignedDecomposer; -use crate::core_crypto::commons::parameters::{ - DecompositionBaseLog, DecompositionLevelCount, ThreadCount, -}; -use crate::core_crypto::commons::traits::*; -use crate::core_crypto::entities::*; -use rayon::prelude::*; - -/// Keyswitch an [`LWE ciphertext`](`LweCiphertext`) encrypted under an -/// [`LWE secret key`](`LweSecretKey`) to another [`LWE secret key`](`LweSecretKey`). -/// -/// # Panics -/// -/// Panics if the modulus of the inputs are not power of twos. -/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key` -/// modulus. -/// -/// # Formal Definition -/// -/// See [`LWE keyswitch key`](`LweKeyswitchKey#lwe-keyswitch`). -/// -/// # Example -/// -/// ```rust -/// use tfhe::core_crypto::prelude::*; -/// -/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct -/// // computations -/// // Define parameters for LweKeyswitchKey creation -/// let input_lwe_dimension = LweDimension(742); -/// let lwe_noise_distribution = -/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); -/// let output_lwe_dimension = LweDimension(2048); -/// let decomp_base_log = DecompositionBaseLog(3); -/// let decomp_level_count = DecompositionLevelCount(5); -/// let ciphertext_modulus = CiphertextModulus::new_native(); -/// -/// // Create the PRNG -/// let mut seeder = new_seeder(); -/// let seeder = seeder.as_mut(); -/// let mut encryption_generator = -/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); -/// let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); -/// -/// // Create the LweSecretKey -/// let input_lwe_secret_key = -/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); -/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key( -/// output_lwe_dimension, -/// &mut secret_generator, -/// ); -/// -/// let ksk = allocate_and_generate_new_lwe_keyswitch_key( -/// &input_lwe_secret_key, -/// &output_lwe_secret_key, -/// decomp_base_log, -/// decomp_level_count, -/// lwe_noise_distribution, -/// ciphertext_modulus, -/// &mut encryption_generator, -/// ); -/// -/// // Create the plaintext -/// let msg = 3u64; -/// let plaintext = Plaintext(msg << 60); -/// -/// // Create a new LweCiphertext -/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext( -/// &input_lwe_secret_key, -/// plaintext, -/// lwe_noise_distribution, -/// ciphertext_modulus, -/// &mut encryption_generator, -/// ); -/// -/// let mut output_lwe = LweCiphertext::new( -/// 0, -/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), -/// ciphertext_modulus, -/// ); -/// -/// hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe); -/// -/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe); -/// -/// // Round and remove encoding -/// // First create a decomposer working on the high 4 bits corresponding to our encoding. -/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); -/// -/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); -/// -/// // Remove the encoding -/// let cleartext = rounded >> 60; -/// -/// // Check we recovered the original message -/// assert_eq!(cleartext, msg); -/// ``` -pub fn hpu_keyswitch_lwe_ciphertext( - lwe_keyswitch_key: &LweKeyswitchKey, - input_lwe_ciphertext: &LweCiphertext, - output_lwe_ciphertext: &mut LweCiphertext, -) where - Scalar: UnsignedInteger, - KSKCont: Container, - InputCont: Container, - OutputCont: ContainerMut, -{ - assert!( - lwe_keyswitch_key.input_key_lwe_dimension() - == input_lwe_ciphertext.lwe_size().to_lwe_dimension(), - "Mismatched input LweDimension. \ - LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.", - lwe_keyswitch_key.input_key_lwe_dimension(), - input_lwe_ciphertext.lwe_size().to_lwe_dimension(), - ); - assert!( - lwe_keyswitch_key.output_key_lwe_dimension() - == output_lwe_ciphertext.lwe_size().to_lwe_dimension(), - "Mismatched output LweDimension. \ - LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.", - lwe_keyswitch_key.output_key_lwe_dimension(), - output_lwe_ciphertext.lwe_size().to_lwe_dimension(), - ); - - let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus(); - - assert_eq!( - lwe_keyswitch_key.ciphertext_modulus(), - output_ciphertext_modulus, - "Mismatched CiphertextModulus. \ - LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.", - lwe_keyswitch_key.ciphertext_modulus(), - output_ciphertext_modulus - ); - assert!( - output_ciphertext_modulus.is_compatible_with_native_modulus(), - "This operation currently only supports power of 2 moduli" - ); - - let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus(); - - assert!( - input_ciphertext_modulus.is_compatible_with_native_modulus(), - "This operation currently only supports power of 2 moduli" - ); - - // Clear the output ciphertext, as it will get updated gradually - output_lwe_ciphertext.as_mut().fill(Scalar::ZERO); - - // Copy the input body to the output ciphertext - *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data; - - // If the moduli are not the same, we need to round the body in the output ciphertext - if output_ciphertext_modulus != input_ciphertext_modulus - && !output_ciphertext_modulus.is_native_modulus() - { - let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize; - let output_decomposer = SignedDecomposer::new( - DecompositionBaseLog(modulus_bits), - DecompositionLevelCount(1), - ); - - *output_lwe_ciphertext.get_mut_body().data = - output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data); - } - - // We instantiate a decomposer - let decomposer = SignedDecomposer::new( - lwe_keyswitch_key.decomposition_base_log(), - lwe_keyswitch_key.decomposition_level_count(), - ); - - for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key - .iter() - .zip(input_lwe_ciphertext.get_mask().as_ref()) - { - let decomposition_iter = decomposer.decompose_raw(input_mask_element); - // Loop over the levels - for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter) - { - slice_wrapping_sub_scalar_mul_assign( - output_lwe_ciphertext.as_mut(), - level_key_ciphertext.as_ref(), - decomposed.value(), - ); - } - } -} - -/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`]. -/// -/// This will use all threads available in the current rayon thread pool. -/// -/// # Example -/// -/// ```rust -/// use tfhe::core_crypto::prelude::*; -/// -/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct -/// // computations -/// // Define parameters for LweKeyswitchKey creation -/// let input_lwe_dimension = LweDimension(742); -/// let lwe_noise_distribution = -/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); -/// let output_lwe_dimension = LweDimension(2048); -/// let decomp_base_log = DecompositionBaseLog(3); -/// let decomp_level_count = DecompositionLevelCount(5); -/// let ciphertext_modulus = CiphertextModulus::new_native(); -/// -/// // Create the PRNG -/// let mut seeder = new_seeder(); -/// let seeder = seeder.as_mut(); -/// let mut encryption_generator = -/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); -/// let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); -/// -/// // Create the LweSecretKey -/// let input_lwe_secret_key = -/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); -/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key( -/// output_lwe_dimension, -/// &mut secret_generator, -/// ); -/// -/// let ksk = allocate_and_generate_new_lwe_keyswitch_key( -/// &input_lwe_secret_key, -/// &output_lwe_secret_key, -/// decomp_base_log, -/// decomp_level_count, -/// lwe_noise_distribution, -/// ciphertext_modulus, -/// &mut encryption_generator, -/// ); -/// -/// // Create the plaintext -/// let msg = 3u64; -/// let plaintext = Plaintext(msg << 60); -/// -/// // Create a new LweCiphertext -/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext( -/// &input_lwe_secret_key, -/// plaintext, -/// lwe_noise_distribution, -/// ciphertext_modulus, -/// &mut encryption_generator, -/// ); -/// -/// let mut output_lwe = LweCiphertext::new( -/// 0, -/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), -/// ciphertext_modulus, -/// ); -/// -/// // Use all threads available in the current rayon thread pool -/// par_hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe); -/// -/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe); -/// -/// // Round and remove encoding -/// // First create a decomposer working on the high 4 bits corresponding to our encoding. -/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); -/// -/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); -/// -/// // Remove the encoding -/// let cleartext = rounded >> 60; -/// -/// // Check we recovered the original message -/// assert_eq!(cleartext, msg); -/// ``` -pub fn par_hpu_keyswitch_lwe_ciphertext( - lwe_keyswitch_key: &LweKeyswitchKey, - input_lwe_ciphertext: &LweCiphertext, - output_lwe_ciphertext: &mut LweCiphertext, -) where - Scalar: UnsignedInteger + Send + Sync, - KSKCont: Container, - InputCont: Container, - OutputCont: ContainerMut, -{ - let thread_count = ThreadCount(rayon::current_num_threads()); - par_hpu_keyswitch_lwe_ciphertext_with_thread_count( - lwe_keyswitch_key, - input_lwe_ciphertext, - output_lwe_ciphertext, - thread_count, - ); -} - -/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`]. -/// -/// This will try to use `thread_count` threads for the computation, if this number is bigger than -/// the available number of threads in the current rayon thread pool then only the number of -/// available threads will be used. Note that `thread_count` cannot be 0. -/// -/// # Panics -/// -/// Panics if the modulus of the inputs are not power of twos. -/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key` -/// modulus. -/// -/// # Example -/// -/// ```rust -/// use tfhe::core_crypto::prelude::*; -/// -/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct -/// // computations -/// // Define parameters for LweKeyswitchKey creation -/// let input_lwe_dimension = LweDimension(742); -/// let lwe_noise_distribution = -/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0); -/// let output_lwe_dimension = LweDimension(2048); -/// let decomp_base_log = DecompositionBaseLog(3); -/// let decomp_level_count = DecompositionLevelCount(5); -/// let ciphertext_modulus = CiphertextModulus::new_native(); -/// -/// // Create the PRNG -/// let mut seeder = new_seeder(); -/// let seeder = seeder.as_mut(); -/// let mut encryption_generator = -/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); -/// let mut secret_generator = SecretRandomGenerator::::new(seeder.seed()); -/// -/// // Create the LweSecretKey -/// let input_lwe_secret_key = -/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); -/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key( -/// output_lwe_dimension, -/// &mut secret_generator, -/// ); -/// -/// let ksk = allocate_and_generate_new_lwe_keyswitch_key( -/// &input_lwe_secret_key, -/// &output_lwe_secret_key, -/// decomp_base_log, -/// decomp_level_count, -/// lwe_noise_distribution, -/// ciphertext_modulus, -/// &mut encryption_generator, -/// ); -/// -/// // Create the plaintext -/// let msg = 3u64; -/// let plaintext = Plaintext(msg << 60); -/// -/// // Create a new LweCiphertext -/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext( -/// &input_lwe_secret_key, -/// plaintext, -/// lwe_noise_distribution, -/// ciphertext_modulus, -/// &mut encryption_generator, -/// ); -/// -/// let mut output_lwe = LweCiphertext::new( -/// 0, -/// output_lwe_secret_key.lwe_dimension().to_lwe_size(), -/// ciphertext_modulus, -/// ); -/// -/// // Try to use 4 threads for the keyswitch if enough are available -/// // in the current rayon thread pool -/// par_hpu_keyswitch_lwe_ciphertext_with_thread_count( -/// &ksk, -/// &input_lwe, -/// &mut output_lwe, -/// ThreadCount(4), -/// ); -/// -/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe); -/// -/// // Round and remove encoding -/// // First create a decomposer working on the high 4 bits corresponding to our encoding. -/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1)); -/// -/// let rounded = decomposer.closest_representable(decrypted_plaintext.0); -/// -/// // Remove the encoding -/// let cleartext = rounded >> 60; -/// -/// // Check we recovered the original message -/// assert_eq!(cleartext, msg); -/// ``` -pub fn par_hpu_keyswitch_lwe_ciphertext_with_thread_count( - lwe_keyswitch_key: &LweKeyswitchKey, - input_lwe_ciphertext: &LweCiphertext, - output_lwe_ciphertext: &mut LweCiphertext, - thread_count: ThreadCount, -) where - Scalar: UnsignedInteger + Send + Sync, - KSKCont: Container, - InputCont: Container, - OutputCont: ContainerMut, -{ - assert!( - lwe_keyswitch_key.input_key_lwe_dimension() - == input_lwe_ciphertext.lwe_size().to_lwe_dimension(), - "Mismatched input LweDimension. \ - LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.", - lwe_keyswitch_key.input_key_lwe_dimension(), - input_lwe_ciphertext.lwe_size().to_lwe_dimension(), - ); - assert!( - lwe_keyswitch_key.output_key_lwe_dimension() - == output_lwe_ciphertext.lwe_size().to_lwe_dimension(), - "Mismatched output LweDimension. \ - LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.", - lwe_keyswitch_key.output_key_lwe_dimension(), - output_lwe_ciphertext.lwe_size().to_lwe_dimension(), - ); - - let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus(); - - assert_eq!( - lwe_keyswitch_key.ciphertext_modulus(), - output_ciphertext_modulus, - "Mismatched CiphertextModulus. \ - LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.", - lwe_keyswitch_key.ciphertext_modulus(), - output_ciphertext_modulus - ); - assert!( - output_ciphertext_modulus.is_compatible_with_native_modulus(), - "This operation currently only supports power of 2 moduli" - ); - - let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus(); - - assert!( - input_ciphertext_modulus.is_compatible_with_native_modulus(), - "This operation currently only supports power of 2 moduli" - ); - - assert!( - thread_count.0 != 0, - "Got thread_count == 0, this is not supported" - ); - - // Clear the output ciphertext, as it will get updated gradually - output_lwe_ciphertext.as_mut().fill(Scalar::ZERO); - - let output_lwe_size = output_lwe_ciphertext.lwe_size(); - - // Copy the input body to the output ciphertext - *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data; - - // If the moduli are not the same, we need to round the body in the output ciphertext - if output_ciphertext_modulus != input_ciphertext_modulus - && !output_ciphertext_modulus.is_native_modulus() - { - let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize; - let output_decomposer = SignedDecomposer::new( - DecompositionBaseLog(modulus_bits), - DecompositionLevelCount(1), - ); - - *output_lwe_ciphertext.get_mut_body().data = - output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data); - } - - // We instantiate a decomposer - let decomposer = SignedDecomposer::new( - lwe_keyswitch_key.decomposition_base_log(), - lwe_keyswitch_key.decomposition_level_count(), - ); - - // Don't go above the current number of threads - let thread_count = thread_count.0.min(rayon::current_num_threads()); - let mut intermediate_accumulators = Vec::with_capacity(thread_count); - - // Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size - let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count); - - lwe_keyswitch_key - .par_chunks(chunk_size) - .zip( - input_lwe_ciphertext - .get_mask() - .as_ref() - .par_chunks(chunk_size), - ) - .map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| { - let mut buffer = - LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus); - - for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk - .iter() - .zip(input_mask_element_chunk.iter()) - { - let decomposition_iter = decomposer.decompose_raw(input_mask_element); - // Loop over the levels - for (level_key_ciphertext, decomposed) in - keyswitch_key_block.iter().zip(decomposition_iter) - { - slice_wrapping_sub_scalar_mul_assign( - buffer.as_mut(), - level_key_ciphertext.as_ref(), - decomposed.value(), - ); - } - } - buffer - }) - .collect_into_vec(&mut intermediate_accumulators); - - let reduced = intermediate_accumulators - .par_iter_mut() - .reduce_with(|lhs, rhs| { - lhs.as_mut() - .iter_mut() - .zip(rhs.as_ref().iter()) - .for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src)); - - lhs - }) - .unwrap(); - - output_lwe_ciphertext - .get_mut_mask() - .as_mut() - .copy_from_slice(reduced.get_mask().as_ref()); - let reduced_ksed_body = *reduced.get_body().data; - - // Add the reduced body of the keyswitch to the output body to complete the keyswitch - *output_lwe_ciphertext.get_mut_body().data = - (*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body); -} diff --git a/tfhe/src/core_crypto/hpu/algorithms/mod.rs b/tfhe/src/core_crypto/hpu/algorithms/mod.rs index 6cf39c3fb..bffde9c47 100644 --- a/tfhe/src/core_crypto/hpu/algorithms/mod.rs +++ b/tfhe/src/core_crypto/hpu/algorithms/mod.rs @@ -1,3 +1,2 @@ -pub mod lwe_keyswitch; pub mod modswitch; pub mod order;