feat(core): add multi-bit BSK generation and PBS threaded implementation

This commit is contained in:
Arthur Meyre
2023-02-08 15:58:52 +01:00
parent bf6f699e8c
commit 75f05c0f3a
19 changed files with 2433 additions and 21 deletions

View File

@@ -123,6 +123,12 @@ path = "benches/core_crypto/pbs_bench.rs"
harness = false
required-features = ["boolean", "shortint", "internal-keycache"]
[[bench]]
name = "dev-bench"
path = "benches/core_crypto/dev_bench.rs"
harness = false
required-features = ["boolean", "shortint", "internal-keycache"]
[[bench]]
name = "boolean-bench"
path = "benches/boolean/bench.rs"

View File

@@ -0,0 +1,310 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use tfhe::core_crypto::prelude::*;
criterion_group!(
boolean_like_pbs_group,
multi_bit_pbs::<u32>,
pbs::<u32>,
mem_optimized_pbs::<u32>
);
criterion_group!(
shortint_like_pbs_group,
multi_bit_pbs::<u64>,
pbs::<u64>,
mem_optimized_pbs::<u64>
);
criterion_main!(boolean_like_pbs_group, shortint_like_pbs_group);
fn get_bench_params<Scalar: Numeric>() -> (
LweDimension,
StandardDev,
DecompositionBaseLog,
DecompositionLevelCount,
GlweDimension,
PolynomialSize,
LweBskGroupingFactor,
ThreadCount,
) {
if Scalar::BITS == 64 {
(
LweDimension(742),
StandardDev(0.000007069849454709433),
DecompositionBaseLog(3),
DecompositionLevelCount(5),
GlweDimension(1),
PolynomialSize(1024),
LweBskGroupingFactor(2),
ThreadCount(5),
)
} else if Scalar::BITS == 32 {
(
LweDimension(778),
StandardDev(0.000003725679281679651),
DecompositionBaseLog(18),
DecompositionLevelCount(1),
GlweDimension(3),
PolynomialSize(512),
LweBskGroupingFactor(2),
ThreadCount(5),
)
} else {
unreachable!()
}
}
fn multi_bit_pbs<Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync>(
c: &mut Criterion,
) {
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
// computations
// Define parameters for LweBootstrapKey creation
let (
mut input_lwe_dimension,
lwe_modular_std_dev,
decomp_base_log,
decomp_level_count,
glwe_dimension,
polynomial_size,
grouping_factor,
thread_count,
) = get_bench_params::<Scalar>();
while input_lwe_dimension.0 % grouping_factor.0 != 0 {
input_lwe_dimension = LweDimension(input_lwe_dimension.0 + 1);
}
// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mut encryption_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
// Create the LweSecretKey
let input_lwe_secret_key =
allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
let output_glwe_secret_key: GlweSecretKeyOwned<Scalar> =
allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key();
let multi_bit_bsk = FourierLweMultiBitBootstrapKey::new(
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_base_log,
decomp_level_count,
grouping_factor,
);
// Allocate a new LweCiphertext and encrypt our plaintext
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
&input_lwe_secret_key,
Plaintext(Scalar::ZERO),
lwe_modular_std_dev,
&mut encryption_generator,
);
let accumulator =
GlweCiphertext::new(Scalar::ZERO, glwe_dimension.to_glwe_size(), polynomial_size);
// Allocate the LweCiphertext to store the result of the PBS
let mut out_pbs_ct = LweCiphertext::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
);
let id = format!("Multi Bit PBS {}", Scalar::BITS);
#[allow(clippy::unit_arg)]
{
c.bench_function(&id, |b| {
b.iter(|| {
multi_bit_programmable_bootstrap_lwe_ciphertext(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator.as_view(),
&multi_bit_bsk,
thread_count,
);
black_box(&mut out_pbs_ct);
})
});
}
}
fn pbs<Scalar: UnsignedTorus + CastInto<usize>>(c: &mut Criterion) {
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
// computations
// Define parameters for LweBootstrapKey creation
let (
input_lwe_dimension,
lwe_modular_std_dev,
decomp_base_log,
decomp_level_count,
glwe_dimension,
polynomial_size,
_,
_,
) = get_bench_params::<Scalar>();
// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mut encryption_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
// Create the LweSecretKey
let input_lwe_secret_key =
allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
let output_glwe_secret_key: GlweSecretKeyOwned<Scalar> =
allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key();
// Create the empty bootstrapping key in the Fourier domain
let fourier_bsk = FourierLweBootstrapKey::new(
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_base_log,
decomp_level_count,
);
// Allocate a new LweCiphertext and encrypt our plaintext
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
&input_lwe_secret_key,
Plaintext(Scalar::ZERO),
lwe_modular_std_dev,
&mut encryption_generator,
);
let accumulator =
GlweCiphertext::new(Scalar::ZERO, glwe_dimension.to_glwe_size(), polynomial_size);
// Allocate the LweCiphertext to store the result of the PBS
let mut out_pbs_ct = LweCiphertext::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
);
let id = format!("PBS {}", Scalar::BITS);
{
c.bench_function(&id, |b| {
b.iter(|| {
programmable_bootstrap_lwe_ciphertext(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator.as_view(),
&fourier_bsk,
);
black_box(&mut out_pbs_ct);
})
});
}
}
fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize>>(c: &mut Criterion) {
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
// computations
// Define parameters for LweBootstrapKey creation
let (
input_lwe_dimension,
lwe_modular_std_dev,
decomp_base_log,
decomp_level_count,
glwe_dimension,
polynomial_size,
_,
_,
) = get_bench_params::<Scalar>();
// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mut encryption_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
// Create the LweSecretKey
let input_lwe_secret_key =
allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
let output_glwe_secret_key: GlweSecretKeyOwned<Scalar> =
allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key();
// Create the empty bootstrapping key in the Fourier domain
let fourier_bsk = FourierLweBootstrapKey::new(
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_base_log,
decomp_level_count,
);
// Allocate a new LweCiphertext and encrypt our plaintext
let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext(
&input_lwe_secret_key,
Plaintext(Scalar::ZERO),
lwe_modular_std_dev,
&mut encryption_generator,
);
let accumulator =
GlweCiphertext::new(Scalar::ZERO, glwe_dimension.to_glwe_size(), polynomial_size);
// Allocate the LweCiphertext to store the result of the PBS
let mut out_pbs_ct = LweCiphertext::new(
Scalar::ZERO,
output_lwe_secret_key.lwe_dimension().to_lwe_size(),
);
let mut buffers = ComputationBuffers::new();
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<Scalar>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let id = format!("PBS mem-optimized {}", Scalar::BITS);
{
c.bench_function(&id, |b| {
b.iter(|| {
programmable_bootstrap_lwe_ciphertext_mem_optimized(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator.as_view(),
&fourier_bsk,
fft,
buffers.stack(),
);
black_box(&mut out_pbs_ct);
})
});
}
}

View File

@@ -160,18 +160,18 @@ fn mem_optimized_pbs<Scalar: UnsignedTorus + CastInto<usize>>(c: &mut Criterion)
);
let id = format!("PBS_mem-optimized_{name}");
#[allow(clippy::unit_arg)]
{
c.bench_function(&id, |b| {
b.iter(|| {
black_box(programmable_bootstrap_lwe_ciphertext_mem_optimized(
programmable_bootstrap_lwe_ciphertext_mem_optimized(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator.as_view(),
&fourier_bsk,
fft,
buffers.stack(),
))
);
black_box(&mut out_pbs_ct);
})
});
}

View File

