refactor(tfhe): add parallel bootstrap key generation

- add equivalence test between refactored sequential and parallel BSK
generation
This commit is contained in:
Arthur Meyre
2022-11-25 15:02:13 +01:00
committed by jborfila
parent b445e349a6
commit 4a0fb6b42e
3 changed files with 319 additions and 1 deletions

View File

@@ -3,10 +3,14 @@ use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator;
use crate::core_crypto::commons::math::decomposition::DecompositionLevel;
use crate::core_crypto::commons::math::random::ByteRandomGenerator;
#[cfg(feature = "__commons_parallel")]
use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator;
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::specification::dispersion::DispersionParameter;
#[cfg(feature = "__commons_parallel")]
use rayon::prelude::*;
pub fn encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
glwe_secret_key: &GlweSecretKeyBase<KeyCont>,
@@ -16,9 +20,9 @@ pub fn encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
Gen: ByteRandomGenerator,
KeyCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert!(
output.polynomial_size() == glwe_secret_key.polynomial_size(),
@@ -85,6 +89,84 @@ pub fn encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
}
}
#[cfg(feature = "__commons_parallel")]
pub fn par_encrypt_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
glwe_secret_key: &GlweSecretKeyBase<KeyCont>,
output: &mut GgswCiphertextBase<OutputCont>,
encoded: Plaintext<Scalar>,
noise_parameters: impl DispersionParameter + Sync,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus + Sync + Send,
KeyCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ParallelByteRandomGenerator,
{
assert!(
output.polynomial_size() == glwe_secret_key.polynomial_size(),
"Mismatch between polynomial sizes of output cipertexts and input secret key. \
Got {:?} in output, and {:?} in secret key.",
output.polynomial_size(),
glwe_secret_key.polynomial_size()
);
assert!(
output.glwe_size().to_glwe_dimension() == glwe_secret_key.glwe_dimension(),
"Mismatch between GlweDimension of output cipertexts and input secret key. \
Got {:?} in output, and {:?} in secret key.",
output.glwe_size().to_glwe_dimension(),
glwe_secret_key.glwe_dimension()
);
// Generators used to have same sequential and parallel key generation
let gen_iter = generator
.par_fork_ggsw_to_ggsw_levels::<Scalar>(
output.decomposition_level_count(),
output.glwe_size(),
output.polynomial_size(),
)
.expect("Failed to split generator into ggsw levels");
let output_glwe_size = output.glwe_size();
let output_polynomial_size = output.polynomial_size();
let decomp_base_log = output.decomposition_base_log();
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
let factor = encoded
.0
.wrapping_neg()
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)));
// We iterate over the rows of the level matrix, the last row needs special treatment
let gen_iter = generator
.par_fork_ggsw_level_to_glwe::<Scalar>(output_glwe_size, output_polynomial_size)
.expect("Failed to split generator into glwe");
let last_row_index = level_matrix.glwe_size().0 - 1;
let sk_poly_list = glwe_secret_key.as_polynomial_list();
level_matrix
.as_mut_glwe_list()
.par_iter_mut()
.enumerate()
.zip(gen_iter)
.for_each(|((row_index, mut row_as_glwe), mut generator)| {
encrypt_ggsw_level_matrix_row(
glwe_secret_key,
(row_index, last_row_index),
factor,
&sk_poly_list,
&mut row_as_glwe,
noise_parameters,
&mut generator,
);
});
},
);
}
fn encrypt_ggsw_level_matrix_row<Scalar, KeyCont, InputCont, OutputCont, Gen>(
glwe_secret_key: &GlweSecretKeyBase<KeyCont>,
(row_index, last_row_index): (usize, usize),

View File

@@ -1,10 +1,14 @@
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::crypto::secret::generators::EncryptionRandomGenerator;
use crate::core_crypto::commons::math::random::ByteRandomGenerator;
#[cfg(feature = "__commons_parallel")]
use crate::core_crypto::commons::math::random::ParallelByteRandomGenerator;
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::specification::dispersion::DispersionParameter;
#[cfg(feature = "__commons_parallel")]
use rayon::prelude::*;
pub fn generate_lwe_bootstrap_key<Scalar, InputKeyCont, OutputKeyCont, OutputCont, Gen>(
input_lwe_secret_key: &LweSecretKeyBase<InputKeyCont>,
@@ -67,6 +71,68 @@ pub fn generate_lwe_bootstrap_key<Scalar, InputKeyCont, OutputKeyCont, OutputCon
}
}
#[cfg(feature = "__commons_parallel")]
pub fn par_generate_lwe_bootstrap_key<Scalar, InputKeyCont, OutputKeyCont, OutputCont, Gen>(
input_lwe_secret_key: &LweSecretKeyBase<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKeyBase<OutputKeyCont>,
output: &mut LweBootstrapKeyBase<OutputCont>,
noise_parameters: impl DispersionParameter + Sync,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus + Sync + Send,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ParallelByteRandomGenerator,
{
assert!(
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
input_lwe_secret_key.lwe_dimension(),
output.input_lwe_dimension()
);
assert!(
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output.glwe_size()
);
assert!(
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
output_glwe_secret_key.polynomial_size(),
output.polynomial_size()
);
let gen_iter = generator
.par_fork_bsk_to_ggsw::<Scalar>(
output.input_lwe_dimension(),
output.decomposition_level_count(),
output.glwe_size(),
output.polynomial_size(),
)
.unwrap();
output
.par_iter_mut()
.zip(input_lwe_secret_key.as_ref().par_iter())
.zip(gen_iter)
.for_each(|((mut ggsw, &input_key_element), mut generator)| {
par_encrypt_ggsw_ciphertext(
output_glwe_secret_key,
&mut ggsw,
Plaintext(input_key_element),
noise_parameters,
&mut generator,
);
})
}
#[cfg(test)]
mod test {
use crate::core_crypto::algorithms::generate_lwe_bootstrap_key;
@@ -173,3 +239,117 @@ mod test {
test_refactored_bsk_equivalence::<u64>()
}
}
#[cfg(feature = "__commons_parallel")]
#[cfg(test)]
mod parallel_test {
use crate::core_crypto::algorithms::{
allocate_and_generate_new_binary_glwe_secret_key,
allocate_and_generate_new_binary_lwe_secret_key, generate_lwe_bootstrap_key,
par_generate_lwe_bootstrap_key,
};
use crate::core_crypto::commons::crypto::secret::generators::{
DeterministicSeeder, EncryptionRandomGenerator,
};
use crate::core_crypto::commons::math::random::Seed;
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::test_tools::new_secret_random_generator;
use crate::core_crypto::entities::LweBootstrapKey;
use crate::core_crypto::prelude::{
DecompositionBaseLog, DecompositionLevelCount, GlweDimension, LweDimension, PolynomialSize,
StandardDev,
};
use concrete_csprng::generators::SoftwareRandomGenerator;
fn test_refactored_bsk_parallel_gen_equivalence<T: UnsignedTorus + Send + Sync>() {
for _ in 0..10 {
let lwe_dim =
LweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
let glwe_dim =
GlweDimension(crate::core_crypto::commons::test_tools::random_usize_between(5..10));
let poly_size = PolynomialSize(
crate::core_crypto::commons::test_tools::random_usize_between(5..10),
);
let level = DecompositionLevelCount(
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
);
let base_log = DecompositionBaseLog(
crate::core_crypto::commons::test_tools::random_usize_between(2..5),
);
let mask_seed = Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
let deterministic_seeder_seed =
Seed(crate::core_crypto::commons::test_tools::any_usize() as u128);
let mut secret_generator = new_secret_random_generator();
let lwe_sk =
allocate_and_generate_new_binary_lwe_secret_key(lwe_dim, &mut secret_generator);
let glwe_sk = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dim,
poly_size,
&mut secret_generator,
);
let mut parallel_bsk = LweBootstrapKey::new(
T::ZERO,
glwe_dim.to_glwe_size(),
poly_size,
base_log,
level,
lwe_dim,
);
let mut encryption_generator =
EncryptionRandomGenerator::<SoftwareRandomGenerator>::new(
mask_seed,
&mut DeterministicSeeder::<SoftwareRandomGenerator>::new(
deterministic_seeder_seed,
),
);
par_generate_lwe_bootstrap_key(
&lwe_sk,
&glwe_sk,
&mut parallel_bsk,
StandardDev::from_standard_dev(10.),
&mut encryption_generator,
);
let mut sequential_bsk = LweBootstrapKey::new(
T::ZERO,
glwe_dim.to_glwe_size(),
poly_size,
base_log,
level,
lwe_dim,
);
let mut encryption_generator =
EncryptionRandomGenerator::<SoftwareRandomGenerator>::new(
mask_seed,
&mut DeterministicSeeder::<SoftwareRandomGenerator>::new(
deterministic_seeder_seed,
),
);
generate_lwe_bootstrap_key(
&lwe_sk,
&glwe_sk,
&mut sequential_bsk,
StandardDev::from_standard_dev(10.),
&mut encryption_generator,
);
assert_eq!(parallel_bsk.as_ref(), sequential_bsk.as_ref());
}
}
#[test]
fn test_refactored_bsk_parallel_gen_equivalence_u32() {
test_refactored_bsk_parallel_gen_equivalence::<u32>()
}
#[test]
fn test_refactored_bsk_parallel_gen_equivalence_u64() {
test_refactored_bsk_parallel_gen_equivalence::<u64>()
}
}

