mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
refactor(tfhe): add parallel bootstrap key generation
- add equivalence test between refactored sequential and parallel BSK generation
This commit is contained in:
@@ -3,10 +3,14 @@ use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator;
|
||||
use crate::core_crypto::commons::math::decomposition::DecompositionLevel;
|
||||
use crate::core_crypto::commons::math::random::ByteRandomGenerator;
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator;
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::specification::dispersion::DispersionParameter;
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
pub fn encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
glwe_secret_key: &GlweSecretKeyBase<KeyCont>,
|
||||
@@ -16,9 +20,9 @@ pub fn encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
Gen: ByteRandomGenerator,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
output.polynomial_size() == glwe_secret_key.polynomial_size(),
|
||||
@@ -85,6 +89,84 @@ pub fn encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
pub fn par_encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
glwe_secret_key: &GlweSecretKeyBase<KeyCont>,
|
||||
output: &mut GgswCiphertextBase<OutputCont>,
|
||||
encoded: Plaintext<Scalar>,
|
||||
noise_parameters: impl DispersionParameter + Sync,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus + Sync + Send,
|
||||
KeyCont: Container<Element = Scalar> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ParallelByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
output.polynomial_size() == glwe_secret_key.polynomial_size(),
|
||||
"Mismatch between polynomial sizes of output cipertexts and input secret key. \
|
||||
Got {:?} in output, and {:?} in secret key.",
|
||||
output.polynomial_size(),
|
||||
glwe_secret_key.polynomial_size()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.glwe_size().to_glwe_dimension() == glwe_secret_key.glwe_dimension(),
|
||||
"Mismatch between GlweDimension of output cipertexts and input secret key. \
|
||||
Got {:?} in output, and {:?} in secret key.",
|
||||
output.glwe_size().to_glwe_dimension(),
|
||||
glwe_secret_key.glwe_dimension()
|
||||
);
|
||||
|
||||
// Generators used to have same sequential and parallel key generation
|
||||
let gen_iter = generator
|
||||
.par_fork_ggsw_to_ggsw_levels::<Scalar>(
|
||||
output.decomposition_level_count(),
|
||||
output.glwe_size(),
|
||||
output.polynomial_size(),
|
||||
)
|
||||
.expect("Failed to split generator into ggsw levels");
|
||||
|
||||
let output_glwe_size = output.glwe_size();
|
||||
let output_polynomial_size = output.polynomial_size();
|
||||
let decomp_base_log = output.decomposition_base_log();
|
||||
|
||||
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|
||||
|(level_index, (mut level_matrix, mut generator))| {
|
||||
let decomp_level = DecompositionLevel(level_index + 1);
|
||||
let factor = encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)));
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special treatment
|
||||
let gen_iter = generator
|
||||
.par_fork_ggsw_level_to_glwe::<Scalar>(output_glwe_size, output_polynomial_size)
|
||||
.expect("Failed to split generator into glwe");
|
||||
|
||||
let last_row_index = level_matrix.glwe_size().0 - 1;
|
||||
let sk_poly_list = glwe_secret_key.as_polynomial_list();
|
||||
|
||||
level_matrix
|
||||
.as_mut_glwe_list()
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.zip(gen_iter)
|
||||
.for_each(|((row_index, mut row_as_glwe), mut generator)| {
|
||||
encrypt_ggsw_level_matrix_row(
|
||||
glwe_secret_key,
|
||||
(row_index, last_row_index),
|
||||
factor,
|
||||
&sk_poly_list,
|
||||
&mut row_as_glwe,
|
||||
noise_parameters,
|
||||
&mut generator,
|
||||
);
|
||||
});
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn encrypt_ggsw_level_matrix_row<Scalar, KeyCont, InputCont, OutputCont, Gen>(
|
||||
glwe_secret_key: &GlweSecretKeyBase<KeyCont>,
|
||||
(row_index, last_row_index): (usize, usize),
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator;
|
||||
use crate::core_crypto::commons::math::random::ByteRandomGenerator;
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator;
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::specification::dispersion::DispersionParameter;
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
pub fn generate_lwe_bootstrap_key<Scalar, InputKeyCont, OutputKeyCont, OutputCont, Gen>(
|
||||
input_lwe_secret_key: &LweSecretKeyBase<InputKeyCont>,
|
||||
@@ -67,6 +71,68 @@ pub fn generate_lwe_bootstrap_key<Scalar, InputKeyCont, OutputKeyCont, OutputCon
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
pub fn par_generate_lwe_bootstrap_key<Scalar, InputKeyCont, OutputKeyCont, OutputCont, Gen>(
|
||||
input_lwe_secret_key: &LweSecretKeyBase<InputKeyCont>,
|
||||
output_glwe_secret_key: &GlweSecretKeyBase<OutputKeyCont>,
|
||||
output: &mut LweBootstrapKeyBase<OutputCont>,
|
||||
noise_parameters: impl DispersionParameter + Sync,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus + Sync + Send,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar> + Sync,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ParallelByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
|
||||
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
|
||||
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
|
||||
input_lwe_secret_key.lwe_dimension(),
|
||||
output.input_lwe_dimension()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
|
||||
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
|
||||
output.glwe_size()
|
||||
);
|
||||
|
||||
assert!(
|
||||
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
|
||||
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
|
||||
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
|
||||
output_glwe_secret_key.polynomial_size(),
|
||||
output.polynomial_size()
|
||||
);
|
||||
|
||||
let gen_iter = generator
|
||||
.par_fork_bsk_to_ggsw::<Scalar>(
|
||||
output.input_lwe_dimension(),
|
||||
output.decomposition_level_count(),
|
||||
output.glwe_size(),
|
||||
output.polynomial_size(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
output
|
||||
.par_iter_mut()
|
||||
.zip(input_lwe_secret_key.as_ref().par_iter())
|
||||
.zip(gen_iter)
|
||||
.for_each(|((mut ggsw, &input_key_element), mut generator)| {
|
||||
par_encrypt_ggsw_ciphertext(
|
||||
output_glwe_secret_key,
|
||||
&mut ggsw,
|
||||
Plaintext(input_key_element),
|
||||
noise_parameters,
|
||||
&mut generator,
|
||||
);
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::core_crypto::algorithms::generate_lwe_bootstrap_key;
|
||||
@@ -173,3 +239,117 @@ mod test {
|
||||
test_refactored_bsk_equivalence::<u64>()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
#[cfg(test)]
|
||||
mod parallel_test {
|
||||
use crate::core_crypto::algorithms::{
|
||||
allocate_and_generate_new_binary_glwe_secret_key,
|
||||
allocate_and_generate_new_binary_lwe_secret_key, generate_lwe_bootstrap_key,
|
||||
par_generate_lwe_bootstrap_key,
|
||||
};
|
||||
use crate::core_crypto::commons::crypto::secret::generators::{
|
||||
DeterministicSeeder, EncryptionRandomGenerator,
|
||||
};
|
||||
use crate::core_crypto::commons::math::random::Seed;
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::test_tools::new_secret_random_generator;
|
||||
use crate::core_crypto::entities::LweBootstrapKey;
|
||||
use crate::core_crypto::prelude::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
|
||||
StandardDev,
|
||||
};
|
||||
use concrete_csprng::generators::SoftwareRandomGenerator;
|
||||
|
||||
fn test_refactored_bsk_parallel_gen_equivalence<T: UnsignedTorus + Send + Sync>() {
|
||||
for _ in 0..10 {
|
||||
let lwe_dim =
|
||||
LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
|
||||
let glwe_dim =
|
||||
GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
|
||||
let poly_size = PolynomialSize(
|
||||
crate::core_crypto::commons::test_tools::random_usize_between(5..10),
|
||||
);
|
||||
let level = DecompositionLevelCount(
|
||||
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
|
||||
);
|
||||
let base_log = DecompositionBaseLog(
|
||||
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
|
||||
);
|
||||
let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
|
||||
let deterministic_seeder_seed =
|
||||
Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
|
||||
|
||||
let mut secret_generator = new_secret_random_generator();
|
||||
let lwe_sk =
|
||||
allocate_and_generate_new_binary_lwe_secret_key(lwe_dim, &mut secret_generator);
|
||||
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
|
||||
glwe_dim,
|
||||
poly_size,
|
||||
&mut secret_generator,
|
||||
);
|
||||
|
||||
let mut parallel_bsk = LweBootstrapKey::new(
|
||||
T::ZERO,
|
||||
glwe_dim.to_glwe_size(),
|
||||
poly_size,
|
||||
base_log,
|
||||
level,
|
||||
lwe_dim,
|
||||
);
|
||||
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<SoftwareRandomGenerator>::new(
|
||||
mask_seed,
|
||||
&mut DeterministicSeeder::<SoftwareRandomGenerator>::new(
|
||||
deterministic_seeder_seed,
|
||||
),
|
||||
);
|
||||
|
||||
par_generate_lwe_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&mut parallel_bsk,
|
||||
StandardDev::from_standard_dev(10.),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
let mut sequential_bsk = LweBootstrapKey::new(
|
||||
T::ZERO,
|
||||
glwe_dim.to_glwe_size(),
|
||||
poly_size,
|
||||
base_log,
|
||||
level,
|
||||
lwe_dim,
|
||||
);
|
||||
|
||||
let mut encryption_generator =
|
||||
EncryptionRandomGenerator::<SoftwareRandomGenerator>::new(
|
||||
mask_seed,
|
||||
&mut DeterministicSeeder::<SoftwareRandomGenerator>::new(
|
||||
deterministic_seeder_seed,
|
||||
),
|
||||
);
|
||||
|
||||
generate_lwe_bootstrap_key(
|
||||
&lwe_sk,
|
||||
&glwe_sk,
|
||||
&mut sequential_bsk,
|
||||
StandardDev::from_standard_dev(10.),
|
||||
&mut encryption_generator,
|
||||
);
|
||||
|
||||
assert_eq!(parallel_bsk.as_ref(), sequential_bsk.as_ref());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_refactored_bsk_parallel_gen_equivalence_u32() {
|
||||
test_refactored_bsk_parallel_gen_equivalence::<u32>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_refactored_bsk_parallel_gen_equivalence_u64() {
|
||||
test_refactored_bsk_parallel_gen_equivalence::<u64>()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
use super::create_from::*;
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
use rayon::prelude::*;
|
||||
|
||||
type WrappingFunction<'data, Element, WrappingType> = fn(
|
||||
(
|
||||
@@ -15,6 +17,15 @@ type WrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map<
|
||||
WrappingFunction<'data, Element, WrappingType>,
|
||||
>;
|
||||
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
type ParallelWrappingLendingIterator<'data, Element, WrappingType> = rayon::iter::Map<
|
||||
rayon::iter::Zip<
|
||||
rayon::slice::Chunks<'data, Element>,
|
||||
rayon::iter::RepeatN<<WrappingType as CreateFrom<&'data [Element]>>::Metadata>,
|
||||
>,
|
||||
WrappingFunction<'data, Element, WrappingType>,
|
||||
>;
|
||||
|
||||
// This is required as at the moment it's not possible to reverse a zip containing a repeat, though
|
||||
// it is perfectly legal to zip a reversed repeat
|
||||
type RevWrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map<
|
||||
@@ -40,6 +51,15 @@ type WrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map<
|
||||
WrappingFunctionMut<'data, Element, WrappingType>,
|
||||
>;
|
||||
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
type ParallelWrappingLendingIteratorMut<'data, Element, WrappingType> = rayon::iter::Map<
|
||||
rayon::iter::Zip<
|
||||
rayon::slice::ChunksMut<'data, Element>,
|
||||
rayon::iter::RepeatN<<WrappingType as CreateFrom<&'data mut [Element]>>::Metadata>,
|
||||
>,
|
||||
WrappingFunctionMut<'data, Element, WrappingType>,
|
||||
>;
|
||||
|
||||
// This is required as at the moment it's not possible to reverse a zip containing a repeat, though
|
||||
// it is perfectly legal to zip a reversed repeat
|
||||
type RevWrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map<
|
||||
@@ -121,6 +141,24 @@ pub trait ContiguousEntityContainer: AsRef<[Self::Element]> {
|
||||
|
||||
Self::EntityView::<'_>::create_from(&self.as_ref()[start..stop], meta)
|
||||
}
|
||||
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
fn par_iter<'this>(
|
||||
&'this self,
|
||||
) -> ParallelWrappingLendingIterator<'this, Self::Element, Self::EntityView<'this>>
|
||||
where
|
||||
Self::Element: Sync,
|
||||
Self::EntityView<'this>: Send,
|
||||
Self::EntityViewMetadata: Send,
|
||||
{
|
||||
let meta = self.get_entity_view_creation_metadata();
|
||||
let entity_view_pod_size = self.get_entity_view_pod_size();
|
||||
let entity_count = self.as_ref().len() / entity_view_pod_size;
|
||||
self.as_ref()
|
||||
.par_chunks(entity_view_pod_size)
|
||||
.zip(rayon::iter::repeatn(meta, entity_count))
|
||||
.map(|(elt, meta)| Self::EntityView::<'this>::create_from(elt, meta))
|
||||
}
|
||||
}
|
||||
|
||||
pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self::Element]> {
|
||||
@@ -188,4 +226,22 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self:
|
||||
|
||||
Self::EntityMutView::<'_>::create_from(&mut self.as_mut()[start..stop], meta)
|
||||
}
|
||||
|
||||
#[cfg(feature = "__commons_parallel")]
|
||||
fn par_iter_mut<'this>(
|
||||
&'this mut self,
|
||||
) -> ParallelWrappingLendingIteratorMut<'this, Self::Element, Self::EntityMutView<'this>>
|
||||
where
|
||||
Self::Element: Send + Sync,
|
||||
Self::EntityMutView<'this>: Send,
|
||||
Self::EntityViewMetadata: Send,
|
||||
{
|
||||
let meta = self.get_entity_view_creation_metadata();
|
||||
let entity_view_pod_size = self.get_entity_view_pod_size();
|
||||
let entity_count = self.as_ref().len() / entity_view_pod_size;
|
||||
self.as_mut()
|
||||
.par_chunks_mut(entity_view_pod_size)
|
||||
.zip(rayon::iter::repeatn(meta, entity_count))
|
||||
.map(|(elt, meta)| Self::EntityMutView::<'this>::create_from(elt, meta))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user