@@ -115,7 +115,10 @@ pub fn extract_lwe_sample_from_glwe_ciphertext<Scalar, InputCont, OutputCont>(
let opposite_count = input_glwe.polynomial_size().0 - nth.0 - 1;
// We loop through the polynomials
for lwe_mask_poly in lwe_mask.as_mut().chunks_mut(input_glwe.polynomial_size().0) {
for lwe_mask_poly in lwe_mask
.as_mut()
.chunks_exact_mut(input_glwe.polynomial_size().0)
{
// We reverse the polynomial
lwe_mask_poly.reverse();
// We compute the opposite of the proper coefficients

View File

@@ -0,0 +1,84 @@
//! Module containing primitives pertaining to the conversion of
//! [`standard LWE multi_bit bootstrap keys`](`LweMultiBitBootstrapKey`) to various
//! representations/numerical domains like the Fourier domain.
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::math::fft::{Fft, FftView};
use concrete_fft::c64;
use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq};
/// Convert an [`LWE multi_bit bootstrap key`](`LweMultiBitBootstrapKey`) with standard
/// coefficients to the Fourier domain.
///
/// See [`multi_bit_programmable_bootstrap_lwe_ciphertext`](`crate::core_crypto::algorithms::multi_bit_programmable_bootstrap_lwe_ciphertext`) for usage.
pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier<Scalar, InputCont, OutputCont>(
input_bsk: &LweMultiBitBootstrapKey<InputCont>,
output_bsk: &mut FourierLweMultiBitBootstrapKey<OutputCont>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = c64>,
{
let mut buffers = ComputationBuffers::new();
let fft = Fft::new(input_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement(fft)
.unwrap()
.unaligned_bytes_required(),
);
let stack = buffers.stack();
convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized(
input_bsk, output_bsk, fft, stack,
);
}
/// Memory optimized version of [`convert_standard_lwe_multi_bit_bootstrap_key_to_fourier`].
pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized<
Scalar,
InputCont,
OutputCont,
>(
input_bsk: &LweMultiBitBootstrapKey<InputCont>,
output_bsk: &mut FourierLweMultiBitBootstrapKey<OutputCont>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
) where
Scalar: UnsignedTorus,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = c64>,
{
let mut output_bsk_as_polynomial_list = output_bsk.as_mut_polynomial_list();
let input_bsk_as_polynomial_list = input_bsk.as_polynomial_list();
assert_eq!(
output_bsk_as_polynomial_list.polynomial_count(),
input_bsk_as_polynomial_list.polynomial_count()
);
for (fourier_poly, coef_poly) in output_bsk_as_polynomial_list
.iter_mut()
.zip(input_bsk_as_polynomial_list.iter())
{
// SAFETY: forward_as_torus doesn't write any uninitialized values into its output
fft.forward_as_torus(
unsafe { fourier_poly.into_uninit() },
coef_poly,
stack.rb_mut(),
);
}
}
/// Return the required memory for
/// [`convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized`].
pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement(
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
fft.forward_scratch()
}

View File

@@ -0,0 +1,442 @@
//! Module containing primitives pertaining to the generation of
//! [`standard LWE multi_bit bootstrap keys`](`LweMultiBitBootstrapKey`).
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::dispersion::DispersionParameter;
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweBootstrapKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let glwe_dimension = GlweDimension(1);
/// let polynomial_size = PolynomialSize(1024);
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
/// let grouping_factor = LweBskGroupingFactor(2);
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator =
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
/// glwe_dimension,
/// polynomial_size,
/// &mut secret_generator,
/// );
///
/// let mut bsk = LweMultiBitBootstrapKey::new(
/// 0u64,
/// glwe_dimension.to_glwe_size(),
/// polynomial_size,
/// decomp_base_log,
/// decomp_level_count,
/// input_lwe_dimension,
/// grouping_factor,
/// );
///
/// generate_lwe_multi_bit_bootstrap_key(
/// &input_lwe_secret_key,
/// &output_glwe_secret_key,
/// &mut bsk,
/// glwe_modular_std_dev,
/// &mut encryption_generator,
/// );
///
/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
///
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip(
/// input_lwe_secret_key
/// .as_ref()
/// .chunks_exact(grouping_factor.0),
/// ) {
/// for (bit_inversion_idx, ggsw) in ggsw_group.iter().enumerate() {
/// let mut key_bits_plaintext = 1u64;
/// for (bit_idx, &key_bit) in input_key_elements.iter().enumerate() {
/// let bit_position = input_key_elements.len() - (bit_idx + 1);
/// let inversion_bit = (((bit_inversion_idx >> bit_position) & 1) ^ 1) as u64;
/// let key_bit = key_bit ^ inversion_bit;
/// key_bits_plaintext *= key_bit;
/// }
///
/// let decrypted_ggsw = decrypt_constant_ggsw_ciphertext(&output_glwe_secret_key, &ggsw);
/// assert_eq!(decrypted_ggsw.0, key_bits_plaintext)
/// }
/// }
/// ```
pub fn generate_lwe_multi_bit_bootstrap_key<Scalar, InputKeyCont, OutputKeyCont, OutputCont, Gen>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
output: &mut LweMultiBitBootstrapKey<OutputCont>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus + CastFrom<usize>,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert!(
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
input_lwe_secret_key.lwe_dimension(),
output.input_lwe_dimension()
);
assert!(
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output.glwe_size()
);
assert!(
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
output_glwe_secret_key.polynomial_size(),
output.polynomial_size()
);
let gen_iter = generator
.fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
output.input_lwe_dimension(),
output.decomposition_level_count(),
output.glwe_size(),
output.polynomial_size(),
output.grouping_factor(),
)
.unwrap();
let grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
for ((mut ggsw_group, input_key_elements), mut loop_generator) in output
.chunks_exact_mut(ggsw_per_multi_bit_element.0)
.zip(
input_lwe_secret_key
.as_ref()
.chunks_exact(grouping_factor.0),
)
.zip(gen_iter)
{
let gen_iter = loop_generator.fork_n(ggsw_per_multi_bit_element.0).unwrap();
for ((bit_inversion_idx, mut ggsw), mut inner_loop_generator) in
ggsw_group.iter_mut().enumerate().zip(gen_iter)
{
// Use the index of the ggsw as a way to know which bit to invert
let key_bits_plaintext = combine_key_bits(bit_inversion_idx, input_key_elements);
encrypt_constant_ggsw_ciphertext(
output_glwe_secret_key,
&mut ggsw,
Plaintext(key_bits_plaintext),
noise_parameters,
&mut inner_loop_generator,
);
}
}
}
pub fn allocate_and_generate_new_lwe_multi_bit_bootstrap_key<
Scalar,
InputKeyCont,
OutputKeyCont,
Gen,
>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) -> LweMultiBitBootstrapKeyOwned<Scalar>
where
Scalar: UnsignedTorus + CastFrom<usize>,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
Gen: ByteRandomGenerator,
{
let mut bsk = LweMultiBitBootstrapKeyOwned::new(
Scalar::ZERO,
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output_glwe_secret_key.polynomial_size(),
decomp_base_log,
decomp_level_count,
input_lwe_secret_key.lwe_dimension(),
grouping_factor,
);
generate_lwe_multi_bit_bootstrap_key(
input_lwe_secret_key,
output_glwe_secret_key,
&mut bsk,
noise_parameters,
generator,
);
bsk
}
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweBootstrapKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let glwe_dimension = GlweDimension(1);
/// let polynomial_size = PolynomialSize(1024);
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
/// let grouping_factor = LweBskGroupingFactor(2);
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator =
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
/// glwe_dimension,
/// polynomial_size,
/// &mut secret_generator,
/// );
///
/// let mut bsk = LweMultiBitBootstrapKey::new(
/// 0u64,
/// glwe_dimension.to_glwe_size(),
/// polynomial_size,
/// decomp_base_log,
/// decomp_level_count,
/// input_lwe_dimension,
/// grouping_factor,
/// );
///
/// par_generate_lwe_multi_bit_bootstrap_key(
/// &input_lwe_secret_key,
/// &output_glwe_secret_key,
/// &mut bsk,
/// glwe_modular_std_dev,
/// &mut encryption_generator,
/// );
///
/// let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new(
/// input_lwe_dimension,
/// glwe_dimension.to_glwe_size(),
/// polynomial_size,
/// decomp_base_log,
/// decomp_level_count,
/// grouping_factor,
/// );
///
/// convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk);
///
/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
///
/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip(
/// input_lwe_secret_key
/// .as_ref()
/// .chunks_exact(grouping_factor.0),
/// ) {
/// for (bit_inversion_idx, ggsw) in ggsw_group.iter().enumerate() {
/// let mut key_bits_plaintext = 1u64;
/// for (bit_idx, &key_bit) in input_key_elements.iter().enumerate() {
/// let bit_position = input_key_elements.len() - (bit_idx + 1);
/// let inversion_bit = (((bit_inversion_idx >> bit_position) & 1) ^ 1) as u64;
/// let key_bit = key_bit ^ inversion_bit;
/// key_bits_plaintext *= key_bit;
/// }
/// let decrypted_ggsw = decrypt_constant_ggsw_ciphertext(&output_glwe_secret_key, &ggsw);
/// assert_eq!(decrypted_ggsw.0, key_bits_plaintext)
/// }
/// }
/// ```
pub fn par_generate_lwe_multi_bit_bootstrap_key<
Scalar,
InputKeyCont,
OutputKeyCont,
OutputCont,
Gen,
>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
output: &mut LweMultiBitBootstrapKey<OutputCont>,
noise_parameters: impl DispersionParameter + Sync,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus + CastFrom<usize> + Sync + Send,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar> + Sync,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ParallelByteRandomGenerator,
{
assert!(
output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(),
"Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \
Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.",
input_lwe_secret_key.lwe_dimension(),
output.input_lwe_dimension()
);
assert!(
output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(),
"Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.",
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output.glwe_size()
);
assert!(
output.polynomial_size() == output_glwe_secret_key.polynomial_size(),
"Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \
Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.",
output_glwe_secret_key.polynomial_size(),
output.polynomial_size()
);
let gen_iter = generator
.par_fork_multi_bit_bsk_to_ggsw_group::<Scalar>(
output.input_lwe_dimension(),
output.decomposition_level_count(),
output.glwe_size(),
output.polynomial_size(),
output.grouping_factor(),
)
.unwrap();
let grouping_factor = output.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
output
.par_iter_mut()
.chunks(ggsw_per_multi_bit_element.0)
.zip(
input_lwe_secret_key
.as_ref()
.par_chunks_exact(grouping_factor.0),
)
.zip(gen_iter)
.for_each(
|((mut ggsw_group, input_key_elements), mut loop_generator)| {
let gen_iter = loop_generator
.par_fork_n(ggsw_per_multi_bit_element.0)
.unwrap();
ggsw_group
.par_iter_mut()
.enumerate()
.zip(gen_iter)
.for_each(|((bit_inversion_idx, ggsw), mut inner_loop_generator)| {
// Use the index of the ggsw as a way to know which bit to invert
let key_bits_plaintext =
combine_key_bits(bit_inversion_idx, input_key_elements);
par_encrypt_constant_ggsw_ciphertext(
output_glwe_secret_key,
ggsw,
Plaintext(key_bits_plaintext),
noise_parameters,
&mut inner_loop_generator,
);
});
},
);
}
fn combine_key_bits<Scalar>(bit_selector: usize, input_key_elements: &[Scalar]) -> Scalar
where
Scalar: UnsignedInteger + CastFrom<usize>,
{
// Use a bit_selector (in practice the ggsw index) as a way to know which bit to invert, the
// counter goes from e.g. 0 to 4 or 00, 01, 10 and 11 in binary, we use those bits to know which
// key bit to invert in our product, also we invert the bit once more to be sure that the first
// term is the GGSW encrypting a constant polynomial (and not a monomial), allowing to copy it
// in the multi_bit PBS routine and computing polynomial products on the rest of the terms.
// We compute products, initialize the combined key bits to 1
let mut key_bits_plaintext = Scalar::ONE;
for (bit_idx, &key_bit) in input_key_elements.iter().enumerate() {
// Get the position of the bit we will check in bit_selector
let bit_position = input_key_elements.len() - (bit_idx + 1);
// Get the bit, invert it to have the first combined GGSW correspond to
// the constant polynomial, i.e. we generate
// first GGSW((1 - s_{i-1}) * (1 - s_i)) up
// to GGSW(s_{i-1} * s_{i})
let inversion_bit: Scalar = Scalar::cast_from(((bit_selector >> bit_position) & 1) ^ 1);
// Invert the key_bit depending on the computed inversion_bit
let key_bit = key_bit ^ inversion_bit;
// Multiply the accumulator by the key_bit we need to combine it with
key_bits_plaintext = key_bits_plaintext.wrapping_mul(key_bit);
}
key_bits_plaintext
}
pub fn par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key<
Scalar,
InputKeyCont,
OutputKeyCont,
Gen,
>(
input_lwe_secret_key: &LweSecretKey<InputKeyCont>,
output_glwe_secret_key: &GlweSecretKey<OutputKeyCont>,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
noise_parameters: impl DispersionParameter + Sync,
generator: &mut EncryptionRandomGenerator<Gen>,
) -> LweMultiBitBootstrapKeyOwned<Scalar>
where
Scalar: UnsignedTorus + CastFrom<usize> + Sync + Send,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar> + Sync,
Gen: ParallelByteRandomGenerator,
{
let mut bsk = LweMultiBitBootstrapKeyOwned::new(
Scalar::ZERO,
output_glwe_secret_key.glwe_dimension().to_glwe_size(),
output_glwe_secret_key.polynomial_size(),
decomp_base_log,
decomp_level_count,
input_lwe_secret_key.lwe_dimension(),
grouping_factor,
);
par_generate_lwe_multi_bit_bootstrap_key(
input_lwe_secret_key,
output_glwe_secret_key,
&mut bsk,
noise_parameters,
generator,
);
bsk
}

View File

@@ -0,0 +1,945 @@
use crate::core_crypto::algorithms::extract_lwe_sample_from_glwe_ciphertext;
use crate::core_crypto::algorithms::polynomial_algorithms::*;
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::as_mut_uninit;
use crate::core_crypto::fft_impl::crypto::bootstrap::pbs_modulus_switch;
use crate::core_crypto::fft_impl::crypto::ggsw::{
add_external_product_assign, add_external_product_assign_scratch, update_with_fmadd,
};
use crate::core_crypto::fft_impl::math::fft::Fft;
use concrete_fft::c64;
use std::sync::{mpsc, Condvar, Mutex};
use std::thread;
/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up
/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap
/// key`](`LweMultiBitBootstrapKey`) in the fourier domain.
///
/// # Example
///
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define the parameters for a 4 bits message able to hold the doubled 2 bits message
/// let small_lwe_dimension = LweDimension(742);
/// let glwe_dimension = GlweDimension(1);
/// let polynomial_size = PolynomialSize(2048);
/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433);
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
/// let pbs_base_log = DecompositionBaseLog(23);
/// let pbs_level = DecompositionLevelCount(1);
/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs
///
/// // Request the best seeder possible, starting with hardware entropy sources and falling back to
/// // /dev/random on Unix systems if enabled via cargo features
/// let mut boxed_seeder = new_seeder();
/// // Get a mutable reference to the seeder as a trait object from the Box returned by new_seeder
/// let seeder = boxed_seeder.as_mut();
///
/// // Create a generator which uses a CSPRNG to generate secret keys
/// let mut secret_generator =
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
///
/// // Create a generator which uses two CSPRNGs to generate public masks and secret encryption
/// // noise
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
///
/// println!("Generating keys...");
///
/// // Generate an LweSecretKey with binary coefficients
/// let small_lwe_sk =
/// LweSecretKey::generate_new_binary(small_lwe_dimension, &mut secret_generator);
///
/// // Generate a GlweSecretKey with binary coefficients
/// let glwe_sk =
/// GlweSecretKey::generate_new_binary(glwe_dimension, polynomial_size, &mut secret_generator);
///
/// // Create a copy of the GlweSecretKey re-interpreted as an LweSecretKey
/// let big_lwe_sk = glwe_sk.clone().into_lwe_secret_key();
///
/// let mut bsk = LweMultiBitBootstrapKey::new(
/// 0u64,
/// glwe_dimension.to_glwe_size(),
/// polynomial_size,
/// pbs_base_log,
/// pbs_level,
/// small_lwe_dimension,
/// grouping_factor,
/// );
///
/// par_generate_lwe_multi_bit_bootstrap_key(
/// &small_lwe_sk,
/// &glwe_sk,
/// &mut bsk,
/// glwe_modular_std_dev,
/// &mut encryption_generator,
/// );
///
/// let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new(
/// bsk.input_lwe_dimension(),
/// bsk.glwe_size(),
/// bsk.polynomial_size(),
/// bsk.decomposition_base_log(),
/// bsk.decomposition_level_count(),
/// bsk.grouping_factor(),
/// );
///
/// convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk);
///
/// // We don't need the standard bootstrapping key anymore
/// drop(bsk);
///
/// // Our 4 bits message space
/// let message_modulus = 1u64 << 4;
///
/// // Our input message
/// let input_message = 3u64;
///
/// // Delta used to encode 4 bits of message + a bit of padding on u64
/// let delta = (1_u64 << 63) / message_modulus;
///
/// // Apply our encoding
/// let plaintext = Plaintext(input_message * delta);
///
/// // Allocate a new LweCiphertext and encrypt our plaintext
/// let lwe_ciphertext_in: LweCiphertextOwned<u64> = allocate_and_encrypt_new_lwe_ciphertext(
/// &small_lwe_sk,
/// plaintext,
/// lwe_modular_std_dev,
/// &mut encryption_generator,
/// );
///
/// // Now we will use a PBS to compute a multiplication by 2, it is NOT the recommended way of
/// // doing this operation in terms of performance as it's much more costly than a multiplication
/// // with a cleartext, however it resets the noise in a ciphertext to a nominal level and allows
/// // to evaluate arbitrary functions so depending on your use case it can be a better fit.
///
/// // Here we will define a helper function to generate an accumulator for a PBS
/// fn generate_accumulator<F>(
/// polynomial_size: PolynomialSize,
/// glwe_size: GlweSize,
/// message_modulus: usize,
/// delta: u64,
/// f: F,
/// ) -> GlweCiphertextOwned<u64>
/// where
/// F: Fn(u64) -> u64,
/// {
/// // N/(p/2) = size of each block, to correct noise from the input we introduce the notion of
/// // box, which manages redundancy to yield a denoised value for several noisy values around
/// // a true input value.
/// let box_size = polynomial_size.0 / message_modulus;
///
/// // Create the accumulator
/// let mut accumulator_u64 = vec![0_u64; polynomial_size.0];
///
/// // Fill each box with the encoded denoised value
/// for i in 0..message_modulus {
/// let index = i * box_size;
/// accumulator_u64[index..index + box_size]
/// .iter_mut()
/// .for_each(|a| *a = f(i as u64) * delta);
/// }
///
/// let half_box_size = box_size / 2;
///
/// // Negate the first half_box_size coefficients to manage negacyclicity and rotate
/// for a_i in accumulator_u64[0..half_box_size].iter_mut() {
/// *a_i = (*a_i).wrapping_neg();
/// }
///
/// // Rotate the accumulator
/// accumulator_u64.rotate_left(half_box_size);
///
/// let accumulator_plaintext = PlaintextList::from_container(accumulator_u64);
///
/// let accumulator =
/// allocate_and_trivially_encrypt_new_glwe_ciphertext(glwe_size, &accumulator_plaintext);
///
/// accumulator
/// }
///
/// // Generate the accumulator for our multiplication by 2 using a simple closure
/// let mut accumulator: GlweCiphertextOwned<u64> = generate_accumulator(
/// polynomial_size,
/// glwe_dimension.to_glwe_size(),
/// message_modulus as usize,
/// delta,
/// |x: u64| 2 * x,
/// );
///
/// // Allocate the LweCiphertext to store the result of the PBS
/// let mut pbs_multiplication_ct =
/// LweCiphertext::new(0u64, big_lwe_sk.lwe_dimension().to_lwe_size());
/// println!("Performing blind rotation...");
/// // Use 4 threads for the multi-bit blind rotation for example
/// multi_bit_blind_rotate_assign(
/// &lwe_ciphertext_in,
/// &mut accumulator,
/// &multi_bit_bsk,
/// ThreadCount(4),
/// );
/// println!("Performing sample extraction...");
/// extract_lwe_sample_from_glwe_ciphertext(
/// &accumulator,
/// &mut pbs_multiplication_ct,
/// MonomialDegree(0),
/// );
///
/// // Decrypt the PBS multiplication result
/// let pbs_multipliation_plaintext: Plaintext<u64> =
/// decrypt_lwe_ciphertext(&big_lwe_sk, &pbs_multiplication_ct);
///
/// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext
/// // We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating we want to
/// // round the 5 MSB, 1 bit of padding plus our 4 bits of message
/// let signed_decomposer =
/// SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1));
///
/// // Round and remove our encoding
/// let pbs_multiplication_result: u64 =
/// signed_decomposer.closest_representable(pbs_multipliation_plaintext.0) / delta;
///
/// println!("Checking result...");
/// assert_eq!(6, pbs_multiplication_result);
/// println!(
/// "Mulitplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
/// );
/// ```
pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
input: &LweCiphertext<InputCont>,
accumulator: &mut GlweCiphertext<OutputCont>,
multi_bit_bsk: &FourierLweMultiBitBootstrapKey<KeyCont>,
thread_count: ThreadCount,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
KeyCont: Container<Element = c64> + Sync,
{
assert_eq!(
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
"Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \
FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
);
assert_eq!(
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
"Mimatched GlweSize. Accumulator GlweSize {:?}. \
FourierLweMultiBitBootstrapKey GlweSize {:?}.",
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
);
assert_eq!(
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
"Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \
FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
);
let (lwe_mask, lwe_body) = input.get_mask_and_body();
// No way to chunk the result of ggsw_iter at the moment
let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect();
let mut work_queue = Vec::with_capacity(multi_bit_bsk.multi_bit_input_lwe_dimension().0);
let grouping_factor = multi_bit_bsk.grouping_factor();
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
for (lwe_mask_elements, ggsw_group) in lwe_mask
.as_ref()
.chunks_exact(grouping_factor.0)
.zip(ggsw_vec.chunks_exact(ggsw_per_multi_bit_element.0))
{
work_queue.push((lwe_mask_elements, ggsw_group));
}
assert!(work_queue.len() == lwe_mask.lwe_dimension().0 / grouping_factor.0);
let work_queue = Mutex::new(work_queue);
// Each producer thread works in a dedicated slot of the buffer
let thread_buffers: usize = thread_count.0;
let lut_poly_size = accumulator.polynomial_size();
let monomial_degree = pbs_modulus_switch(
*lwe_body.0,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
// Modulus switching
accumulator
.as_mut_polynomial_list()
.iter_mut()
.for_each(|mut poly| {
polynomial_wrapping_monic_monomial_div_assign(
&mut poly,
MonomialDegree(monomial_degree),
)
});
let fourier_multi_bit_ggsw_buffers = (0..thread_buffers)
.map(|_| {
(
Mutex::new(false),
Condvar::new(),
Mutex::new(FourierGgswCiphertext::new(
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
multi_bit_bsk.decomposition_base_log(),
multi_bit_bsk.decomposition_level_count(),
)),
)
})
.collect::<Vec<_>>();
let (tx, rx) = mpsc::channel::<usize>();
let fft = Fft::new(multi_bit_bsk.polynomial_size());
let fft = fft.as_view();
thread::scope(|s| {
let produce_multi_bit_fourier_ggsw = |thread_id: usize, tx: mpsc::Sender<usize>| {
let mut buffers = ComputationBuffers::new();
buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required());
let mut unit_polynomial =
Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size());
unit_polynomial.as_mut()[0] = Scalar::ONE;
let mut a_monomial = unit_polynomial.clone();
let mut fourier_a_monomial = FourierPolynomial::new(multi_bit_bsk.polynomial_size());
let work_queue = &work_queue;
let dest_idx = thread_id;
let (ready_for_consumer_lock, condvar, fourier_ggsw_buffer) =
&fourier_multi_bit_ggsw_buffers[dest_idx];
loop {
let maybe_work = {
let mut queue_lock = work_queue.lock().unwrap();
queue_lock.pop()
};
let Some((lwe_mask_elements, ggsw_group)) = maybe_work else {break};
let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap();
// Wait while the buffer is not ready for processing and wait on the condvar
// to get notified when we can start processing again
while *ready_for_consumer {
ready_for_consumer = condvar.wait(ready_for_consumer).unwrap();
}
let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap();
let mut ggsw_group_iter = ggsw_group.iter();
// Keygen guarantees the first term is a constant term of the polynomial, no
// polynomial multiplication required
let ggsw_a_none = ggsw_group_iter.next().unwrap();
fourier_ggsw_buffer
.as_mut_view()
.data()
.copy_from_slice(ggsw_a_none.as_view().data());
let multi_bit_fourier_ggsw =
unsafe { as_mut_uninit(fourier_ggsw_buffer.as_mut_view().data()) };
for (ggsw_idx, fourier_ggsw) in ggsw_group_iter.enumerate() {
// We already processed the first ggsw, advance the index by 1
let ggsw_idx = ggsw_idx + 1;
// Select the proper mask elements to build the monomial degree depending on
// the order the GGSW were generated in, using the bits from mask_idx and
// ggsw_idx as selector bits
let mut monomial_degree = Scalar::ZERO;
for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() {
let mask_position = lwe_mask_elements.len() - (mask_idx + 1);
let selection_bit: Scalar =
Scalar::cast_from((ggsw_idx >> mask_position) & 1);
monomial_degree =
monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element));
}
let switched_degree = pbs_modulus_switch(
monomial_degree,
lut_poly_size,
ModulusSwitchOffset(0),
LutCountLog(0),
);
a_monomial
.as_mut()
.copy_from_slice(unit_polynomial.as_ref());
polynomial_wrapping_monic_monomial_mul_assign(
&mut a_monomial,
MonomialDegree(switched_degree),
);
fft.forward_as_integer(
unsafe { fourier_a_monomial.as_mut_view().into_uninit() },
a_monomial.as_view(),
buffers.stack(),
);
unsafe {
update_with_fmadd(
multi_bit_fourier_ggsw,
fourier_ggsw.as_view().data(),
fourier_a_monomial.as_view().data,
false,
lut_poly_size.to_fourier_polynomial_size().0,
);
}
}
// Drop the lock before we wake other threads
drop(fourier_ggsw_buffer);
*ready_for_consumer = true;
tx.send(dest_idx).unwrap();
// Wake threads waiting on the condvar
condvar.notify_all();
}
};
let threads: Vec<_> = (0..thread_count.0)
.map(|id| {
let tx = tx.clone();
s.spawn(move || produce_multi_bit_fourier_ggsw(id, tx))
})
.collect();
// We initialize ct0 for the successive external products
let ct0 = accumulator;
let mut ct1 = GlweCiphertext::new(Scalar::ZERO, ct0.glwe_size(), ct0.polynomial_size());
let ct1 = &mut ct1;
let mut buffers = ComputationBuffers::new();
buffers.resize(
add_external_product_assign_scratch::<Scalar>(
multi_bit_bsk.glwe_size(),
multi_bit_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let mut src_idx = 1usize;
for _ in 0..multi_bit_bsk.multi_bit_input_lwe_dimension().0 {
src_idx ^= 1;
let idx = rx.recv().unwrap();
let (ready_lock, condvar, multi_bit_fourier_ggsw) =
&fourier_multi_bit_ggsw_buffers[idx];
let (src_ct, mut dst_ct) = if src_idx == 0 {
(ct0.as_view(), ct1.as_mut_view())
} else {
(ct1.as_view(), ct0.as_mut_view())
};
dst_ct.as_mut().fill(Scalar::ZERO);
let mut ready = ready_lock.lock().unwrap();
assert!(*ready);
let multi_bit_fourier_ggsw = multi_bit_fourier_ggsw.lock().unwrap();
add_external_product_assign(
dst_ct,
multi_bit_fourier_ggsw.as_view(),
src_ct,
fft,
buffers.stack(),
);
drop(multi_bit_fourier_ggsw);
*ready = false;
// Wake a single producer thread sleeping on the condvar (only one will get to work
// anyways)
condvar.notify_one();
}
if src_idx == 0 {
ct0.as_mut().copy_from_slice(ct1.as_ref());
}
threads.into_iter().for_each(|t| t.join().unwrap());
});
}
/// Perform a programmable bootsrap with given an input [`LWE ciphertext`](`LweCiphertext`), a
/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE multi-bit bootstrap
/// key`](`LweMultiBitBootstrapKey`) in the fourier domain. The result is written in the provided
/// output [`LWE ciphertext`](`LweCiphertext`).
///
/// # Example
///
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define the parameters for a 4 bits message able to hold the doubled 2 bits message
/// let small_lwe_dimension = LweDimension(742);
/// let glwe_dimension = GlweDimension(1);
/// let polynomial_size = PolynomialSize(2048);
/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433);
/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
/// let pbs_base_log = DecompositionBaseLog(23);
/// let pbs_level = DecompositionLevelCount(1);
/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs
///
/// // Request the best seeder possible, starting with hardware entropy sources and falling back to
/// // /dev/random on Unix systems if enabled via cargo features
/// let mut boxed_seeder = new_seeder();
/// // Get a mutable reference to the seeder as a trait object from the Box returned by new_seeder
/// let seeder = boxed_seeder.as_mut();
///
/// // Create a generator which uses a CSPRNG to generate secret keys
/// let mut secret_generator =
/// SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
///
/// // Create a generator which uses two CSPRNGs to generate public masks and secret encryption
/// // noise
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
///
/// println!("Generating keys...");
///
/// // Generate an LweSecretKey with binary coefficients
/// let small_lwe_sk =
/// LweSecretKey::generate_new_binary(small_lwe_dimension, &mut secret_generator);
///
/// // Generate a GlweSecretKey with binary coefficients
/// let glwe_sk =
/// GlweSecretKey::generate_new_binary(glwe_dimension, polynomial_size, &mut secret_generator);
///
/// // Create a copy of the GlweSecretKey re-interpreted as an LweSecretKey
/// let big_lwe_sk = glwe_sk.clone().into_lwe_secret_key();
///
/// let mut bsk = LweMultiBitBootstrapKey::new(
/// 0u64,
/// glwe_dimension.to_glwe_size(),
/// polynomial_size,
/// pbs_base_log,
/// pbs_level,
/// small_lwe_dimension,
/// grouping_factor,
/// );
///
/// par_generate_lwe_multi_bit_bootstrap_key(
/// &small_lwe_sk,
/// &glwe_sk,
/// &mut bsk,
/// glwe_modular_std_dev,
/// &mut encryption_generator,
/// );
///
/// let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new(
/// bsk.input_lwe_dimension(),
/// bsk.glwe_size(),
/// bsk.polynomial_size(),
/// bsk.decomposition_base_log(),
/// bsk.decomposition_level_count(),
/// bsk.grouping_factor(),
/// );
///
/// convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk);
///
/// // We don't need the standard bootstrapping key anymore
/// drop(bsk);
///
/// // Our 4 bits message space
/// let message_modulus = 1u64 << 4;
///
/// // Our input message
/// let input_message = 3u64;
///
/// // Delta used to encode 4 bits of message + a bit of padding on u64
/// let delta = (1_u64 << 63) / message_modulus;
///
/// // Apply our encoding
/// let plaintext = Plaintext(input_message * delta);
///
/// // Allocate a new LweCiphertext and encrypt our plaintext
/// let lwe_ciphertext_in: LweCiphertextOwned<u64> = allocate_and_encrypt_new_lwe_ciphertext(
/// &small_lwe_sk,
/// plaintext,
/// lwe_modular_std_dev,
/// &mut encryption_generator,
/// );
///
/// // Now we will use a PBS to compute a multiplication by 2, it is NOT the recommended way of
/// // doing this operation in terms of performance as it's much more costly than a multiplication
/// // with a cleartext, however it resets the noise in a ciphertext to a nominal level and allows
/// // to evaluate arbitrary functions so depending on your use case it can be a better fit.
///
/// // Here we will define a helper function to generate an accumulator for a PBS
/// fn generate_accumulator<F>(
/// polynomial_size: PolynomialSize,
/// glwe_size: GlweSize,
/// message_modulus: usize,
/// delta: u64,
/// f: F,
/// ) -> GlweCiphertextOwned<u64>
/// where
/// F: Fn(u64) -> u64,
/// {
/// // N/(p/2) = size of each block, to correct noise from the input we introduce the notion of
/// // box, which manages redundancy to yield a denoised value for several noisy values around
/// // a true input value.
/// let box_size = polynomial_size.0 / message_modulus;
///
/// // Create the accumulator
/// let mut accumulator_u64 = vec![0_u64; polynomial_size.0];
///
/// // Fill each box with the encoded denoised value
/// for i in 0..message_modulus {
/// let index = i * box_size;
/// accumulator_u64[index..index + box_size]
/// .iter_mut()
/// .for_each(|a| *a = f(i as u64) * delta);
/// }
///
/// let half_box_size = box_size / 2;
///
/// // Negate the first half_box_size coefficients to manage negacyclicity and rotate
/// for a_i in accumulator_u64[0..half_box_size].iter_mut() {
/// *a_i = (*a_i).wrapping_neg();
/// }
///
/// // Rotate the accumulator
/// accumulator_u64.rotate_left(half_box_size);
///
/// let accumulator_plaintext = PlaintextList::from_container(accumulator_u64);
///
/// let accumulator =
/// allocate_and_trivially_encrypt_new_glwe_ciphertext(glwe_size, &accumulator_plaintext);
///
/// accumulator
/// }
///
/// // Generate the accumulator for our multiplication by 2 using a simple closure
/// let accumulator: GlweCiphertextOwned<u64> = generate_accumulator(
/// polynomial_size,
/// glwe_dimension.to_glwe_size(),
/// message_modulus as usize,
/// delta,
/// |x: u64| 2 * x,
/// );
///
/// // Allocate the LweCiphertext to store the result of the PBS
/// let mut pbs_multiplication_ct =
/// LweCiphertext::new(0u64, big_lwe_sk.lwe_dimension().to_lwe_size());
/// println!("Computing PBS...");
/// // Use 4 threads to compute the multi-bit PBS
/// multi_bit_programmable_bootstrap_lwe_ciphertext(
/// &lwe_ciphertext_in,
/// &mut pbs_multiplication_ct,
/// &accumulator,
/// &multi_bit_bsk,
/// ThreadCount(4),
/// );
///
/// // Decrypt the PBS multiplication result
/// let pbs_multipliation_plaintext: Plaintext<u64> =
/// decrypt_lwe_ciphertext(&big_lwe_sk, &pbs_multiplication_ct);
///
/// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext
/// // We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating we want to
/// // round the 5 MSB, 1 bit of padding plus our 4 bits of message
/// let signed_decomposer =
/// SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1));
///
/// // Round and remove our encoding
/// let pbs_multiplication_result: u64 =
/// signed_decomposer.closest_representable(pbs_multipliation_plaintext.0) / delta;
///
/// println!("Checking result...");
/// assert_eq!(6, pbs_multiplication_result);
/// println!(
/// "Mulitplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}"
/// );
/// ```
pub fn multi_bit_programmable_bootstrap_lwe_ciphertext<
Scalar,
InputCont,
OutputCont,
AccCont,
KeyCont,
>(
input: &LweCiphertext<InputCont>,
output: &mut LweCiphertext<OutputCont>,
accumulator: &GlweCiphertext<AccCont>,
multi_bit_bsk: &FourierLweMultiBitBootstrapKey<KeyCont>,
thread_count: ThreadCount,
) where
// CastInto required for PBS modulus switch which returns a usize
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
AccCont: Container<Element = Scalar>,
KeyCont: Container<Element = c64> + Sync,
{
assert_eq!(
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
"Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \
FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
input.lwe_size().to_lwe_dimension(),
multi_bit_bsk.input_lwe_dimension(),
);
assert_eq!(
output.lwe_size().to_lwe_dimension(),
multi_bit_bsk.output_lwe_dimension(),
"Mimatched output LweDimension. LweCiphertext output LweDimension {:?}. \
FourierLweMultiBitBootstrapKey output LweDimension {:?}.",
output.lwe_size().to_lwe_dimension(),
multi_bit_bsk.output_lwe_dimension(),
);
assert_eq!(
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
"Mimatched GlweSize. Accumulator GlweSize {:?}. \
FourierLweMultiBitBootstrapKey GlweSize {:?}.",
accumulator.glwe_size(),
multi_bit_bsk.glwe_size(),
);
assert_eq!(
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
"Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \
FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
accumulator.polynomial_size(),
multi_bit_bsk.polynomial_size(),
);
let mut local_accumulator = GlweCiphertext::new(
Scalar::ZERO,
accumulator.glwe_size(),
accumulator.polynomial_size(),
);
local_accumulator
.as_mut()
.copy_from_slice(accumulator.as_ref());
multi_bit_blind_rotate_assign(input, &mut local_accumulator, multi_bit_bsk, thread_count);
extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0));
}
#[cfg(test)]
mod test {
use crate::core_crypto::prelude::*;
fn multi_bit_pbs(grouping_factor: LweBskGroupingFactor, thread_count: ThreadCount) {
// DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
// computations
// Define parameters for LweBootstrapKey creation
let mut input_lwe_dimension = LweDimension(742);
let lwe_modular_std_dev = StandardDev(0.000007069849454709433);
let decomp_base_log = DecompositionBaseLog(3);
let decomp_level_count = DecompositionLevelCount(5);
let glwe_dimension = GlweDimension(1);
let polynomial_size = PolynomialSize(1024);
let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533);
while input_lwe_dimension.0 % grouping_factor.0 != 0 {
input_lwe_dimension = LweDimension(input_lwe_dimension.0 + 1);
}
// Create the PRNG
let mut seeder = new_seeder();
let seeder = seeder.as_mut();
let mut encryption_generator =
EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed(), seeder);
let mut secret_generator =
SecretRandomGenerator::<ActivatedRandomGenerator>::new(seeder.seed());
// Create the LweSecretKey
let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
input_lwe_dimension,
&mut secret_generator,
);
let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key(
glwe_dimension,
polynomial_size,
&mut secret_generator,
);
let output_lwe_secret_key = output_glwe_secret_key.clone().into_lwe_secret_key();
let mut bsk = LweMultiBitBootstrapKey::new(
0u64,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_base_log,
decomp_level_count,
input_lwe_dimension,
grouping_factor,
);
par_generate_lwe_multi_bit_bootstrap_key(
&input_lwe_secret_key,
&output_glwe_secret_key,
&mut bsk,
glwe_modular_std_dev,
&mut encryption_generator,
);
let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new(
input_lwe_dimension,
glwe_dimension.to_glwe_size(),
polynomial_size,
decomp_base_log,
decomp_level_count,
grouping_factor,
);
convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk);
// Here we will define a helper function to generate an accumulator for a PBS
fn generate_accumulator<F>(
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
message_modulus: usize,
delta: u64,
f: F,
) -> GlweCiphertextOwned<u64>
where
F: Fn(u64) -> u64,
{
// N/(p/2) = size of each block, to correct noise from the input we introduce the
// notion of box, which manages redundancy to yield a denoised value
// for several noisy values around a true input value.
let box_size = polynomial_size.0 / message_modulus;
// Create the accumulator
let mut accumulator_u64 = vec![0_u64; polynomial_size.0];
// Fill each box with the encoded denoised value
for i in 0..message_modulus {
let index = i * box_size;
accumulator_u64[index..index + box_size]
.iter_mut()
.for_each(|a| *a = f(i as u64) * delta);
}
let half_box_size = box_size / 2;
// Negate the first half_box_size coefficients to manage negacyclicity and rotate
for a_i in accumulator_u64[0..half_box_size].iter_mut() {
*a_i = (*a_i).wrapping_neg();
}
// Rotate the accumulator
accumulator_u64.rotate_left(half_box_size);
let accumulator_plaintext = PlaintextList::from_container(accumulator_u64);
allocate_and_trivially_encrypt_new_glwe_ciphertext(glwe_size, &accumulator_plaintext)
}
// Our 4 bits message space
let message_modulus = 1u64 << 4;
let f = |x: u64| (3 * x) % message_modulus;
// Delta used to encode 4 bits of message + a bit of padding on u64
let delta = (1_u64 << 63) / message_modulus;
const NB_TESTS: usize = 10;
for input_message in 0..message_modulus {
for _ in 0..NB_TESTS {
// Apply our encoding
let plaintext = Plaintext(input_message * delta);
// Allocate a new LweCiphertext and encrypt our plaintext
let lwe_ciphertext_in: LweCiphertextOwned<u64> =
allocate_and_encrypt_new_lwe_ciphertext(
&input_lwe_secret_key,
plaintext,
lwe_modular_std_dev,
&mut encryption_generator,
);
let accumulator: GlweCiphertextOwned<u64> = generate_accumulator(
polynomial_size,
glwe_dimension.to_glwe_size(),
message_modulus as usize,
delta,
f,
);
// Allocate the LweCiphertext to store the result of the PBS
let mut out_pbs_ct =
LweCiphertext::new(0u64, output_lwe_secret_key.lwe_dimension().to_lwe_size());
println!("Computing PBS...");
multi_bit_programmable_bootstrap_lwe_ciphertext(
&lwe_ciphertext_in,
&mut out_pbs_ct,
&accumulator.as_view(),
&multi_bit_bsk,
thread_count,
);
// Decrypt the PBS result
let result_plaintext: Plaintext<u64> =
decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct);
// Create a SignedDecomposer to perform the rounding of the decrypted plaintext
// We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating
// we want to round the 5 MSB, 1 bit of padding plus our 4 bits of
// message
let signed_decomposer =
SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1));
// Round and remove our encoding
let result_cleartext: u64 =
signed_decomposer.closest_representable(result_plaintext.0) / delta;
println!("Checking result...");
assert_eq!(
f(input_message),
result_cleartext,
"in: {input_message}, expected: {}, out: {result_cleartext}",
f(input_message)
);
}
}
}
#[test]
fn multi_bit_pbs_test_factor_2_thread_5() {
multi_bit_pbs(LweBskGroupingFactor(2), ThreadCount(5));
}
#[test]
fn multi_bit_pbs_test_factor_3_thread_12() {
multi_bit_pbs(LweBskGroupingFactor(3), ThreadCount(12));
}
}

