diff --git a/tfhe/src/core_crypto/algorithms/ggsw_encryption.rs b/tfhe/src/core_crypto/algorithms/ggsw_encryption.rs index e51cda784..b929a11cc 100644 --- a/tfhe/src/core_crypto/algorithms/ggsw_encryption.rs +++ b/tfhe/src/core_crypto/algorithms/ggsw_encryption.rs @@ -3,10 +3,14 @@ use crate::core_crypto::algorithms::*; use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator; use crate::core_crypto::commons::math::decomposition::DecompositionLevel; use crate::core_crypto::commons::math::random::ByteRandomGenerator; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::*; use crate::core_crypto::specification::dispersion::DispersionParameter; +#[cfg(feature = "__commons_parallel")] +use rayon::prelude::*; pub fn encrypt_ggsw_ciphertext( glwe_secret_key: &GlweSecretKeyBase, @@ -16,9 +20,9 @@ pub fn encrypt_ggsw_ciphertext( generator: &mut EncryptionRandomGenerator, ) where Scalar: UnsignedTorus, - Gen: ByteRandomGenerator, KeyCont: Container, OutputCont: ContainerMut, + Gen: ByteRandomGenerator, { assert!( output.polynomial_size() == glwe_secret_key.polynomial_size(), @@ -85,6 +89,84 @@ pub fn encrypt_ggsw_ciphertext( } } +#[cfg(feature = "__commons_parallel")] +pub fn par_encrypt_ggsw_ciphertext( + glwe_secret_key: &GlweSecretKeyBase, + output: &mut GgswCiphertextBase, + encoded: Plaintext, + noise_parameters: impl DispersionParameter + Sync, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: UnsignedTorus + Sync + Send, + KeyCont: Container + Sync, + OutputCont: ContainerMut, + Gen: ParallelByteRandomGenerator, +{ + assert!( + output.polynomial_size() == glwe_secret_key.polynomial_size(), + "Mismatch between polynomial sizes of output cipertexts and input secret key. \ + Got {:?} in output, and {:?} in secret key.", + output.polynomial_size(), + glwe_secret_key.polynomial_size() + ); + + assert!( + output.glwe_size().to_glwe_dimension() == glwe_secret_key.glwe_dimension(), + "Mismatch between GlweDimension of output cipertexts and input secret key. \ + Got {:?} in output, and {:?} in secret key.", + output.glwe_size().to_glwe_dimension(), + glwe_secret_key.glwe_dimension() + ); + + // Generators used to have same sequential and parallel key generation + let gen_iter = generator + .par_fork_ggsw_to_ggsw_levels::( + output.decomposition_level_count(), + output.glwe_size(), + output.polynomial_size(), + ) + .expect("Failed to split generator into ggsw levels"); + + let output_glwe_size = output.glwe_size(); + let output_polynomial_size = output.polynomial_size(); + let decomp_base_log = output.decomposition_base_log(); + + output.par_iter_mut().zip(gen_iter).enumerate().for_each( + |(level_index, (mut level_matrix, mut generator))| { + let decomp_level = DecompositionLevel(level_index + 1); + let factor = encoded + .0 + .wrapping_neg() + .wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0))); + + // We iterate over the rows of the level matrix, the last row needs special treatment + let gen_iter = generator + .par_fork_ggsw_level_to_glwe::(output_glwe_size, output_polynomial_size) + .expect("Failed to split generator into glwe"); + + let last_row_index = level_matrix.glwe_size().0 - 1; + let sk_poly_list = glwe_secret_key.as_polynomial_list(); + + level_matrix + .as_mut_glwe_list() + .par_iter_mut() + .enumerate() + .zip(gen_iter) + .for_each(|((row_index, mut row_as_glwe), mut generator)| { + encrypt_ggsw_level_matrix_row( + glwe_secret_key, + (row_index, last_row_index), + factor, + &sk_poly_list, + &mut row_as_glwe, + noise_parameters, + &mut generator, + ); + }); + }, + ); +} + fn encrypt_ggsw_level_matrix_row( glwe_secret_key: &GlweSecretKeyBase, (row_index, last_row_index): (usize, usize), diff --git a/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_generation.rs b/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_generation.rs index 89b24b86b..17e5cec76 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_generation.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_bootstrap_key_generation.rs @@ -1,10 +1,14 @@ use crate::core_crypto::algorithms::*; use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator; use crate::core_crypto::commons::math::random::ByteRandomGenerator; +#[cfg(feature = "__commons_parallel")] +use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator; use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::*; use crate::core_crypto::specification::dispersion::DispersionParameter; +#[cfg(feature = "__commons_parallel")] +use rayon::prelude::*; pub fn generate_lwe_bootstrap_key( input_lwe_secret_key: &LweSecretKeyBase, @@ -67,6 +71,68 @@ pub fn generate_lwe_bootstrap_key( + input_lwe_secret_key: &LweSecretKeyBase, + output_glwe_secret_key: &GlweSecretKeyBase, + output: &mut LweBootstrapKeyBase, + noise_parameters: impl DispersionParameter + Sync, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: UnsignedTorus + Sync + Send, + InputKeyCont: Container, + OutputKeyCont: Container + Sync, + OutputCont: ContainerMut, + Gen: ParallelByteRandomGenerator, +{ + assert!( + output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(), + "Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \ + Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.", + input_lwe_secret_key.lwe_dimension(), + output.input_lwe_dimension() + ); + + assert!( + output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(), + "Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \ + Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.", + output_glwe_secret_key.glwe_dimension().to_glwe_size(), + output.glwe_size() + ); + + assert!( + output.polynomial_size() == output_glwe_secret_key.polynomial_size(), + "Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \ + Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.", + output_glwe_secret_key.polynomial_size(), + output.polynomial_size() + ); + + let gen_iter = generator + .par_fork_bsk_to_ggsw::( + output.input_lwe_dimension(), + output.decomposition_level_count(), + output.glwe_size(), + output.polynomial_size(), + ) + .unwrap(); + + output + .par_iter_mut() + .zip(input_lwe_secret_key.as_ref().par_iter()) + .zip(gen_iter) + .for_each(|((mut ggsw, &input_key_element), mut generator)| { + par_encrypt_ggsw_ciphertext( + output_glwe_secret_key, + &mut ggsw, + Plaintext(input_key_element), + noise_parameters, + &mut generator, + ); + }) +} + #[cfg(test)] mod test { use crate::core_crypto::algorithms::generate_lwe_bootstrap_key; @@ -173,3 +239,117 @@ mod test { test_refactored_bsk_equivalence::() } } + +#[cfg(feature = "__commons_parallel")] +#[cfg(test)] +mod parallel_test { + use crate::core_crypto::algorithms::{ + allocate_and_generate_new_binary_glwe_secret_key, + allocate_and_generate_new_binary_lwe_secret_key, generate_lwe_bootstrap_key, + par_generate_lwe_bootstrap_key, + }; + use crate::core_crypto::commons::crypto::secret::generators::{ + DeterministicSeeder, EncryptionRandomGenerator, + }; + use crate::core_crypto::commons::math::random::Seed; + use crate::core_crypto::commons::math::torus::UnsignedTorus; + use crate::core_crypto::commons::test_tools::new_secret_random_generator; + use crate::core_crypto::entities::LweBootstrapKey; + use crate::core_crypto::prelude::{ + DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize, + StandardDev, + }; + use concrete_csprng::generators::SoftwareRandomGenerator; + + fn test_refactored_bsk_parallel_gen_equivalence() { + for _ in 0..10 { + let lwe_dim = + LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let glwe_dim = + GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10)); + let poly_size = PolynomialSize( + crate::core_crypto::commons::test_tools::random_usize_between(5..10), + ); + let level = DecompositionLevelCount( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let base_log = DecompositionBaseLog( + crate::core_crypto::commons::test_tools::random_usize_between(2..5), + ); + let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + let deterministic_seeder_seed = + Seed(crate::core_crypto::commons::test_tools::any_usize() as u128); + + let mut secret_generator = new_secret_random_generator(); + let lwe_sk = + allocate_and_generate_new_binary_lwe_secret_key(lwe_dim, &mut secret_generator); + let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key( + glwe_dim, + poly_size, + &mut secret_generator, + ); + + let mut parallel_bsk = LweBootstrapKey::new( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + base_log, + level, + lwe_dim, + ); + + let mut encryption_generator = + EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new( + deterministic_seeder_seed, + ), + ); + + par_generate_lwe_bootstrap_key( + &lwe_sk, + &glwe_sk, + &mut parallel_bsk, + StandardDev::from_standard_dev(10.), + &mut encryption_generator, + ); + + let mut sequential_bsk = LweBootstrapKey::new( + T::ZERO, + glwe_dim.to_glwe_size(), + poly_size, + base_log, + level, + lwe_dim, + ); + + let mut encryption_generator = + EncryptionRandomGenerator::::new( + mask_seed, + &mut DeterministicSeeder::::new( + deterministic_seeder_seed, + ), + ); + + generate_lwe_bootstrap_key( + &lwe_sk, + &glwe_sk, + &mut sequential_bsk, + StandardDev::from_standard_dev(10.), + &mut encryption_generator, + ); + + assert_eq!(parallel_bsk.as_ref(), sequential_bsk.as_ref()); + } + } + + #[test] + fn test_refactored_bsk_parallel_gen_equivalence_u32() { + test_refactored_bsk_parallel_gen_equivalence::() + } + + #[test] + fn test_refactored_bsk_parallel_gen_equivalence_u64() { + test_refactored_bsk_parallel_gen_equivalence::() + } +} diff --git a/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs b/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs index 334da46c6..1a25a2f18 100644 --- a/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs +++ b/tfhe/src/core_crypto/commons/traits/contiguous_entity_container.rs @@ -1,4 +1,6 @@ use super::create_from::*; +#[cfg(feature = "__commons_parallel")] +use rayon::prelude::*; type WrappingFunction<'data, Element, WrappingType> = fn( ( @@ -15,6 +17,15 @@ type WrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map< WrappingFunction<'data, Element, WrappingType>, >; +#[cfg(feature = "__commons_parallel")] +type ParallelWrappingLendingIterator<'data, Element, WrappingType> = rayon::iter::Map< + rayon::iter::Zip< + rayon::slice::Chunks<'data, Element>, + rayon::iter::RepeatN<>::Metadata>, + >, + WrappingFunction<'data, Element, WrappingType>, +>; + // This is required as at the moment it's not possible to reverse a zip containing a repeat, though // it is perfectly legal to zip a reversed repeat type RevWrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map< @@ -40,6 +51,15 @@ type WrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map< WrappingFunctionMut<'data, Element, WrappingType>, >; +#[cfg(feature = "__commons_parallel")] +type ParallelWrappingLendingIteratorMut<'data, Element, WrappingType> = rayon::iter::Map< + rayon::iter::Zip< + rayon::slice::ChunksMut<'data, Element>, + rayon::iter::RepeatN<>::Metadata>, + >, + WrappingFunctionMut<'data, Element, WrappingType>, +>; + // This is required as at the moment it's not possible to reverse a zip containing a repeat, though // it is perfectly legal to zip a reversed repeat type RevWrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map< @@ -121,6 +141,24 @@ pub trait ContiguousEntityContainer: AsRef<[Self::Element]> { Self::EntityView::<'_>::create_from(&self.as_ref()[start..stop], meta) } + + #[cfg(feature = "__commons_parallel")] + fn par_iter<'this>( + &'this self, + ) -> ParallelWrappingLendingIterator<'this, Self::Element, Self::EntityView<'this>> + where + Self::Element: Sync, + Self::EntityView<'this>: Send, + Self::EntityViewMetadata: Send, + { + let meta = self.get_entity_view_creation_metadata(); + let entity_view_pod_size = self.get_entity_view_pod_size(); + let entity_count = self.as_ref().len() / entity_view_pod_size; + self.as_ref() + .par_chunks(entity_view_pod_size) + .zip(rayon::iter::repeatn(meta, entity_count)) + .map(|(elt, meta)| Self::EntityView::<'this>::create_from(elt, meta)) + } } pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self::Element]> { @@ -188,4 +226,22 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self: Self::EntityMutView::<'_>::create_from(&mut self.as_mut()[start..stop], meta) } + + #[cfg(feature = "__commons_parallel")] + fn par_iter_mut<'this>( + &'this mut self, + ) -> ParallelWrappingLendingIteratorMut<'this, Self::Element, Self::EntityMutView<'this>> + where + Self::Element: Send + Sync, + Self::EntityMutView<'this>: Send, + Self::EntityViewMetadata: Send, + { + let meta = self.get_entity_view_creation_metadata(); + let entity_view_pod_size = self.get_entity_view_pod_size(); + let entity_count = self.as_ref().len() / entity_view_pod_size; + self.as_mut() + .par_chunks_mut(entity_view_pod_size) + .zip(rayon::iter::repeatn(meta, entity_count)) + .map(|(elt, meta)| Self::EntityMutView::<'this>::create_from(elt, meta)) + } }