View File

@@ -1,4 +1,6 @@
use super::create_from::*;
#[cfg(feature = "__commons_parallel")]
use rayon::prelude::*;
type WrappingFunction<'data, Element, WrappingType> = fn(
(
@@ -15,6 +17,15 @@ type WrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map<
WrappingFunction<'data, Element, WrappingType>,
>;
#[cfg(feature = "__commons_parallel")]
type ParallelWrappingLendingIterator<'data, Element, WrappingType> = rayon::iter::Map<
rayon::iter::Zip<
rayon::slice::Chunks<'data, Element>,
rayon::iter::RepeatN<<WrappingType as CreateFrom<&'data [Element]>>::Metadata>,
>,
WrappingFunction<'data, Element, WrappingType>,
>;
// This is required as at the moment it's not possible to reverse a zip containing a repeat, though
// it is perfectly legal to zip a reversed repeat
type RevWrappingLendingIterator<'data, Element, WrappingType> = std::iter::Map<
@@ -40,6 +51,15 @@ type WrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map<
WrappingFunctionMut<'data, Element, WrappingType>,
>;
#[cfg(feature = "__commons_parallel")]
type ParallelWrappingLendingIteratorMut<'data, Element, WrappingType> = rayon::iter::Map<
rayon::iter::Zip<
rayon::slice::ChunksMut<'data, Element>,
rayon::iter::RepeatN<<WrappingType as CreateFrom<&'data mut [Element]>>::Metadata>,
>,
WrappingFunctionMut<'data, Element, WrappingType>,
>;
// This is required as at the moment it's not possible to reverse a zip containing a repeat, though
// it is perfectly legal to zip a reversed repeat
type RevWrappingLendingIteratorMut<'data, Element, WrappingType> = std::iter::Map<
@@ -121,6 +141,24 @@ pub trait ContiguousEntityContainer: AsRef<[Self::Element]> {
Self::EntityView::<'_>::create_from(&self.as_ref()[start..stop], meta)
}
#[cfg(feature = "__commons_parallel")]
fn par_iter<'this>(
&'this self,
) -> ParallelWrappingLendingIterator<'this, Self::Element, Self::EntityView<'this>>
where
Self::Element: Sync,
Self::EntityView<'this>: Send,
Self::EntityViewMetadata: Send,
{
let meta = self.get_entity_view_creation_metadata();
let entity_view_pod_size = self.get_entity_view_pod_size();
let entity_count = self.as_ref().len() / entity_view_pod_size;
self.as_ref()
.par_chunks(entity_view_pod_size)
.zip(rayon::iter::repeatn(meta, entity_count))
.map(|(elt, meta)| Self::EntityView::<'this>::create_from(elt, meta))
}
}
pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self::Element]> {
@@ -188,4 +226,22 @@ pub trait ContiguousEntityContainerMut: ContiguousEntityContainer + AsMut<[Self:
Self::EntityMutView::<'_>::create_from(&mut self.as_mut()[start..stop], meta)
}
#[cfg(feature = "__commons_parallel")]
fn par_iter_mut<'this>(
&'this mut self,
) -> ParallelWrappingLendingIteratorMut<'this, Self::Element, Self::EntityMutView<'this>>
where
Self::Element: Send + Sync,
Self::EntityMutView<'this>: Send,
Self::EntityViewMetadata: Send,
{
let meta = self.get_entity_view_creation_metadata();
let entity_view_pod_size = self.get_entity_view_pod_size();
let entity_count = self.as_ref().len() / entity_view_pod_size;
self.as_mut()
.par_chunks_mut(entity_view_pod_size)
.zip(rayon::iter::repeatn(meta, entity_count))
.map(|(elt, meta)| Self::EntityMutView::<'this>::create_from(elt, meta))
}
}