View File

@@ -85,11 +85,12 @@ pub fn private_functional_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphert
Scalar: UnsignedTorus,
KeyCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar> + Clone,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(input.lwe_ciphertext_count().0 <= output.polynomial_size().0);
output.as_mut().fill(Scalar::ZERO);
let mut buffer = output.clone();
let mut buffer =
GlweCiphertext::new(Scalar::ZERO, output.glwe_size(), output.polynomial_size());
// for each ciphertext, call mono_key_switch
for (degree, input_ciphertext) in input.iter().enumerate() {
private_functional_keyswitch_lwe_ciphertext_into_glwe_ciphertext(

View File

@@ -918,7 +918,7 @@ pub fn cmux_assign_mem_optimized_requirement<Scalar>(
/// let pbs_multipliation_plaintext: Plaintext<u64> =
/// decrypt_lwe_ciphertext(&big_lwe_sk, &pbs_multiplication_ct);
///
/// /// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext
/// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext
/// // We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating we want to
/// // round the 5 MSB, 1 bit of padding plus our 4 bits of message
/// let signed_decomposer =

View File

@@ -13,6 +13,9 @@ pub mod lwe_encryption;
pub mod lwe_keyswitch;
pub mod lwe_keyswitch_key_generation;
pub mod lwe_linear_algebra;
pub mod lwe_multi_bit_bootstrap_key_conversion;
pub mod lwe_multi_bit_bootstrap_key_generation;
pub mod lwe_multi_bit_programmable_bootstrapping;
pub mod lwe_private_functional_packing_keyswitch;
pub mod lwe_private_functional_packing_keyswitch_key_generation;
pub mod lwe_programmable_bootstrapping;
@@ -44,6 +47,9 @@ pub use lwe_encryption::*;
pub use lwe_keyswitch::*;
pub use lwe_keyswitch_key_generation::*;
pub use lwe_linear_algebra::*;
pub use lwe_multi_bit_bootstrap_key_conversion::*;
pub use lwe_multi_bit_bootstrap_key_generation::*;
pub use lwe_multi_bit_programmable_bootstrapping::*;
pub use lwe_private_functional_packing_keyswitch::*;
pub use lwe_private_functional_packing_keyswitch_key_generation::*;
pub use lwe_programmable_bootstrapping::*;

View File

@@ -9,7 +9,7 @@ use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::parameters::{
DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, GlweDimension, GlweSize,
LweCiphertextCount, LweDimension, LweSize, PolynomialSize,
LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweSize, PolynomialSize,
};
use concrete_csprng::generators::ForkError;
use rayon::prelude::*;
@@ -45,6 +45,20 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
self.mask.remaining_bytes()
}
pub(crate) fn fork_n(
&mut self,
n: usize,
) -> Result<impl Iterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
// We use ForkTooLarge here as what can fail is the conversion from u128 to usize
let mask_bytes = self.mask.remaining_bytes().ok_or(ForkError::ForkTooLarge)? / n;
let noise_bytes = self
.noise
.remaining_bytes()
.ok_or(ForkError::ForkTooLarge)?
/ n;
self.try_fork(n, mask_bytes, noise_bytes)
}
// Forks the generator, when splitting a bootstrap key into ggsw ct.
pub(crate) fn fork_bsk_to_ggsw<T: UnsignedInteger>(
&mut self,
@@ -58,6 +72,21 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
self.try_fork(lwe_dimension.0, mask_bytes, noise_bytes)
}
// Forks the generator, when splitting a multi_bit bootstrap key into ggsw ciphertext groups.
pub(crate) fn fork_multi_bit_bsk_to_ggsw_group<T: UnsignedInteger>(
&mut self,
lwe_dimension: LweDimension,
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
) -> Result<impl Iterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
let ggsw_count = grouping_factor.ggsw_per_multi_bit_element();
let mask_bytes = ggsw_count.0 * mask_bytes_per_ggsw::<T>(level, glwe_size, polynomial_size);
let noise_bytes = ggsw_count.0 * noise_bytes_per_ggsw(level, glwe_size, polynomial_size);
self.try_fork(lwe_dimension.0, mask_bytes, noise_bytes)
}
// Forks the generator, when splitting a ggsw into level matrices.
pub(crate) fn fork_ggsw_to_ggsw_levels<T: UnsignedInteger>(
&mut self,
@@ -210,6 +239,15 @@ impl<G: ByteRandomGenerator> EncryptionRandomGenerator<G> {
}
impl<G: ParallelByteRandomGenerator> EncryptionRandomGenerator<G> {
pub(crate) fn par_fork_n(
&mut self,
n: usize,
) -> Result<impl IndexedParallelIterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
let mask_bytes = self.mask.remaining_bytes().unwrap() / n;
let noise_bytes = self.noise.remaining_bytes().unwrap() / n;
self.par_try_fork(n, mask_bytes, noise_bytes)
}
// Forks the generator into a parallel iterator, when splitting a bootstrap key into ggsw ct.
pub(crate) fn par_fork_bsk_to_ggsw<T: UnsignedInteger>(
&mut self,
@@ -220,7 +258,21 @@ impl<G: ParallelByteRandomGenerator> EncryptionRandomGenerator<G> {
) -> Result<impl IndexedParallelIterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
let mask_bytes = mask_bytes_per_ggsw::<T>(level, glwe_size, polynomial_size);
let noise_bytes = noise_bytes_per_ggsw(level, glwe_size, polynomial_size);
// panic!("{:?} {:?} {:?}", lwe_dimension.0, mask_bytes, noise_bytes);
self.par_try_fork(lwe_dimension.0, mask_bytes, noise_bytes)
}
// Forks the generator, when splitting a multi_bit bootstrap key into ggsw ct.
pub(crate) fn par_fork_multi_bit_bsk_to_ggsw_group<T: UnsignedInteger>(
&mut self,
lwe_dimension: LweDimension,
level: DecompositionLevelCount,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
grouping_factor: LweBskGroupingFactor,
) -> Result<impl IndexedParallelIterator<Item = EncryptionRandomGenerator<G>>, ForkError> {
let ggsw_count = grouping_factor.ggsw_per_multi_bit_element();
let mask_bytes = ggsw_count.0 * mask_bytes_per_ggsw::<T>(level, glwe_size, polynomial_size);
let noise_bytes = ggsw_count.0 * noise_bytes_per_ggsw(level, glwe_size, polynomial_size);
self.par_try_fork(lwe_dimension.0, mask_bytes, noise_bytes)
}

View File

@@ -102,7 +102,7 @@ impl GlweDimension {
/// The number of coefficients of a polynomial.
///
/// Assuming a polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns $N$.
/// Assuming a polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this new-type contains $N$.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct PolynomialSize(pub usize);
@@ -124,7 +124,7 @@ impl PolynomialSize {
/// The number of elements in the container of a fourier polynomial.
///
/// Assuming a standard polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns
/// Assuming a standard polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this new-type contains
/// $\frac{N}{2}$.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct FourierPolynomialSize(pub usize);
@@ -204,3 +204,21 @@ pub struct FunctionalPackingKeyswitchKeyCount(pub usize);
/// The number of bits used for the mask coefficients and the body of a ciphertext
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
pub struct CiphertextModulusLog(pub usize);
/// The number of cpu execution thread to use
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
pub struct ThreadCount(pub usize);
/// The number of key bits grouped together in the multi_bit PBS
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
pub struct LweBskGroupingFactor(pub usize);
impl LweBskGroupingFactor {
pub fn ggsw_per_multi_bit_element(&self) -> GgswPerLweMultiBitBskElement {
GgswPerLweMultiBitBskElement(1 << self.0)
}
}
/// The number of GGSW ciphertexts required per multi_bit BSK element
#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
pub struct GgswPerLweMultiBitBskElement(pub usize);

View File

@@ -192,6 +192,25 @@ pub fn ggsw_level_matrix_size(glwe_size: GlweSize, polynomial_size: PolynomialSi
glwe_size.0 * glwe_size.0 * polynomial_size.0
}
/// Return the number of elements in a [`FourierGgswCiphertext`] given a [`GlweSize`],
/// [`FourierPolynomialSize`] and [`DecompositionLevelCount`].
pub fn fourier_ggsw_ciphertext_size(
glwe_size: GlweSize,
fourier_polynomial_size: FourierPolynomialSize,
decomp_level_count: DecompositionLevelCount,
) -> usize {
decomp_level_count.0 * fourier_ggsw_level_matrix_size(glwe_size, fourier_polynomial_size)
}
/// Return the number of elements in a [`FourierGgswLevelMatrix`] given a [`GlweSize`] and
/// [`FourierPolynomialSize`].
pub fn fourier_ggsw_level_matrix_size(
glwe_size: GlweSize,
fourier_polynomial_size: FourierPolynomialSize,
) -> usize {
glwe_size.0 * glwe_size.0 * fourier_polynomial_size.0
}
impl<Scalar, C: Container<Element = Scalar>> GgswCiphertext<C> {
/// Create a [`GgswCiphertext`] from an existing container.
///

View File

@@ -163,6 +163,17 @@ impl<Scalar, C: Container<Element = Scalar>> GgswCiphertextList<C> {
pub fn into_container(self) -> C {
self.data
}
pub fn as_polynomial_list(&self) -> PolynomialListView<'_, Scalar> {
PolynomialList::from_container(self.as_ref(), self.polynomial_size())
}
}
impl<Scalar, C: ContainerMut<Element = Scalar>> GgswCiphertextList<C> {
pub fn as_mut_polynomial_list(&mut self) -> PolynomialListMutView<'_, Scalar> {
let polynomial_size = self.polynomial_size();
PolynomialList::from_container(self.as_mut(), polynomial_size)
}
}
/// A [`GgswCiphertextList`] owning the memory for its own storage.

View File

@@ -0,0 +1,470 @@
//! Module containing the definition of the [`LweMultiBitBootstrapKey`].
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::math::fft::FourierPolynomialList;
use aligned_vec::{avec, ABox};
use concrete_fft::c64;
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct LweMultiBitBootstrapKey<C: Container> {
// An LweMultiBitBootstrapKey is literally a GgswCiphertextList, so we wrap a
// GgswCiphertextList and use Deref to have access to all the primitives of the
// GgswCiphertextList easily
ggsw_list: GgswCiphertextList<C>,
grouping_factor: LweBskGroupingFactor,
}
impl<C: Container> std::ops::Deref for LweMultiBitBootstrapKey<C> {
type Target = GgswCiphertextList<C>;
fn deref(&self) -> &GgswCiphertextList<C> {
&self.ggsw_list
}
}
impl<C: ContainerMut> std::ops::DerefMut for LweMultiBitBootstrapKey<C> {
fn deref_mut(&mut self) -> &mut GgswCiphertextList<C> {
&mut self.ggsw_list
}
}
impl<Scalar, C: Container<Element = Scalar>> LweMultiBitBootstrapKey<C> {
/// Create an [`LweMultiBitBootstrapKey`] from an existing container.
///
/// # Note
///
/// This function only wraps a container in the appropriate type. If you want to generate an LWE
/// bootstrap key you need to use
/// [`crate::core_crypto::algorithms::generate_lwe_multi_bit_bootstrap_key`] or its parallel
/// equivalent [`crate::core_crypto::algorithms::par_generate_lwe_multi_bit_bootstrap_key`]
/// using this key as output.
///
/// This docstring exhibits [`LweMultiBitBootstrapKey`] primitives usage.
///
/// ```
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweMultiBitBootstrapKey creation
/// let glwe_size = GlweSize(2);
/// let polynomial_size = PolynomialSize(1024);
/// let decomp_base_log = DecompositionBaseLog(8);
/// let decomp_level_count = DecompositionLevelCount(3);
/// let input_lwe_dimension = LweDimension(600);
/// let grouping_factor = LweBskGroupingFactor(2);
///
/// // Create a new LweMultiBitBootstrapKey
/// let bsk = LweMultiBitBootstrapKey::new(
/// 0u64,
/// glwe_size,
/// polynomial_size,
/// decomp_base_log,
/// decomp_level_count,
/// input_lwe_dimension,
/// grouping_factor,
/// );
///
/// // These methods are "inherited" from GgswCiphertextList and are accessed through the Deref
/// // trait
/// assert_eq!(bsk.glwe_size(), glwe_size);
/// assert_eq!(bsk.polynomial_size(), polynomial_size);
/// assert_eq!(bsk.decomposition_base_log(), decomp_base_log);
/// assert_eq!(bsk.decomposition_level_count(), decomp_level_count);
///
/// // These methods are specific to the LweMultiBitBootstrapKey
/// assert_eq!(bsk.input_lwe_dimension(), input_lwe_dimension);
/// assert_eq!(
/// bsk.multi_bit_input_lwe_dimension(),
/// LweDimension(input_lwe_dimension.0 / grouping_factor.0)
/// );
/// assert_eq!(
/// bsk.output_lwe_dimension().0,
/// glwe_size.to_glwe_dimension().0 * polynomial_size.0
/// );
/// assert_eq!(bsk.grouping_factor(), grouping_factor);
///
/// // Demonstrate how to recover the allocated container
/// let underlying_container: Vec<u64> = bsk.into_container();
///
/// // Recreate a key using from_container
/// let bsk = LweMultiBitBootstrapKey::from_container(
/// underlying_container,
/// glwe_size,
/// polynomial_size,
/// decomp_base_log,
/// decomp_level_count,
/// grouping_factor,
/// );
///
/// assert_eq!(bsk.glwe_size(), glwe_size);
/// assert_eq!(bsk.polynomial_size(), polynomial_size);
/// assert_eq!(bsk.decomposition_base_log(), decomp_base_log);
/// assert_eq!(bsk.decomposition_level_count(), decomp_level_count);
/// assert_eq!(bsk.input_lwe_dimension(), input_lwe_dimension);
/// assert_eq!(
/// bsk.multi_bit_input_lwe_dimension(),
/// LweDimension(input_lwe_dimension.0 / grouping_factor.0)
/// );
/// assert_eq!(
/// bsk.output_lwe_dimension().0,
/// glwe_size.to_glwe_dimension().0 * polynomial_size.0
/// );
/// assert_eq!(bsk.grouping_factor(), grouping_factor);
/// ```
pub fn from_container(
container: C,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
) -> LweMultiBitBootstrapKey<C> {
let bsk = LweMultiBitBootstrapKey {
ggsw_list: GgswCiphertextList::from_container(
container,
glwe_size,
polynomial_size,
decomp_base_log,
decomp_level_count,
),
grouping_factor,
};
assert!(
bsk.ggsw_ciphertext_count().0 % grouping_factor.0 == 0,
"Number of GGSW ({}) in the bootstrap key needs to be a multiple of {}",
bsk.ggsw_ciphertext_count().0,
grouping_factor.0,
);
bsk
}
/// Return the [`LweDimension`] of the input [`LweSecretKey`].
///
/// See [`LweMultiBitBootstrapKey::from_container`] for usage.
pub fn input_lwe_dimension(&self) -> LweDimension {
let grouping_factor = self.grouping_factor;
let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element();
LweDimension(
self.ggsw_ciphertext_count().0 * grouping_factor.0 / ggsw_per_multi_bit_element.0,
)
}
/// Return the [`LweDimension`] of the input [`LweSecretKey`] taking into consideration the
/// grouping factor. This essentially returns the input [`LweDimension`] divided by the grouping
/// factor.
///
/// See [`LweMultiBitBootstrapKey::from_container`] for usage.
pub fn multi_bit_input_lwe_dimension(&self) -> LweDimension {
LweDimension(self.input_lwe_dimension().0 / self.grouping_factor.0)
}
/// Return the [`LweDimension`] of the equivalent output [`LweSecretKey`].
///
/// See [`LweMultiBitBootstrapKey::from_container`] for usage.
pub fn output_lwe_dimension(&self) -> LweDimension {
LweDimension(self.glwe_size().to_glwe_dimension().0 * self.polynomial_size().0)
}
/// Return the [`LweBskGroupingFactor`] of the current [`LweMultiBitBootstrapKey`].
///
/// See [`LweMultiBitBootstrapKey::from_container`] for usage.
pub fn grouping_factor(&self) -> LweBskGroupingFactor {
self.grouping_factor
}
/// Consume the entity and return its underlying container.
///
/// See [`LweMultiBitBootstrapKey::from_container`] for usage.
pub fn into_container(self) -> C {
self.ggsw_list.into_container()
}
/// Return a view of the [`LweMultiBitBootstrapKey`]. This is useful if an algorithm takes a
/// view by value.
pub fn as_view(&self) -> LweMultiBitBootstrapKey<&'_ [Scalar]> {
LweMultiBitBootstrapKey::from_container(
self.as_ref(),
self.glwe_size(),
self.polynomial_size(),
self.decomposition_base_log(),
self.decomposition_level_count(),
self.grouping_factor(),
)
}
}
impl<Scalar, C: ContainerMut<Element = Scalar>> LweMultiBitBootstrapKey<C> {
/// Mutable variant of [`LweMultiBitBootstrapKey::as_view`].
pub fn as_mut_view(&mut self) -> LweMultiBitBootstrapKey<&'_ mut [Scalar]> {
let glwe_size = self.glwe_size();
let polynomial_size = self.polynomial_size();
let decomp_base_log = self.decomposition_base_log();
let decomp_level_count = self.decomposition_level_count();
let grouping_factor = self.grouping_factor();
LweMultiBitBootstrapKey::from_container(
self.as_mut(),
glwe_size,
polynomial_size,
decomp_base_log,
decomp_level_count,
grouping_factor,
)
}
}
/// An [`LweMultiBitBootstrapKey`] owning the memory for its own storage.
pub type LweMultiBitBootstrapKeyOwned<Scalar> = LweMultiBitBootstrapKey<Vec<Scalar>>;
impl<Scalar: Copy> LweMultiBitBootstrapKeyOwned<Scalar> {
/// Allocate memory and create a new owned [`LweMultiBitBootstrapKey`].
///
/// # Note
///
/// This function allocates a vector of the appropriate size and wraps it in the appropriate
/// type. If you want to generate an LWE bootstrap key you need to use
/// [`crate::core_crypto::algorithms::generate_lwe_bootstrap_key`] or its parallel
/// equivalent [`crate::core_crypto::algorithms::par_generate_lwe_bootstrap_key`] using this
/// key as output.
///
/// See [`LweMultiBitBootstrapKey::from_container`] for usage.
pub fn new(
fill_with: Scalar,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
decomp_base_log: DecompositionBaseLog,
decomp_level_count: DecompositionLevelCount,
input_lwe_dimension: LweDimension,
grouping_factor: LweBskGroupingFactor,
) -> LweMultiBitBootstrapKeyOwned<Scalar> {
assert!(
input_lwe_dimension.0 % grouping_factor.0 == 0,
"Multi Bit BSK requires input LWE dimension ({}) to be a multiple of {}",
input_lwe_dimension.0,
grouping_factor.0
);
// For two bits multi_bit together
let equivalent_multi_bit_dimension = input_lwe_dimension.0 / grouping_factor.0;
LweMultiBitBootstrapKeyOwned {
ggsw_list: GgswCiphertextList::new(
fill_with,
glwe_size,
polynomial_size,
decomp_base_log,
decomp_level_count,
GgswCiphertextCount(
equivalent_multi_bit_dimension * grouping_factor.ggsw_per_multi_bit_element().0,
),
),
grouping_factor,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FourierLweMultiBitBootstrapKey<C: Container<Element = c64>> {
fourier: FourierPolynomialList<C>,
input_lwe_dimension: LweDimension,
glwe_size: GlweSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
}
pub type FourierLweMultiBitBootstrapKeyOwned = FourierLweMultiBitBootstrapKey<ABox<[c64]>>;
pub type FourierLweMultiBitBootstrapKeyView<'a> = FourierLweMultiBitBootstrapKey<&'a [c64]>;
pub type FourierLweMultiBitBootstrapKeyMutView<'a> = FourierLweMultiBitBootstrapKey<&'a mut [c64]>;
impl<C: Container<Element = c64>> FourierLweMultiBitBootstrapKey<C> {
pub fn from_container(
data: C,
input_lwe_dimension: LweDimension,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
) -> Self {
assert!(
input_lwe_dimension.0 % grouping_factor.0 == 0,
"Multi Bit BSK requires input LWE dimension to be a multiple of {}",
grouping_factor.0
);
let equivalent_multi_bit_dimension = input_lwe_dimension.0 / grouping_factor.0;
let ggsw_count =
equivalent_multi_bit_dimension * grouping_factor.ggsw_per_multi_bit_element().0;
let expected_container_size = ggsw_count
* fourier_ggsw_ciphertext_size(
glwe_size,
polynomial_size.to_fourier_polynomial_size(),
decomposition_level_count,
);
assert_eq!(data.container_len(), expected_container_size);
Self {
fourier: FourierPolynomialList {
data,
polynomial_size,
},
input_lwe_dimension,
glwe_size,
decomposition_base_log,
decomposition_level_count,
grouping_factor,
}
}
/// Return an iterator over the GGSW ciphertexts composing the key.
pub fn ggsw_iter(
&self,
) -> impl DoubleEndedIterator<Item = FourierGgswCiphertext<&'_ [C::Element]>> {
self.fourier
.data
.as_ref()
.chunks_exact(fourier_ggsw_ciphertext_size(
self.glwe_size,
self.fourier.polynomial_size.to_fourier_polynomial_size(),
self.decomposition_level_count,
))
.map(move |slice| {
FourierGgswCiphertext::from_container(
slice,
self.glwe_size,
self.fourier.polynomial_size,
self.decomposition_base_log,
self.decomposition_level_count,
)
})
}
pub fn input_lwe_dimension(&self) -> LweDimension {
self.input_lwe_dimension
}
pub fn multi_bit_input_lwe_dimension(&self) -> LweDimension {
LweDimension(self.input_lwe_dimension().0 / self.grouping_factor.0)
}
pub fn polynomial_size(&self) -> PolynomialSize {
self.fourier.polynomial_size
}
pub fn glwe_size(&self) -> GlweSize {
self.glwe_size
}
pub fn decomposition_base_log(&self) -> DecompositionBaseLog {
self.decomposition_base_log
}
pub fn decomposition_level_count(&self) -> DecompositionLevelCount {
self.decomposition_level_count
}
pub fn output_lwe_dimension(&self) -> LweDimension {
LweDimension((self.glwe_size.0 - 1) * self.polynomial_size().0)
}
pub fn grouping_factor(&self) -> LweBskGroupingFactor {
self.grouping_factor
}
pub fn data(self) -> C {
self.fourier.data
}
pub fn as_view(&self) -> FourierLweMultiBitBootstrapKeyView<'_> {
FourierLweMultiBitBootstrapKeyView {
fourier: FourierPolynomialList {
data: self.fourier.data.as_ref(),
polynomial_size: self.fourier.polynomial_size,
},
input_lwe_dimension: self.input_lwe_dimension,
glwe_size: self.glwe_size,
decomposition_base_log: self.decomposition_base_log,
decomposition_level_count: self.decomposition_level_count,
grouping_factor: self.grouping_factor,
}
}
pub fn as_mut_view(&mut self) -> FourierLweMultiBitBootstrapKeyMutView<'_>
where
C: AsMut<[c64]>,
{
FourierLweMultiBitBootstrapKeyMutView {
fourier: FourierPolynomialList {
data: self.fourier.data.as_mut(),
polynomial_size: self.fourier.polynomial_size,
},
input_lwe_dimension: self.input_lwe_dimension,
glwe_size: self.glwe_size,
decomposition_base_log: self.decomposition_base_log,
decomposition_level_count: self.decomposition_level_count,
grouping_factor: self.grouping_factor,
}
}
pub fn as_polynomial_list(&self) -> FourierPolynomialList<&'_ [c64]> {
FourierPolynomialList {
data: self.fourier.data.as_ref(),
polynomial_size: self.fourier.polynomial_size,
}
}
pub fn as_mut_polynomial_list(&mut self) -> FourierPolynomialList<&'_ mut [c64]>
where
C: AsMut<[c64]>,
{
FourierPolynomialList {
data: self.fourier.data.as_mut(),
polynomial_size: self.fourier.polynomial_size,
}
}
}
impl FourierLweMultiBitBootstrapKeyOwned {
pub fn new(
input_lwe_dimension: LweDimension,
glwe_size: GlweSize,
polynomial_size: PolynomialSize,
decomposition_base_log: DecompositionBaseLog,
decomposition_level_count: DecompositionLevelCount,
grouping_factor: LweBskGroupingFactor,
) -> Self {
assert!(
input_lwe_dimension.0 % grouping_factor.0 == 0,
"Multi Bit BSK requires input LWE dimension ({}) to be a multiple of {}",
input_lwe_dimension.0,
grouping_factor.0
);
let equivalent_multi_bit_dimension = input_lwe_dimension.0 / grouping_factor.0;
let ggsw_count =
equivalent_multi_bit_dimension * grouping_factor.ggsw_per_multi_bit_element().0;
let container_size = ggsw_count
* fourier_ggsw_ciphertext_size(
glwe_size,
polynomial_size.to_fourier_polynomial_size(),
decomposition_level_count,
);
let boxed = avec![
c64::default();
container_size
]
.into_boxed_slice();
Self {
fourier: FourierPolynomialList {
data: boxed,
polynomial_size,
},
input_lwe_dimension,
glwe_size,
decomposition_base_log,
decomposition_level_count,
grouping_factor,
}
}
}

View File

@@ -14,6 +14,7 @@ pub mod lwe_bootstrap_key;
pub mod lwe_ciphertext;
pub mod lwe_ciphertext_list;
pub mod lwe_keyswitch_key;
pub mod lwe_multi_bit_bootstrap_key;
pub mod lwe_private_functional_packing_keyswitch_key;
pub mod lwe_private_functional_packing_keyswitch_key_list;
pub mod lwe_public_key;
@@ -35,7 +36,8 @@ pub mod seeded_lwe_public_key;
pub use crate::core_crypto::fft_impl::crypto::bootstrap::{
FourierLweBootstrapKey, FourierLweBootstrapKeyOwned,
};
pub use crate::core_crypto::fft_impl::crypto::ggsw::FourierGgswCiphertext;
pub use crate::core_crypto::fft_impl::crypto::ggsw::*;
pub use crate::core_crypto::fft_impl::math::polynomial::FourierPolynomial;
pub use cleartext::*;
pub use ggsw_ciphertext::*;
pub use ggsw_ciphertext_list::*;
@@ -47,6 +49,7 @@ pub use lwe_bootstrap_key::*;
pub use lwe_ciphertext::*;
pub use lwe_ciphertext_list::*;
pub use lwe_keyswitch_key::*;
pub use lwe_multi_bit_bootstrap_key::*;
pub use lwe_private_functional_packing_keyswitch_key::*;
pub use lwe_private_functional_packing_keyswitch_key_list::*;
pub use lwe_public_key::*;

View File

@@ -428,7 +428,7 @@ pub fn add_external_product_assign<Scalar, InputGlweCont>(
unsafe {
update_with_fmadd(
output_fft_buffer,
ggsw_row,
ggsw_row.data(),
fourier,
is_output_uninit,
fourier_poly_size,
@@ -637,9 +637,9 @@ unsafe fn update_with_fmadd_scalar(
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
#[cfg_attr(__profiling, inline(never))]
unsafe fn update_with_fmadd(
pub(crate) unsafe fn update_with_fmadd(
output_fft_buffer: &mut [MaybeUninit<c64>],
ggsw_row: FourierGgswLevelRowView,
lhs_polynomial_list: &[c64],
fourier: &[c64],
is_output_uninit: bool,
fourier_poly_size: usize,
@@ -665,10 +665,10 @@ unsafe fn update_with_fmadd(
izip!(
output_fft_buffer.into_chunks(fourier_poly_size),
ggsw_row.data.into_chunks(fourier_poly_size)
lhs_polynomial_list.into_chunks(fourier_poly_size)
)
.for_each(|(output_fourier, ggsw_poly)| {
ptr(output_fourier, ggsw_poly, fourier, is_output_uninit);
.for_each(|(output_fourier, poly_from_list)| {
ptr(output_fourier, poly_from_list, fourier, is_output_uninit);
});
}

View File

@@ -5,8 +5,8 @@ use super::polynomial::{
};
use crate::core_crypto::commons::math::torus::UnsignedTorus;
use crate::core_crypto::commons::numeric::CastInto;
use crate::core_crypto::commons::parameters::PolynomialSize;
use crate::core_crypto::commons::traits::{Container, IntoContainerOwned};
use crate::core_crypto::commons::parameters::{PolynomialCount, PolynomialSize};
use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainerOwned};
use crate::core_crypto::commons::utils::izip;
use crate::core_crypto::entities::*;
use aligned_vec::{avec, ABox};
@@ -370,7 +370,7 @@ impl<'a> FftView<'a> {
self.plan
.fft_scratch()?
.try_and(StackReq::try_new_aligned::<c64>(
self.polynomial_size().0 / 2,
self.polynomial_size().to_fourier_polynomial_size().0,
aligned_vec::CACHELINE_ALIGN,
)?)
}
@@ -521,6 +521,28 @@ pub struct FourierPolynomialList<C: Container<Element = c64>> {
pub polynomial_size: PolynomialSize,
}
impl<C: Container<Element = c64>> FourierPolynomialList<C> {
pub fn polynomial_count(&self) -> PolynomialCount {
PolynomialCount(
self.data.container_len() / self.polynomial_size.to_fourier_polynomial_size().0,
)
}
}
impl<C: ContainerMut<Element = c64>> FourierPolynomialList<C> {
pub fn iter_mut(
&mut self,
) -> impl DoubleEndedIterator<Item = FourierPolynomial<&'_ mut [c64]>> {
assert!(
self.data.container_len() % self.polynomial_size.to_fourier_polynomial_size().0 == 0
);
self.data
.as_mut()
.chunks_exact_mut(self.polynomial_size.to_fourier_polynomial_size().0)
.map(move |slice| FourierPolynomial { data: slice })
}
}
impl<C: Container<Element = c64>> serde::Serialize for FourierPolynomialList<C> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
fn serialize_impl<S: serde::Serializer>(

View File

@@ -1,6 +1,8 @@
use super::super::as_mut_uninit;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::Polynomial;
use aligned_vec::{avec, ABox};
use concrete_fft::c64;
//--------------------------------------------------------------------------------
@@ -21,6 +23,20 @@ pub struct FourierPolynomial<C: Container> {
pub type FourierPolynomialView<'a> = FourierPolynomial<&'a [c64]>;
pub type FourierPolynomialMutView<'a> = FourierPolynomial<&'a mut [c64]>;
pub type FourierPolynomialOwned = FourierPolynomial<ABox<[c64]>>;
impl FourierPolynomial<ABox<[c64]>> {
pub fn new(polynomial_size: PolynomialSize) -> FourierPolynomial<ABox<[c64]>> {
let boxed = avec![
c64::default();
polynomial_size.to_fourier_polynomial_size().0
]
.into_boxed_slice();
FourierPolynomial { data: boxed }
}
}
/// Polynomial in the standard domain, with possibly uninitialized coefficients.
///
/// This is used for the Fourier transforms to avoid the cost of initializing the output buffer,
@@ -54,6 +70,10 @@ impl<C: Container<Element = c64>> FourierPolynomial<C> {
data: self.data.as_mut(),
}
}
pub fn polynomial_size(&self) -> PolynomialSize {
PolynomialSize(self.data.container_len() * 2)
}
}
impl<'a, Scalar> Polynomial<&'a mut [Scalar]> {