From 75f05c0f3a563c7ed4ebf283d3698e49d5988747 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Wed, 8 Feb 2023 15:58:52 +0100 Subject: [PATCH] feat(core): add multi-bit BSK generation and PBS threaded implementation --- tfhe/Cargo.toml | 6 + tfhe/benches/core_crypto/dev_bench.rs | 310 ++++++ tfhe/benches/core_crypto/pbs_bench.rs | 6 +- .../algorithms/glwe_sample_extraction.rs | 5 +- .../lwe_multi_bit_bootstrap_key_conversion.rs | 84 ++ .../lwe_multi_bit_bootstrap_key_generation.rs | 442 ++++++++ ...we_multi_bit_programmable_bootstrapping.rs | 945 ++++++++++++++++++ ...we_private_functional_packing_keyswitch.rs | 5 +- .../lwe_programmable_bootstrapping.rs | 2 +- tfhe/src/core_crypto/algorithms/mod.rs | 6 + .../commons/generators/encryption.rs | 56 +- tfhe/src/core_crypto/commons/parameters.rs | 22 +- .../core_crypto/entities/ggsw_ciphertext.rs | 19 + .../entities/ggsw_ciphertext_list.rs | 11 + .../entities/lwe_multi_bit_bootstrap_key.rs | 470 +++++++++ tfhe/src/core_crypto/entities/mod.rs | 5 +- tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs | 12 +- tfhe/src/core_crypto/fft_impl/math/fft/mod.rs | 28 +- .../core_crypto/fft_impl/math/polynomial.rs | 20 + 19 files changed, 2433 insertions(+), 21 deletions(-) create mode 100644 tfhe/benches/core_crypto/dev_bench.rs create mode 100644 tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs create mode 100644 tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_generation.rs create mode 100644 tfhe/src/core_crypto/algorithms/lwe_multi_bit_programmable_bootstrapping.rs create mode 100644 tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs diff --git a/tfhe/Cargo.toml b/tfhe/Cargo.toml index 4fd25ecf5..faf8e535b 100644 --- a/tfhe/Cargo.toml +++ b/tfhe/Cargo.toml @@ -123,6 +123,12 @@ path = "benches/core_crypto/pbs_bench.rs" harness = false required-features = ["boolean", "shortint", "internal-keycache"] +[[bench]] +name = "dev-bench" +path = "benches/core_crypto/dev_bench.rs" +harness = false +required-features = ["boolean", "shortint", "internal-keycache"] + [[bench]] name = "boolean-bench" path = "benches/boolean/bench.rs" diff --git a/tfhe/benches/core_crypto/dev_bench.rs b/tfhe/benches/core_crypto/dev_bench.rs new file mode 100644 index 000000000..eb41d5231 --- /dev/null +++ b/tfhe/benches/core_crypto/dev_bench.rs @@ -0,0 +1,310 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use tfhe::core_crypto::prelude::*; + +criterion_group!( + boolean_like_pbs_group, + multi_bit_pbs::, + pbs::, + mem_optimized_pbs:: +); + +criterion_group!( + shortint_like_pbs_group, + multi_bit_pbs::, + pbs::, + mem_optimized_pbs:: +); + +criterion_main!(boolean_like_pbs_group, shortint_like_pbs_group); + +fn get_bench_params() -> ( + LweDimension, + StandardDev, + DecompositionBaseLog, + DecompositionLevelCount, + GlweDimension, + PolynomialSize, + LweBskGroupingFactor, + ThreadCount, +) { + if Scalar::BITS == 64 { + ( + LweDimension(742), + StandardDev(0.000007069849454709433), + DecompositionBaseLog(3), + DecompositionLevelCount(5), + GlweDimension(1), + PolynomialSize(1024), + LweBskGroupingFactor(2), + ThreadCount(5), + ) + } else if Scalar::BITS == 32 { + ( + LweDimension(778), + StandardDev(0.000003725679281679651), + DecompositionBaseLog(18), + DecompositionLevelCount(1), + GlweDimension(3), + PolynomialSize(512), + LweBskGroupingFactor(2), + ThreadCount(5), + ) + } else { + unreachable!() + } +} + +fn multi_bit_pbs + CastFrom + Sync>( + c: &mut Criterion, +) { + // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + // computations + // Define parameters for LweBootstrapKey creation + + let ( + mut input_lwe_dimension, + lwe_modular_std_dev, + decomp_base_log, + decomp_level_count, + glwe_dimension, + polynomial_size, + grouping_factor, + thread_count, + ) = get_bench_params::(); + + while input_lwe_dimension.0 % grouping_factor.0 != 0 { + input_lwe_dimension = LweDimension(input_lwe_dimension.0 + 1); + } + + // Create the PRNG + let mut seeder = new_seeder(); + let seeder = seeder.as_mut(); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), seeder); + let mut secret_generator = + SecretRandomGenerator::::new(seeder.seed()); + + // Create the LweSecretKey + let input_lwe_secret_key = + allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); + let output_glwe_secret_key: GlweSecretKeyOwned = + allocate_and_generate_new_binary_glwe_secret_key( + glwe_dimension, + polynomial_size, + &mut secret_generator, + ); + let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key(); + + let multi_bit_bsk = FourierLweMultiBitBootstrapKey::new( + input_lwe_dimension, + glwe_dimension.to_glwe_size(), + polynomial_size, + decomp_base_log, + decomp_level_count, + grouping_factor, + ); + + // Allocate a new LweCiphertext and encrypt our plaintext + let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext( + &input_lwe_secret_key, + Plaintext(Scalar::ZERO), + lwe_modular_std_dev, + &mut encryption_generator, + ); + + let accumulator = + GlweCiphertext::new(Scalar::ZERO, glwe_dimension.to_glwe_size(), polynomial_size); + + // Allocate the LweCiphertext to store the result of the PBS + let mut out_pbs_ct = LweCiphertext::new( + Scalar::ZERO, + output_lwe_secret_key.lwe_dimension().to_lwe_size(), + ); + + let id = format!("Multi Bit PBS {}", Scalar::BITS); + #[allow(clippy::unit_arg)] + { + c.bench_function(&id, |b| { + b.iter(|| { + multi_bit_programmable_bootstrap_lwe_ciphertext( + &lwe_ciphertext_in, + &mut out_pbs_ct, + &accumulator.as_view(), + &multi_bit_bsk, + thread_count, + ); + black_box(&mut out_pbs_ct); + }) + }); + } +} + +fn pbs>(c: &mut Criterion) { + // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + // computations + // Define parameters for LweBootstrapKey creation + + let ( + input_lwe_dimension, + lwe_modular_std_dev, + decomp_base_log, + decomp_level_count, + glwe_dimension, + polynomial_size, + _, + _, + ) = get_bench_params::(); + + // Create the PRNG + let mut seeder = new_seeder(); + let seeder = seeder.as_mut(); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), seeder); + let mut secret_generator = + SecretRandomGenerator::::new(seeder.seed()); + + // Create the LweSecretKey + let input_lwe_secret_key = + allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); + let output_glwe_secret_key: GlweSecretKeyOwned = + allocate_and_generate_new_binary_glwe_secret_key( + glwe_dimension, + polynomial_size, + &mut secret_generator, + ); + let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key(); + + // Create the empty bootstrapping key in the Fourier domain + let fourier_bsk = FourierLweBootstrapKey::new( + input_lwe_dimension, + glwe_dimension.to_glwe_size(), + polynomial_size, + decomp_base_log, + decomp_level_count, + ); + + // Allocate a new LweCiphertext and encrypt our plaintext + let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext( + &input_lwe_secret_key, + Plaintext(Scalar::ZERO), + lwe_modular_std_dev, + &mut encryption_generator, + ); + + let accumulator = + GlweCiphertext::new(Scalar::ZERO, glwe_dimension.to_glwe_size(), polynomial_size); + + // Allocate the LweCiphertext to store the result of the PBS + let mut out_pbs_ct = LweCiphertext::new( + Scalar::ZERO, + output_lwe_secret_key.lwe_dimension().to_lwe_size(), + ); + + let id = format!("PBS {}", Scalar::BITS); + { + c.bench_function(&id, |b| { + b.iter(|| { + programmable_bootstrap_lwe_ciphertext( + &lwe_ciphertext_in, + &mut out_pbs_ct, + &accumulator.as_view(), + &fourier_bsk, + ); + black_box(&mut out_pbs_ct); + }) + }); + } +} + +fn mem_optimized_pbs>(c: &mut Criterion) { + // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + // computations + // Define parameters for LweBootstrapKey creation + + let ( + input_lwe_dimension, + lwe_modular_std_dev, + decomp_base_log, + decomp_level_count, + glwe_dimension, + polynomial_size, + _, + _, + ) = get_bench_params::(); + + // Create the PRNG + let mut seeder = new_seeder(); + let seeder = seeder.as_mut(); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), seeder); + let mut secret_generator = + SecretRandomGenerator::::new(seeder.seed()); + + // Create the LweSecretKey + let input_lwe_secret_key = + allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); + let output_glwe_secret_key: GlweSecretKeyOwned = + allocate_and_generate_new_binary_glwe_secret_key( + glwe_dimension, + polynomial_size, + &mut secret_generator, + ); + let output_lwe_secret_key = output_glwe_secret_key.into_lwe_secret_key(); + + // Create the empty bootstrapping key in the Fourier domain + let fourier_bsk = FourierLweBootstrapKey::new( + input_lwe_dimension, + glwe_dimension.to_glwe_size(), + polynomial_size, + decomp_base_log, + decomp_level_count, + ); + + // Allocate a new LweCiphertext and encrypt our plaintext + let lwe_ciphertext_in = allocate_and_encrypt_new_lwe_ciphertext( + &input_lwe_secret_key, + Plaintext(Scalar::ZERO), + lwe_modular_std_dev, + &mut encryption_generator, + ); + + let accumulator = + GlweCiphertext::new(Scalar::ZERO, glwe_dimension.to_glwe_size(), polynomial_size); + // Allocate the LweCiphertext to store the result of the PBS + let mut out_pbs_ct = LweCiphertext::new( + Scalar::ZERO, + output_lwe_secret_key.lwe_dimension().to_lwe_size(), + ); + + let mut buffers = ComputationBuffers::new(); + + let fft = Fft::new(fourier_bsk.polynomial_size()); + let fft = fft.as_view(); + + buffers.resize( + programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::( + fourier_bsk.glwe_size(), + fourier_bsk.polynomial_size(), + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + + let id = format!("PBS mem-optimized {}", Scalar::BITS); + { + c.bench_function(&id, |b| { + b.iter(|| { + programmable_bootstrap_lwe_ciphertext_mem_optimized( + &lwe_ciphertext_in, + &mut out_pbs_ct, + &accumulator.as_view(), + &fourier_bsk, + fft, + buffers.stack(), + ); + black_box(&mut out_pbs_ct); + }) + }); + } +} diff --git a/tfhe/benches/core_crypto/pbs_bench.rs b/tfhe/benches/core_crypto/pbs_bench.rs index 62f8c7012..9a68ce72e 100644 --- a/tfhe/benches/core_crypto/pbs_bench.rs +++ b/tfhe/benches/core_crypto/pbs_bench.rs @@ -160,18 +160,18 @@ fn mem_optimized_pbs>(c: &mut Criterion) ); let id = format!("PBS_mem-optimized_{name}"); - #[allow(clippy::unit_arg)] { c.bench_function(&id, |b| { b.iter(|| { - black_box(programmable_bootstrap_lwe_ciphertext_mem_optimized( + programmable_bootstrap_lwe_ciphertext_mem_optimized( &lwe_ciphertext_in, &mut out_pbs_ct, &accumulator.as_view(), &fourier_bsk, fft, buffers.stack(), - )) + ); + black_box(&mut out_pbs_ct); }) }); } diff --git a/tfhe/src/core_crypto/algorithms/glwe_sample_extraction.rs b/tfhe/src/core_crypto/algorithms/glwe_sample_extraction.rs index d370a8c13..28ed67229 100644 --- a/tfhe/src/core_crypto/algorithms/glwe_sample_extraction.rs +++ b/tfhe/src/core_crypto/algorithms/glwe_sample_extraction.rs @@ -115,7 +115,10 @@ pub fn extract_lwe_sample_from_glwe_ciphertext( let opposite_count = input_glwe.polynomial_size().0 - nth.0 - 1; // We loop through the polynomials - for lwe_mask_poly in lwe_mask.as_mut().chunks_mut(input_glwe.polynomial_size().0) { + for lwe_mask_poly in lwe_mask + .as_mut() + .chunks_exact_mut(input_glwe.polynomial_size().0) + { // We reverse the polynomial lwe_mask_poly.reverse(); // We compute the opposite of the proper coefficients diff --git a/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs new file mode 100644 index 000000000..5b1487424 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_conversion.rs @@ -0,0 +1,84 @@ +//! Module containing primitives pertaining to the conversion of +//! [`standard LWE multi_bit bootstrap keys`](`LweMultiBitBootstrapKey`) to various +//! representations/numerical domains like the Fourier domain. + +use crate::core_crypto::commons::computation_buffers::ComputationBuffers; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use crate::core_crypto::fft_impl::math::fft::{Fft, FftView}; +use concrete_fft::c64; +use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq}; + +/// Convert an [`LWE multi_bit bootstrap key`](`LweMultiBitBootstrapKey`) with standard +/// coefficients to the Fourier domain. +/// +/// See [`multi_bit_programmable_bootstrap_lwe_ciphertext`](`crate::core_crypto::algorithms::multi_bit_programmable_bootstrap_lwe_ciphertext`) for usage. +pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier( + input_bsk: &LweMultiBitBootstrapKey, + output_bsk: &mut FourierLweMultiBitBootstrapKey, +) where + Scalar: UnsignedTorus, + InputCont: Container, + OutputCont: ContainerMut, +{ + let mut buffers = ComputationBuffers::new(); + + let fft = Fft::new(input_bsk.polynomial_size()); + let fft = fft.as_view(); + + buffers.resize( + convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement(fft) + .unwrap() + .unaligned_bytes_required(), + ); + + let stack = buffers.stack(); + + convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized( + input_bsk, output_bsk, fft, stack, + ); +} + +/// Memory optimized version of [`convert_standard_lwe_multi_bit_bootstrap_key_to_fourier`]. +pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized< + Scalar, + InputCont, + OutputCont, +>( + input_bsk: &LweMultiBitBootstrapKey, + output_bsk: &mut FourierLweMultiBitBootstrapKey, + fft: FftView<'_>, + mut stack: DynStack<'_>, +) where + Scalar: UnsignedTorus, + InputCont: Container, + OutputCont: ContainerMut, +{ + let mut output_bsk_as_polynomial_list = output_bsk.as_mut_polynomial_list(); + let input_bsk_as_polynomial_list = input_bsk.as_polynomial_list(); + + assert_eq!( + output_bsk_as_polynomial_list.polynomial_count(), + input_bsk_as_polynomial_list.polynomial_count() + ); + + for (fourier_poly, coef_poly) in output_bsk_as_polynomial_list + .iter_mut() + .zip(input_bsk_as_polynomial_list.iter()) + { + // SAFETY: forward_as_torus doesn't write any uninitialized values into its output + fft.forward_as_torus( + unsafe { fourier_poly.into_uninit() }, + coef_poly, + stack.rb_mut(), + ); + } +} + +/// Return the required memory for +/// [`convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized`]. +pub fn convert_standard_lwe_multi_bit_bootstrap_key_to_fourier_mem_optimized_requirement( + fft: FftView<'_>, +) -> Result { + fft.forward_scratch() +} diff --git a/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_generation.rs b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_generation.rs new file mode 100644 index 000000000..fd86a4645 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_bootstrap_key_generation.rs @@ -0,0 +1,442 @@ +//! Module containing primitives pertaining to the generation of +//! [`standard LWE multi_bit bootstrap keys`](`LweMultiBitBootstrapKey`). + +use crate::core_crypto::algorithms::*; +use crate::core_crypto::commons::dispersion::DispersionParameter; +use crate::core_crypto::commons::generators::EncryptionRandomGenerator; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use rayon::prelude::*; + +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweBootstrapKey creation +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433); +/// let output_lwe_dimension = LweDimension(2048); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(5); +/// let glwe_dimension = GlweDimension(1); +/// let polynomial_size = PolynomialSize(1024); +/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533); +/// let grouping_factor = LweBskGroupingFactor(2); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// glwe_dimension, +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// +/// let mut bsk = LweMultiBitBootstrapKey::new( +/// 0u64, +/// glwe_dimension.to_glwe_size(), +/// polynomial_size, +/// decomp_base_log, +/// decomp_level_count, +/// input_lwe_dimension, +/// grouping_factor, +/// ); +/// +/// generate_lwe_multi_bit_bootstrap_key( +/// &input_lwe_secret_key, +/// &output_glwe_secret_key, +/// &mut bsk, +/// glwe_modular_std_dev, +/// &mut encryption_generator, +/// ); +/// +/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element(); +/// +/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip( +/// input_lwe_secret_key +/// .as_ref() +/// .chunks_exact(grouping_factor.0), +/// ) { +/// for (bit_inversion_idx, ggsw) in ggsw_group.iter().enumerate() { +/// let mut key_bits_plaintext = 1u64; +/// for (bit_idx, &key_bit) in input_key_elements.iter().enumerate() { +/// let bit_position = input_key_elements.len() - (bit_idx + 1); +/// let inversion_bit = (((bit_inversion_idx >> bit_position) & 1) ^ 1) as u64; +/// let key_bit = key_bit ^ inversion_bit; +/// key_bits_plaintext *= key_bit; +/// } +/// +/// let decrypted_ggsw = decrypt_constant_ggsw_ciphertext(&output_glwe_secret_key, &ggsw); +/// assert_eq!(decrypted_ggsw.0, key_bits_plaintext) +/// } +/// } +/// ``` +pub fn generate_lwe_multi_bit_bootstrap_key( + input_lwe_secret_key: &LweSecretKey, + output_glwe_secret_key: &GlweSecretKey, + output: &mut LweMultiBitBootstrapKey, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: UnsignedTorus + CastFrom, + InputKeyCont: Container, + OutputKeyCont: Container, + OutputCont: ContainerMut, + Gen: ByteRandomGenerator, +{ + assert!( + output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(), + "Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \ + Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.", + input_lwe_secret_key.lwe_dimension(), + output.input_lwe_dimension() + ); + + assert!( + output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(), + "Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \ + Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.", + output_glwe_secret_key.glwe_dimension().to_glwe_size(), + output.glwe_size() + ); + + assert!( + output.polynomial_size() == output_glwe_secret_key.polynomial_size(), + "Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \ + Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.", + output_glwe_secret_key.polynomial_size(), + output.polynomial_size() + ); + + let gen_iter = generator + .fork_multi_bit_bsk_to_ggsw_group::( + output.input_lwe_dimension(), + output.decomposition_level_count(), + output.glwe_size(), + output.polynomial_size(), + output.grouping_factor(), + ) + .unwrap(); + + let grouping_factor = output.grouping_factor(); + let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element(); + + for ((mut ggsw_group, input_key_elements), mut loop_generator) in output + .chunks_exact_mut(ggsw_per_multi_bit_element.0) + .zip( + input_lwe_secret_key + .as_ref() + .chunks_exact(grouping_factor.0), + ) + .zip(gen_iter) + { + let gen_iter = loop_generator.fork_n(ggsw_per_multi_bit_element.0).unwrap(); + for ((bit_inversion_idx, mut ggsw), mut inner_loop_generator) in + ggsw_group.iter_mut().enumerate().zip(gen_iter) + { + // Use the index of the ggsw as a way to know which bit to invert + let key_bits_plaintext = combine_key_bits(bit_inversion_idx, input_key_elements); + + encrypt_constant_ggsw_ciphertext( + output_glwe_secret_key, + &mut ggsw, + Plaintext(key_bits_plaintext), + noise_parameters, + &mut inner_loop_generator, + ); + } + } +} + +pub fn allocate_and_generate_new_lwe_multi_bit_bootstrap_key< + Scalar, + InputKeyCont, + OutputKeyCont, + Gen, +>( + input_lwe_secret_key: &LweSecretKey, + output_glwe_secret_key: &GlweSecretKey, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + grouping_factor: LweBskGroupingFactor, + noise_parameters: impl DispersionParameter, + generator: &mut EncryptionRandomGenerator, +) -> LweMultiBitBootstrapKeyOwned +where + Scalar: UnsignedTorus + CastFrom, + InputKeyCont: Container, + OutputKeyCont: Container, + Gen: ByteRandomGenerator, +{ + let mut bsk = LweMultiBitBootstrapKeyOwned::new( + Scalar::ZERO, + output_glwe_secret_key.glwe_dimension().to_glwe_size(), + output_glwe_secret_key.polynomial_size(), + decomp_base_log, + decomp_level_count, + input_lwe_secret_key.lwe_dimension(), + grouping_factor, + ); + + generate_lwe_multi_bit_bootstrap_key( + input_lwe_secret_key, + output_glwe_secret_key, + &mut bsk, + noise_parameters, + generator, + ); + + bsk +} + +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define parameters for LweBootstrapKey creation +/// let input_lwe_dimension = LweDimension(742); +/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433); +/// let output_lwe_dimension = LweDimension(2048); +/// let decomp_base_log = DecompositionBaseLog(3); +/// let decomp_level_count = DecompositionLevelCount(5); +/// let glwe_dimension = GlweDimension(1); +/// let polynomial_size = PolynomialSize(1024); +/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533); +/// let grouping_factor = LweBskGroupingFactor(2); +/// +/// // Create the PRNG +/// let mut seeder = new_seeder(); +/// let seeder = seeder.as_mut(); +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create the LweSecretKey +/// let input_lwe_secret_key = +/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator); +/// let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( +/// glwe_dimension, +/// polynomial_size, +/// &mut secret_generator, +/// ); +/// +/// let mut bsk = LweMultiBitBootstrapKey::new( +/// 0u64, +/// glwe_dimension.to_glwe_size(), +/// polynomial_size, +/// decomp_base_log, +/// decomp_level_count, +/// input_lwe_dimension, +/// grouping_factor, +/// ); +/// +/// par_generate_lwe_multi_bit_bootstrap_key( +/// &input_lwe_secret_key, +/// &output_glwe_secret_key, +/// &mut bsk, +/// glwe_modular_std_dev, +/// &mut encryption_generator, +/// ); +/// +/// let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new( +/// input_lwe_dimension, +/// glwe_dimension.to_glwe_size(), +/// polynomial_size, +/// decomp_base_log, +/// decomp_level_count, +/// grouping_factor, +/// ); +/// +/// convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk); +/// +/// let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element(); +/// +/// for (mut ggsw_group, input_key_elements) in bsk.chunks_exact(ggsw_per_multi_bit_element.0).zip( +/// input_lwe_secret_key +/// .as_ref() +/// .chunks_exact(grouping_factor.0), +/// ) { +/// for (bit_inversion_idx, ggsw) in ggsw_group.iter().enumerate() { +/// let mut key_bits_plaintext = 1u64; +/// for (bit_idx, &key_bit) in input_key_elements.iter().enumerate() { +/// let bit_position = input_key_elements.len() - (bit_idx + 1); +/// let inversion_bit = (((bit_inversion_idx >> bit_position) & 1) ^ 1) as u64; +/// let key_bit = key_bit ^ inversion_bit; +/// key_bits_plaintext *= key_bit; +/// } +/// let decrypted_ggsw = decrypt_constant_ggsw_ciphertext(&output_glwe_secret_key, &ggsw); +/// assert_eq!(decrypted_ggsw.0, key_bits_plaintext) +/// } +/// } +/// ``` +pub fn par_generate_lwe_multi_bit_bootstrap_key< + Scalar, + InputKeyCont, + OutputKeyCont, + OutputCont, + Gen, +>( + input_lwe_secret_key: &LweSecretKey, + output_glwe_secret_key: &GlweSecretKey, + output: &mut LweMultiBitBootstrapKey, + noise_parameters: impl DispersionParameter + Sync, + generator: &mut EncryptionRandomGenerator, +) where + Scalar: UnsignedTorus + CastFrom + Sync + Send, + InputKeyCont: Container, + OutputKeyCont: Container + Sync, + OutputCont: ContainerMut, + Gen: ParallelByteRandomGenerator, +{ + assert!( + output.input_lwe_dimension() == input_lwe_secret_key.lwe_dimension(), + "Mismatched LweDimension between input LWE secret key and LWE bootstrap key. \ + Input LWE secret key LweDimension: {:?}, LWE bootstrap key input LweDimension {:?}.", + input_lwe_secret_key.lwe_dimension(), + output.input_lwe_dimension() + ); + + assert!( + output.glwe_size() == output_glwe_secret_key.glwe_dimension().to_glwe_size(), + "Mismatched GlweSize between output GLWE secret key and LWE bootstrap key. \ + Output GLWE secret key GlweSize: {:?}, LWE bootstrap key GlweSize {:?}.", + output_glwe_secret_key.glwe_dimension().to_glwe_size(), + output.glwe_size() + ); + + assert!( + output.polynomial_size() == output_glwe_secret_key.polynomial_size(), + "Mismatched PolynomialSize between output GLWE secret key and LWE bootstrap key. \ + Output GLWE secret key PolynomialSize: {:?}, LWE bootstrap key PolynomialSize {:?}.", + output_glwe_secret_key.polynomial_size(), + output.polynomial_size() + ); + + let gen_iter = generator + .par_fork_multi_bit_bsk_to_ggsw_group::( + output.input_lwe_dimension(), + output.decomposition_level_count(), + output.glwe_size(), + output.polynomial_size(), + output.grouping_factor(), + ) + .unwrap(); + + let grouping_factor = output.grouping_factor(); + let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element(); + + output + .par_iter_mut() + .chunks(ggsw_per_multi_bit_element.0) + .zip( + input_lwe_secret_key + .as_ref() + .par_chunks_exact(grouping_factor.0), + ) + .zip(gen_iter) + .for_each( + |((mut ggsw_group, input_key_elements), mut loop_generator)| { + let gen_iter = loop_generator + .par_fork_n(ggsw_per_multi_bit_element.0) + .unwrap(); + ggsw_group + .par_iter_mut() + .enumerate() + .zip(gen_iter) + .for_each(|((bit_inversion_idx, ggsw), mut inner_loop_generator)| { + // Use the index of the ggsw as a way to know which bit to invert + let key_bits_plaintext = + combine_key_bits(bit_inversion_idx, input_key_elements); + + par_encrypt_constant_ggsw_ciphertext( + output_glwe_secret_key, + ggsw, + Plaintext(key_bits_plaintext), + noise_parameters, + &mut inner_loop_generator, + ); + }); + }, + ); +} + +fn combine_key_bits(bit_selector: usize, input_key_elements: &[Scalar]) -> Scalar +where + Scalar: UnsignedInteger + CastFrom, +{ + // Use a bit_selector (in practice the ggsw index) as a way to know which bit to invert, the + // counter goes from e.g. 0 to 4 or 00, 01, 10 and 11 in binary, we use those bits to know which + // key bit to invert in our product, also we invert the bit once more to be sure that the first + // term is the GGSW encrypting a constant polynomial (and not a monomial), allowing to copy it + // in the multi_bit PBS routine and computing polynomial products on the rest of the terms. + + // We compute products, initialize the combined key bits to 1 + let mut key_bits_plaintext = Scalar::ONE; + for (bit_idx, &key_bit) in input_key_elements.iter().enumerate() { + // Get the position of the bit we will check in bit_selector + let bit_position = input_key_elements.len() - (bit_idx + 1); + // Get the bit, invert it to have the first combined GGSW correspond to + // the constant polynomial, i.e. we generate + // first GGSW((1 - s_{i-1}) * (1 - s_i)) up + // to GGSW(s_{i-1} * s_{i}) + let inversion_bit: Scalar = Scalar::cast_from(((bit_selector >> bit_position) & 1) ^ 1); + // Invert the key_bit depending on the computed inversion_bit + let key_bit = key_bit ^ inversion_bit; + // Multiply the accumulator by the key_bit we need to combine it with + key_bits_plaintext = key_bits_plaintext.wrapping_mul(key_bit); + } + key_bits_plaintext +} + +pub fn par_allocate_and_generate_new_lwe_multi_bit_bootstrap_key< + Scalar, + InputKeyCont, + OutputKeyCont, + Gen, +>( + input_lwe_secret_key: &LweSecretKey, + output_glwe_secret_key: &GlweSecretKey, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + grouping_factor: LweBskGroupingFactor, + noise_parameters: impl DispersionParameter + Sync, + generator: &mut EncryptionRandomGenerator, +) -> LweMultiBitBootstrapKeyOwned +where + Scalar: UnsignedTorus + CastFrom + Sync + Send, + InputKeyCont: Container, + OutputKeyCont: Container + Sync, + Gen: ParallelByteRandomGenerator, +{ + let mut bsk = LweMultiBitBootstrapKeyOwned::new( + Scalar::ZERO, + output_glwe_secret_key.glwe_dimension().to_glwe_size(), + output_glwe_secret_key.polynomial_size(), + decomp_base_log, + decomp_level_count, + input_lwe_secret_key.lwe_dimension(), + grouping_factor, + ); + + par_generate_lwe_multi_bit_bootstrap_key( + input_lwe_secret_key, + output_glwe_secret_key, + &mut bsk, + noise_parameters, + generator, + ); + + bsk +} diff --git a/tfhe/src/core_crypto/algorithms/lwe_multi_bit_programmable_bootstrapping.rs b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_programmable_bootstrapping.rs new file mode 100644 index 000000000..b006239f1 --- /dev/null +++ b/tfhe/src/core_crypto/algorithms/lwe_multi_bit_programmable_bootstrapping.rs @@ -0,0 +1,945 @@ +use crate::core_crypto::algorithms::extract_lwe_sample_from_glwe_ciphertext; +use crate::core_crypto::algorithms::polynomial_algorithms::*; +use crate::core_crypto::commons::computation_buffers::ComputationBuffers; +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use crate::core_crypto::fft_impl::as_mut_uninit; +use crate::core_crypto::fft_impl::crypto::bootstrap::pbs_modulus_switch; +use crate::core_crypto::fft_impl::crypto::ggsw::{ + add_external_product_assign, add_external_product_assign_scratch, update_with_fmadd, +}; +use crate::core_crypto::fft_impl::math::fft::Fft; +use concrete_fft::c64; +use std::sync::{mpsc, Condvar, Mutex}; +use std::thread; + +/// Perform a blind rotation given an input [`LWE ciphertext`](`LweCiphertext`), modifying a look-up +/// table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE bootstrap +/// key`](`LweMultiBitBootstrapKey`) in the fourier domain. +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define the parameters for a 4 bits message able to hold the doubled 2 bits message +/// let small_lwe_dimension = LweDimension(742); +/// let glwe_dimension = GlweDimension(1); +/// let polynomial_size = PolynomialSize(2048); +/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433); +/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533); +/// let pbs_base_log = DecompositionBaseLog(23); +/// let pbs_level = DecompositionLevelCount(1); +/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs +/// +/// // Request the best seeder possible, starting with hardware entropy sources and falling back to +/// // /dev/random on Unix systems if enabled via cargo features +/// let mut boxed_seeder = new_seeder(); +/// // Get a mutable reference to the seeder as a trait object from the Box returned by new_seeder +/// let seeder = boxed_seeder.as_mut(); +/// +/// // Create a generator which uses a CSPRNG to generate secret keys +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create a generator which uses two CSPRNGs to generate public masks and secret encryption +/// // noise +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// +/// println!("Generating keys..."); +/// +/// // Generate an LweSecretKey with binary coefficients +/// let small_lwe_sk = +/// LweSecretKey::generate_new_binary(small_lwe_dimension, &mut secret_generator); +/// +/// // Generate a GlweSecretKey with binary coefficients +/// let glwe_sk = +/// GlweSecretKey::generate_new_binary(glwe_dimension, polynomial_size, &mut secret_generator); +/// +/// // Create a copy of the GlweSecretKey re-interpreted as an LweSecretKey +/// let big_lwe_sk = glwe_sk.clone().into_lwe_secret_key(); +/// +/// let mut bsk = LweMultiBitBootstrapKey::new( +/// 0u64, +/// glwe_dimension.to_glwe_size(), +/// polynomial_size, +/// pbs_base_log, +/// pbs_level, +/// small_lwe_dimension, +/// grouping_factor, +/// ); +/// +/// par_generate_lwe_multi_bit_bootstrap_key( +/// &small_lwe_sk, +/// &glwe_sk, +/// &mut bsk, +/// glwe_modular_std_dev, +/// &mut encryption_generator, +/// ); +/// +/// let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new( +/// bsk.input_lwe_dimension(), +/// bsk.glwe_size(), +/// bsk.polynomial_size(), +/// bsk.decomposition_base_log(), +/// bsk.decomposition_level_count(), +/// bsk.grouping_factor(), +/// ); +/// +/// convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk); +/// +/// // We don't need the standard bootstrapping key anymore +/// drop(bsk); +/// +/// // Our 4 bits message space +/// let message_modulus = 1u64 << 4; +/// +/// // Our input message +/// let input_message = 3u64; +/// +/// // Delta used to encode 4 bits of message + a bit of padding on u64 +/// let delta = (1_u64 << 63) / message_modulus; +/// +/// // Apply our encoding +/// let plaintext = Plaintext(input_message * delta); +/// +/// // Allocate a new LweCiphertext and encrypt our plaintext +/// let lwe_ciphertext_in: LweCiphertextOwned = allocate_and_encrypt_new_lwe_ciphertext( +/// &small_lwe_sk, +/// plaintext, +/// lwe_modular_std_dev, +/// &mut encryption_generator, +/// ); +/// +/// // Now we will use a PBS to compute a multiplication by 2, it is NOT the recommended way of +/// // doing this operation in terms of performance as it's much more costly than a multiplication +/// // with a cleartext, however it resets the noise in a ciphertext to a nominal level and allows +/// // to evaluate arbitrary functions so depending on your use case it can be a better fit. +/// +/// // Here we will define a helper function to generate an accumulator for a PBS +/// fn generate_accumulator( +/// polynomial_size: PolynomialSize, +/// glwe_size: GlweSize, +/// message_modulus: usize, +/// delta: u64, +/// f: F, +/// ) -> GlweCiphertextOwned +/// where +/// F: Fn(u64) -> u64, +/// { +/// // N/(p/2) = size of each block, to correct noise from the input we introduce the notion of +/// // box, which manages redundancy to yield a denoised value for several noisy values around +/// // a true input value. +/// let box_size = polynomial_size.0 / message_modulus; +/// +/// // Create the accumulator +/// let mut accumulator_u64 = vec![0_u64; polynomial_size.0]; +/// +/// // Fill each box with the encoded denoised value +/// for i in 0..message_modulus { +/// let index = i * box_size; +/// accumulator_u64[index..index + box_size] +/// .iter_mut() +/// .for_each(|a| *a = f(i as u64) * delta); +/// } +/// +/// let half_box_size = box_size / 2; +/// +/// // Negate the first half_box_size coefficients to manage negacyclicity and rotate +/// for a_i in accumulator_u64[0..half_box_size].iter_mut() { +/// *a_i = (*a_i).wrapping_neg(); +/// } +/// +/// // Rotate the accumulator +/// accumulator_u64.rotate_left(half_box_size); +/// +/// let accumulator_plaintext = PlaintextList::from_container(accumulator_u64); +/// +/// let accumulator = +/// allocate_and_trivially_encrypt_new_glwe_ciphertext(glwe_size, &accumulator_plaintext); +/// +/// accumulator +/// } +/// +/// // Generate the accumulator for our multiplication by 2 using a simple closure +/// let mut accumulator: GlweCiphertextOwned = generate_accumulator( +/// polynomial_size, +/// glwe_dimension.to_glwe_size(), +/// message_modulus as usize, +/// delta, +/// |x: u64| 2 * x, +/// ); +/// +/// // Allocate the LweCiphertext to store the result of the PBS +/// let mut pbs_multiplication_ct = +/// LweCiphertext::new(0u64, big_lwe_sk.lwe_dimension().to_lwe_size()); +/// println!("Performing blind rotation..."); +/// // Use 4 threads for the multi-bit blind rotation for example +/// multi_bit_blind_rotate_assign( +/// &lwe_ciphertext_in, +/// &mut accumulator, +/// &multi_bit_bsk, +/// ThreadCount(4), +/// ); +/// println!("Performing sample extraction..."); +/// extract_lwe_sample_from_glwe_ciphertext( +/// &accumulator, +/// &mut pbs_multiplication_ct, +/// MonomialDegree(0), +/// ); +/// +/// // Decrypt the PBS multiplication result +/// let pbs_multipliation_plaintext: Plaintext = +/// decrypt_lwe_ciphertext(&big_lwe_sk, &pbs_multiplication_ct); +/// +/// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext +/// // We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating we want to +/// // round the 5 MSB, 1 bit of padding plus our 4 bits of message +/// let signed_decomposer = +/// SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1)); +/// +/// // Round and remove our encoding +/// let pbs_multiplication_result: u64 = +/// signed_decomposer.closest_representable(pbs_multipliation_plaintext.0) / delta; +/// +/// println!("Checking result..."); +/// assert_eq!(6, pbs_multiplication_result); +/// println!( +/// "Mulitplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}" +/// ); +/// ``` +pub fn multi_bit_blind_rotate_assign( + input: &LweCiphertext, + accumulator: &mut GlweCiphertext, + multi_bit_bsk: &FourierLweMultiBitBootstrapKey, + thread_count: ThreadCount, +) where + // CastInto required for PBS modulus switch which returns a usize + Scalar: UnsignedTorus + CastInto + CastFrom + Sync, + InputCont: Container, + OutputCont: ContainerMut, + KeyCont: Container + Sync, +{ + assert_eq!( + input.lwe_size().to_lwe_dimension(), + multi_bit_bsk.input_lwe_dimension(), + "Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \ + FourierLweMultiBitBootstrapKey input LweDimension {:?}.", + input.lwe_size().to_lwe_dimension(), + multi_bit_bsk.input_lwe_dimension(), + ); + + assert_eq!( + accumulator.glwe_size(), + multi_bit_bsk.glwe_size(), + "Mimatched GlweSize. Accumulator GlweSize {:?}. \ + FourierLweMultiBitBootstrapKey GlweSize {:?}.", + accumulator.glwe_size(), + multi_bit_bsk.glwe_size(), + ); + + assert_eq!( + accumulator.polynomial_size(), + multi_bit_bsk.polynomial_size(), + "Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \ + FourierLweMultiBitBootstrapKey PolynomialSize {:?}.", + accumulator.polynomial_size(), + multi_bit_bsk.polynomial_size(), + ); + + let (lwe_mask, lwe_body) = input.get_mask_and_body(); + + // No way to chunk the result of ggsw_iter at the moment + let ggsw_vec: Vec<_> = multi_bit_bsk.ggsw_iter().collect(); + let mut work_queue = Vec::with_capacity(multi_bit_bsk.multi_bit_input_lwe_dimension().0); + + let grouping_factor = multi_bit_bsk.grouping_factor(); + let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element(); + + for (lwe_mask_elements, ggsw_group) in lwe_mask + .as_ref() + .chunks_exact(grouping_factor.0) + .zip(ggsw_vec.chunks_exact(ggsw_per_multi_bit_element.0)) + { + work_queue.push((lwe_mask_elements, ggsw_group)); + } + + assert!(work_queue.len() == lwe_mask.lwe_dimension().0 / grouping_factor.0); + + let work_queue = Mutex::new(work_queue); + + // Each producer thread works in a dedicated slot of the buffer + let thread_buffers: usize = thread_count.0; + + let lut_poly_size = accumulator.polynomial_size(); + let monomial_degree = pbs_modulus_switch( + *lwe_body.0, + lut_poly_size, + ModulusSwitchOffset(0), + LutCountLog(0), + ); + + // Modulus switching + accumulator + .as_mut_polynomial_list() + .iter_mut() + .for_each(|mut poly| { + polynomial_wrapping_monic_monomial_div_assign( + &mut poly, + MonomialDegree(monomial_degree), + ) + }); + + let fourier_multi_bit_ggsw_buffers = (0..thread_buffers) + .map(|_| { + ( + Mutex::new(false), + Condvar::new(), + Mutex::new(FourierGgswCiphertext::new( + multi_bit_bsk.glwe_size(), + multi_bit_bsk.polynomial_size(), + multi_bit_bsk.decomposition_base_log(), + multi_bit_bsk.decomposition_level_count(), + )), + ) + }) + .collect::>(); + + let (tx, rx) = mpsc::channel::(); + + let fft = Fft::new(multi_bit_bsk.polynomial_size()); + let fft = fft.as_view(); + thread::scope(|s| { + let produce_multi_bit_fourier_ggsw = |thread_id: usize, tx: mpsc::Sender| { + let mut buffers = ComputationBuffers::new(); + + buffers.resize(fft.forward_scratch().unwrap().unaligned_bytes_required()); + + let mut unit_polynomial = + Polynomial::new(Scalar::ZERO, multi_bit_bsk.polynomial_size()); + unit_polynomial.as_mut()[0] = Scalar::ONE; + let mut a_monomial = unit_polynomial.clone(); + let mut fourier_a_monomial = FourierPolynomial::new(multi_bit_bsk.polynomial_size()); + + let work_queue = &work_queue; + + let dest_idx = thread_id; + let (ready_for_consumer_lock, condvar, fourier_ggsw_buffer) = + &fourier_multi_bit_ggsw_buffers[dest_idx]; + + loop { + let maybe_work = { + let mut queue_lock = work_queue.lock().unwrap(); + queue_lock.pop() + }; + + let Some((lwe_mask_elements, ggsw_group)) = maybe_work else {break}; + let mut ready_for_consumer = ready_for_consumer_lock.lock().unwrap(); + + // Wait while the buffer is not ready for processing and wait on the condvar + // to get notified when we can start processing again + while *ready_for_consumer { + ready_for_consumer = condvar.wait(ready_for_consumer).unwrap(); + } + + let mut fourier_ggsw_buffer = fourier_ggsw_buffer.lock().unwrap(); + + let mut ggsw_group_iter = ggsw_group.iter(); + + // Keygen guarantees the first term is a constant term of the polynomial, no + // polynomial multiplication required + let ggsw_a_none = ggsw_group_iter.next().unwrap(); + + fourier_ggsw_buffer + .as_mut_view() + .data() + .copy_from_slice(ggsw_a_none.as_view().data()); + + let multi_bit_fourier_ggsw = + unsafe { as_mut_uninit(fourier_ggsw_buffer.as_mut_view().data()) }; + + for (ggsw_idx, fourier_ggsw) in ggsw_group_iter.enumerate() { + // We already processed the first ggsw, advance the index by 1 + let ggsw_idx = ggsw_idx + 1; + + // Select the proper mask elements to build the monomial degree depending on + // the order the GGSW were generated in, using the bits from mask_idx and + // ggsw_idx as selector bits + let mut monomial_degree = Scalar::ZERO; + for (mask_idx, &mask_element) in lwe_mask_elements.iter().enumerate() { + let mask_position = lwe_mask_elements.len() - (mask_idx + 1); + let selection_bit: Scalar = + Scalar::cast_from((ggsw_idx >> mask_position) & 1); + monomial_degree = + monomial_degree.wrapping_add(selection_bit.wrapping_mul(mask_element)); + } + + let switched_degree = pbs_modulus_switch( + monomial_degree, + lut_poly_size, + ModulusSwitchOffset(0), + LutCountLog(0), + ); + + a_monomial + .as_mut() + .copy_from_slice(unit_polynomial.as_ref()); + polynomial_wrapping_monic_monomial_mul_assign( + &mut a_monomial, + MonomialDegree(switched_degree), + ); + + fft.forward_as_integer( + unsafe { fourier_a_monomial.as_mut_view().into_uninit() }, + a_monomial.as_view(), + buffers.stack(), + ); + + unsafe { + update_with_fmadd( + multi_bit_fourier_ggsw, + fourier_ggsw.as_view().data(), + fourier_a_monomial.as_view().data, + false, + lut_poly_size.to_fourier_polynomial_size().0, + ); + } + } + + // Drop the lock before we wake other threads + drop(fourier_ggsw_buffer); + + *ready_for_consumer = true; + tx.send(dest_idx).unwrap(); + + // Wake threads waiting on the condvar + condvar.notify_all(); + } + }; + + let threads: Vec<_> = (0..thread_count.0) + .map(|id| { + let tx = tx.clone(); + s.spawn(move || produce_multi_bit_fourier_ggsw(id, tx)) + }) + .collect(); + + // We initialize ct0 for the successive external products + let ct0 = accumulator; + let mut ct1 = GlweCiphertext::new(Scalar::ZERO, ct0.glwe_size(), ct0.polynomial_size()); + let ct1 = &mut ct1; + + let mut buffers = ComputationBuffers::new(); + + buffers.resize( + add_external_product_assign_scratch::( + multi_bit_bsk.glwe_size(), + multi_bit_bsk.polynomial_size(), + fft, + ) + .unwrap() + .unaligned_bytes_required(), + ); + + let mut src_idx = 1usize; + + for _ in 0..multi_bit_bsk.multi_bit_input_lwe_dimension().0 { + src_idx ^= 1; + let idx = rx.recv().unwrap(); + let (ready_lock, condvar, multi_bit_fourier_ggsw) = + &fourier_multi_bit_ggsw_buffers[idx]; + + let (src_ct, mut dst_ct) = if src_idx == 0 { + (ct0.as_view(), ct1.as_mut_view()) + } else { + (ct1.as_view(), ct0.as_mut_view()) + }; + + dst_ct.as_mut().fill(Scalar::ZERO); + + let mut ready = ready_lock.lock().unwrap(); + assert!(*ready); + + let multi_bit_fourier_ggsw = multi_bit_fourier_ggsw.lock().unwrap(); + add_external_product_assign( + dst_ct, + multi_bit_fourier_ggsw.as_view(), + src_ct, + fft, + buffers.stack(), + ); + drop(multi_bit_fourier_ggsw); + + *ready = false; + // Wake a single producer thread sleeping on the condvar (only one will get to work + // anyways) + condvar.notify_one(); + } + + if src_idx == 0 { + ct0.as_mut().copy_from_slice(ct1.as_ref()); + } + + threads.into_iter().for_each(|t| t.join().unwrap()); + }); +} + +/// Perform a programmable bootsrap with given an input [`LWE ciphertext`](`LweCiphertext`), a +/// look-up table passed as a [`GLWE ciphertext`](`GlweCiphertext`) and an [`LWE multi-bit bootstrap +/// key`](`LweMultiBitBootstrapKey`) in the fourier domain. The result is written in the provided +/// output [`LWE ciphertext`](`LweCiphertext`). +/// +/// # Example +/// +/// ``` +/// use tfhe::core_crypto::prelude::*; +/// +/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct +/// // computations +/// // Define the parameters for a 4 bits message able to hold the doubled 2 bits message +/// let small_lwe_dimension = LweDimension(742); +/// let glwe_dimension = GlweDimension(1); +/// let polynomial_size = PolynomialSize(2048); +/// let lwe_modular_std_dev = StandardDev(0.000007069849454709433); +/// let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533); +/// let pbs_base_log = DecompositionBaseLog(23); +/// let pbs_level = DecompositionLevelCount(1); +/// let grouping_factor = LweBskGroupingFactor(2); // Group bits in pairs +/// +/// // Request the best seeder possible, starting with hardware entropy sources and falling back to +/// // /dev/random on Unix systems if enabled via cargo features +/// let mut boxed_seeder = new_seeder(); +/// // Get a mutable reference to the seeder as a trait object from the Box returned by new_seeder +/// let seeder = boxed_seeder.as_mut(); +/// +/// // Create a generator which uses a CSPRNG to generate secret keys +/// let mut secret_generator = +/// SecretRandomGenerator::::new(seeder.seed()); +/// +/// // Create a generator which uses two CSPRNGs to generate public masks and secret encryption +/// // noise +/// let mut encryption_generator = +/// EncryptionRandomGenerator::::new(seeder.seed(), seeder); +/// +/// println!("Generating keys..."); +/// +/// // Generate an LweSecretKey with binary coefficients +/// let small_lwe_sk = +/// LweSecretKey::generate_new_binary(small_lwe_dimension, &mut secret_generator); +/// +/// // Generate a GlweSecretKey with binary coefficients +/// let glwe_sk = +/// GlweSecretKey::generate_new_binary(glwe_dimension, polynomial_size, &mut secret_generator); +/// +/// // Create a copy of the GlweSecretKey re-interpreted as an LweSecretKey +/// let big_lwe_sk = glwe_sk.clone().into_lwe_secret_key(); +/// +/// let mut bsk = LweMultiBitBootstrapKey::new( +/// 0u64, +/// glwe_dimension.to_glwe_size(), +/// polynomial_size, +/// pbs_base_log, +/// pbs_level, +/// small_lwe_dimension, +/// grouping_factor, +/// ); +/// +/// par_generate_lwe_multi_bit_bootstrap_key( +/// &small_lwe_sk, +/// &glwe_sk, +/// &mut bsk, +/// glwe_modular_std_dev, +/// &mut encryption_generator, +/// ); +/// +/// let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new( +/// bsk.input_lwe_dimension(), +/// bsk.glwe_size(), +/// bsk.polynomial_size(), +/// bsk.decomposition_base_log(), +/// bsk.decomposition_level_count(), +/// bsk.grouping_factor(), +/// ); +/// +/// convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk); +/// +/// // We don't need the standard bootstrapping key anymore +/// drop(bsk); +/// +/// // Our 4 bits message space +/// let message_modulus = 1u64 << 4; +/// +/// // Our input message +/// let input_message = 3u64; +/// +/// // Delta used to encode 4 bits of message + a bit of padding on u64 +/// let delta = (1_u64 << 63) / message_modulus; +/// +/// // Apply our encoding +/// let plaintext = Plaintext(input_message * delta); +/// +/// // Allocate a new LweCiphertext and encrypt our plaintext +/// let lwe_ciphertext_in: LweCiphertextOwned = allocate_and_encrypt_new_lwe_ciphertext( +/// &small_lwe_sk, +/// plaintext, +/// lwe_modular_std_dev, +/// &mut encryption_generator, +/// ); +/// +/// // Now we will use a PBS to compute a multiplication by 2, it is NOT the recommended way of +/// // doing this operation in terms of performance as it's much more costly than a multiplication +/// // with a cleartext, however it resets the noise in a ciphertext to a nominal level and allows +/// // to evaluate arbitrary functions so depending on your use case it can be a better fit. +/// +/// // Here we will define a helper function to generate an accumulator for a PBS +/// fn generate_accumulator( +/// polynomial_size: PolynomialSize, +/// glwe_size: GlweSize, +/// message_modulus: usize, +/// delta: u64, +/// f: F, +/// ) -> GlweCiphertextOwned +/// where +/// F: Fn(u64) -> u64, +/// { +/// // N/(p/2) = size of each block, to correct noise from the input we introduce the notion of +/// // box, which manages redundancy to yield a denoised value for several noisy values around +/// // a true input value. +/// let box_size = polynomial_size.0 / message_modulus; +/// +/// // Create the accumulator +/// let mut accumulator_u64 = vec![0_u64; polynomial_size.0]; +/// +/// // Fill each box with the encoded denoised value +/// for i in 0..message_modulus { +/// let index = i * box_size; +/// accumulator_u64[index..index + box_size] +/// .iter_mut() +/// .for_each(|a| *a = f(i as u64) * delta); +/// } +/// +/// let half_box_size = box_size / 2; +/// +/// // Negate the first half_box_size coefficients to manage negacyclicity and rotate +/// for a_i in accumulator_u64[0..half_box_size].iter_mut() { +/// *a_i = (*a_i).wrapping_neg(); +/// } +/// +/// // Rotate the accumulator +/// accumulator_u64.rotate_left(half_box_size); +/// +/// let accumulator_plaintext = PlaintextList::from_container(accumulator_u64); +/// +/// let accumulator = +/// allocate_and_trivially_encrypt_new_glwe_ciphertext(glwe_size, &accumulator_plaintext); +/// +/// accumulator +/// } +/// +/// // Generate the accumulator for our multiplication by 2 using a simple closure +/// let accumulator: GlweCiphertextOwned = generate_accumulator( +/// polynomial_size, +/// glwe_dimension.to_glwe_size(), +/// message_modulus as usize, +/// delta, +/// |x: u64| 2 * x, +/// ); +/// +/// // Allocate the LweCiphertext to store the result of the PBS +/// let mut pbs_multiplication_ct = +/// LweCiphertext::new(0u64, big_lwe_sk.lwe_dimension().to_lwe_size()); +/// println!("Computing PBS..."); +/// // Use 4 threads to compute the multi-bit PBS +/// multi_bit_programmable_bootstrap_lwe_ciphertext( +/// &lwe_ciphertext_in, +/// &mut pbs_multiplication_ct, +/// &accumulator, +/// &multi_bit_bsk, +/// ThreadCount(4), +/// ); +/// +/// // Decrypt the PBS multiplication result +/// let pbs_multipliation_plaintext: Plaintext = +/// decrypt_lwe_ciphertext(&big_lwe_sk, &pbs_multiplication_ct); +/// +/// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext +/// // We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating we want to +/// // round the 5 MSB, 1 bit of padding plus our 4 bits of message +/// let signed_decomposer = +/// SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1)); +/// +/// // Round and remove our encoding +/// let pbs_multiplication_result: u64 = +/// signed_decomposer.closest_representable(pbs_multipliation_plaintext.0) / delta; +/// +/// println!("Checking result..."); +/// assert_eq!(6, pbs_multiplication_result); +/// println!( +/// "Mulitplication via PBS result is correct! Expected 6, got {pbs_multiplication_result}" +/// ); +/// ``` +pub fn multi_bit_programmable_bootstrap_lwe_ciphertext< + Scalar, + InputCont, + OutputCont, + AccCont, + KeyCont, +>( + input: &LweCiphertext, + output: &mut LweCiphertext, + accumulator: &GlweCiphertext, + multi_bit_bsk: &FourierLweMultiBitBootstrapKey, + thread_count: ThreadCount, +) where + // CastInto required for PBS modulus switch which returns a usize + Scalar: UnsignedTorus + CastInto + CastFrom + Sync, + InputCont: Container, + OutputCont: ContainerMut, + AccCont: Container, + KeyCont: Container + Sync, +{ + assert_eq!( + input.lwe_size().to_lwe_dimension(), + multi_bit_bsk.input_lwe_dimension(), + "Mimatched input LweDimension. LweCiphertext input LweDimension {:?}. \ + FourierLweMultiBitBootstrapKey input LweDimension {:?}.", + input.lwe_size().to_lwe_dimension(), + multi_bit_bsk.input_lwe_dimension(), + ); + + assert_eq!( + output.lwe_size().to_lwe_dimension(), + multi_bit_bsk.output_lwe_dimension(), + "Mimatched output LweDimension. LweCiphertext output LweDimension {:?}. \ + FourierLweMultiBitBootstrapKey output LweDimension {:?}.", + output.lwe_size().to_lwe_dimension(), + multi_bit_bsk.output_lwe_dimension(), + ); + + assert_eq!( + accumulator.glwe_size(), + multi_bit_bsk.glwe_size(), + "Mimatched GlweSize. Accumulator GlweSize {:?}. \ + FourierLweMultiBitBootstrapKey GlweSize {:?}.", + accumulator.glwe_size(), + multi_bit_bsk.glwe_size(), + ); + + assert_eq!( + accumulator.polynomial_size(), + multi_bit_bsk.polynomial_size(), + "Mimatched PolynomialSize. Accumulator PolynomialSize {:?}. \ + FourierLweMultiBitBootstrapKey PolynomialSize {:?}.", + accumulator.polynomial_size(), + multi_bit_bsk.polynomial_size(), + ); + + let mut local_accumulator = GlweCiphertext::new( + Scalar::ZERO, + accumulator.glwe_size(), + accumulator.polynomial_size(), + ); + local_accumulator + .as_mut() + .copy_from_slice(accumulator.as_ref()); + + multi_bit_blind_rotate_assign(input, &mut local_accumulator, multi_bit_bsk, thread_count); + + extract_lwe_sample_from_glwe_ciphertext(&local_accumulator, output, MonomialDegree(0)); +} + +#[cfg(test)] +mod test { + use crate::core_crypto::prelude::*; + + fn multi_bit_pbs(grouping_factor: LweBskGroupingFactor, thread_count: ThreadCount) { + // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + // computations + // Define parameters for LweBootstrapKey creation + let mut input_lwe_dimension = LweDimension(742); + let lwe_modular_std_dev = StandardDev(0.000007069849454709433); + let decomp_base_log = DecompositionBaseLog(3); + let decomp_level_count = DecompositionLevelCount(5); + let glwe_dimension = GlweDimension(1); + let polynomial_size = PolynomialSize(1024); + let glwe_modular_std_dev = StandardDev(0.00000000000000029403601535432533); + + while input_lwe_dimension.0 % grouping_factor.0 != 0 { + input_lwe_dimension = LweDimension(input_lwe_dimension.0 + 1); + } + + // Create the PRNG + let mut seeder = new_seeder(); + let seeder = seeder.as_mut(); + let mut encryption_generator = + EncryptionRandomGenerator::::new(seeder.seed(), seeder); + let mut secret_generator = + SecretRandomGenerator::::new(seeder.seed()); + + // Create the LweSecretKey + let input_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key( + input_lwe_dimension, + &mut secret_generator, + ); + let output_glwe_secret_key = allocate_and_generate_new_binary_glwe_secret_key( + glwe_dimension, + polynomial_size, + &mut secret_generator, + ); + let output_lwe_secret_key = output_glwe_secret_key.clone().into_lwe_secret_key(); + + let mut bsk = LweMultiBitBootstrapKey::new( + 0u64, + glwe_dimension.to_glwe_size(), + polynomial_size, + decomp_base_log, + decomp_level_count, + input_lwe_dimension, + grouping_factor, + ); + + par_generate_lwe_multi_bit_bootstrap_key( + &input_lwe_secret_key, + &output_glwe_secret_key, + &mut bsk, + glwe_modular_std_dev, + &mut encryption_generator, + ); + + let mut multi_bit_bsk = FourierLweMultiBitBootstrapKey::new( + input_lwe_dimension, + glwe_dimension.to_glwe_size(), + polynomial_size, + decomp_base_log, + decomp_level_count, + grouping_factor, + ); + + convert_standard_lwe_multi_bit_bootstrap_key_to_fourier(&bsk, &mut multi_bit_bsk); + + // Here we will define a helper function to generate an accumulator for a PBS + fn generate_accumulator( + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + message_modulus: usize, + delta: u64, + f: F, + ) -> GlweCiphertextOwned + where + F: Fn(u64) -> u64, + { + // N/(p/2) = size of each block, to correct noise from the input we introduce the + // notion of box, which manages redundancy to yield a denoised value + // for several noisy values around a true input value. + let box_size = polynomial_size.0 / message_modulus; + + // Create the accumulator + let mut accumulator_u64 = vec![0_u64; polynomial_size.0]; + + // Fill each box with the encoded denoised value + for i in 0..message_modulus { + let index = i * box_size; + accumulator_u64[index..index + box_size] + .iter_mut() + .for_each(|a| *a = f(i as u64) * delta); + } + + let half_box_size = box_size / 2; + + // Negate the first half_box_size coefficients to manage negacyclicity and rotate + for a_i in accumulator_u64[0..half_box_size].iter_mut() { + *a_i = (*a_i).wrapping_neg(); + } + + // Rotate the accumulator + accumulator_u64.rotate_left(half_box_size); + + let accumulator_plaintext = PlaintextList::from_container(accumulator_u64); + + allocate_and_trivially_encrypt_new_glwe_ciphertext(glwe_size, &accumulator_plaintext) + } + + // Our 4 bits message space + let message_modulus = 1u64 << 4; + + let f = |x: u64| (3 * x) % message_modulus; + + // Delta used to encode 4 bits of message + a bit of padding on u64 + let delta = (1_u64 << 63) / message_modulus; + + const NB_TESTS: usize = 10; + + for input_message in 0..message_modulus { + for _ in 0..NB_TESTS { + // Apply our encoding + let plaintext = Plaintext(input_message * delta); + + // Allocate a new LweCiphertext and encrypt our plaintext + let lwe_ciphertext_in: LweCiphertextOwned = + allocate_and_encrypt_new_lwe_ciphertext( + &input_lwe_secret_key, + plaintext, + lwe_modular_std_dev, + &mut encryption_generator, + ); + + let accumulator: GlweCiphertextOwned = generate_accumulator( + polynomial_size, + glwe_dimension.to_glwe_size(), + message_modulus as usize, + delta, + f, + ); + + // Allocate the LweCiphertext to store the result of the PBS + let mut out_pbs_ct = + LweCiphertext::new(0u64, output_lwe_secret_key.lwe_dimension().to_lwe_size()); + println!("Computing PBS..."); + multi_bit_programmable_bootstrap_lwe_ciphertext( + &lwe_ciphertext_in, + &mut out_pbs_ct, + &accumulator.as_view(), + &multi_bit_bsk, + thread_count, + ); + + // Decrypt the PBS result + let result_plaintext: Plaintext = + decrypt_lwe_ciphertext(&output_lwe_secret_key, &out_pbs_ct); + + // Create a SignedDecomposer to perform the rounding of the decrypted plaintext + // We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating + // we want to round the 5 MSB, 1 bit of padding plus our 4 bits of + // message + let signed_decomposer = + SignedDecomposer::new(DecompositionBaseLog(5), DecompositionLevelCount(1)); + + // Round and remove our encoding + let result_cleartext: u64 = + signed_decomposer.closest_representable(result_plaintext.0) / delta; + + println!("Checking result..."); + assert_eq!( + f(input_message), + result_cleartext, + "in: {input_message}, expected: {}, out: {result_cleartext}", + f(input_message) + ); + } + } + } + + #[test] + fn multi_bit_pbs_test_factor_2_thread_5() { + multi_bit_pbs(LweBskGroupingFactor(2), ThreadCount(5)); + } + + #[test] + fn multi_bit_pbs_test_factor_3_thread_12() { + multi_bit_pbs(LweBskGroupingFactor(3), ThreadCount(12)); + } +} diff --git a/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch.rs b/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch.rs index 540a4b71d..30768e6bc 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_private_functional_packing_keyswitch.rs @@ -85,11 +85,12 @@ pub fn private_functional_keyswitch_lwe_ciphertext_list_and_pack_in_glwe_ciphert Scalar: UnsignedTorus, KeyCont: Container, InputCont: Container, - OutputCont: ContainerMut + Clone, + OutputCont: ContainerMut, { assert!(input.lwe_ciphertext_count().0 <= output.polynomial_size().0); output.as_mut().fill(Scalar::ZERO); - let mut buffer = output.clone(); + let mut buffer = + GlweCiphertext::new(Scalar::ZERO, output.glwe_size(), output.polynomial_size()); // for each ciphertext, call mono_key_switch for (degree, input_ciphertext) in input.iter().enumerate() { private_functional_keyswitch_lwe_ciphertext_into_glwe_ciphertext( diff --git a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping.rs b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping.rs index 2b8be0a84..c27cb8571 100644 --- a/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping.rs +++ b/tfhe/src/core_crypto/algorithms/lwe_programmable_bootstrapping.rs @@ -918,7 +918,7 @@ pub fn cmux_assign_mem_optimized_requirement( /// let pbs_multipliation_plaintext: Plaintext = /// decrypt_lwe_ciphertext(&big_lwe_sk, &pbs_multiplication_ct); /// -/// /// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext +/// // Create a SignedDecomposer to perform the rounding of the decrypted plaintext /// // We pass a DecompositionBaseLog of 5 and a DecompositionLevelCount of 1 indicating we want to /// // round the 5 MSB, 1 bit of padding plus our 4 bits of message /// let signed_decomposer = diff --git a/tfhe/src/core_crypto/algorithms/mod.rs b/tfhe/src/core_crypto/algorithms/mod.rs index 81a8f9622..bb5982120 100644 --- a/tfhe/src/core_crypto/algorithms/mod.rs +++ b/tfhe/src/core_crypto/algorithms/mod.rs @@ -13,6 +13,9 @@ pub mod lwe_encryption; pub mod lwe_keyswitch; pub mod lwe_keyswitch_key_generation; pub mod lwe_linear_algebra; +pub mod lwe_multi_bit_bootstrap_key_conversion; +pub mod lwe_multi_bit_bootstrap_key_generation; +pub mod lwe_multi_bit_programmable_bootstrapping; pub mod lwe_private_functional_packing_keyswitch; pub mod lwe_private_functional_packing_keyswitch_key_generation; pub mod lwe_programmable_bootstrapping; @@ -44,6 +47,9 @@ pub use lwe_encryption::*; pub use lwe_keyswitch::*; pub use lwe_keyswitch_key_generation::*; pub use lwe_linear_algebra::*; +pub use lwe_multi_bit_bootstrap_key_conversion::*; +pub use lwe_multi_bit_bootstrap_key_generation::*; +pub use lwe_multi_bit_programmable_bootstrapping::*; pub use lwe_private_functional_packing_keyswitch::*; pub use lwe_private_functional_packing_keyswitch_key_generation::*; pub use lwe_programmable_bootstrapping::*; diff --git a/tfhe/src/core_crypto/commons/generators/encryption.rs b/tfhe/src/core_crypto/commons/generators/encryption.rs index 30e165d33..39f516ae6 100644 --- a/tfhe/src/core_crypto/commons/generators/encryption.rs +++ b/tfhe/src/core_crypto/commons/generators/encryption.rs @@ -9,7 +9,7 @@ use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::numeric::UnsignedInteger; use crate::core_crypto::commons::parameters::{ DecompositionLevelCount, FunctionalPackingKeyswitchKeyCount, GlweDimension, GlweSize, - LweCiphertextCount, LweDimension, LweSize, PolynomialSize, + LweBskGroupingFactor, LweCiphertextCount, LweDimension, LweSize, PolynomialSize, }; use concrete_csprng::generators::ForkError; use rayon::prelude::*; @@ -45,6 +45,20 @@ impl EncryptionRandomGenerator { self.mask.remaining_bytes() } + pub(crate) fn fork_n( + &mut self, + n: usize, + ) -> Result>, ForkError> { + // We use ForkTooLarge here as what can fail is the conversion from u128 to usize + let mask_bytes = self.mask.remaining_bytes().ok_or(ForkError::ForkTooLarge)? / n; + let noise_bytes = self + .noise + .remaining_bytes() + .ok_or(ForkError::ForkTooLarge)? + / n; + self.try_fork(n, mask_bytes, noise_bytes) + } + // Forks the generator, when splitting a bootstrap key into ggsw ct. pub(crate) fn fork_bsk_to_ggsw( &mut self, @@ -58,6 +72,21 @@ impl EncryptionRandomGenerator { self.try_fork(lwe_dimension.0, mask_bytes, noise_bytes) } + // Forks the generator, when splitting a multi_bit bootstrap key into ggsw ciphertext groups. + pub(crate) fn fork_multi_bit_bsk_to_ggsw_group( + &mut self, + lwe_dimension: LweDimension, + level: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + grouping_factor: LweBskGroupingFactor, + ) -> Result>, ForkError> { + let ggsw_count = grouping_factor.ggsw_per_multi_bit_element(); + let mask_bytes = ggsw_count.0 * mask_bytes_per_ggsw::(level, glwe_size, polynomial_size); + let noise_bytes = ggsw_count.0 * noise_bytes_per_ggsw(level, glwe_size, polynomial_size); + self.try_fork(lwe_dimension.0, mask_bytes, noise_bytes) + } + // Forks the generator, when splitting a ggsw into level matrices. pub(crate) fn fork_ggsw_to_ggsw_levels( &mut self, @@ -210,6 +239,15 @@ impl EncryptionRandomGenerator { } impl EncryptionRandomGenerator { + pub(crate) fn par_fork_n( + &mut self, + n: usize, + ) -> Result>, ForkError> { + let mask_bytes = self.mask.remaining_bytes().unwrap() / n; + let noise_bytes = self.noise.remaining_bytes().unwrap() / n; + self.par_try_fork(n, mask_bytes, noise_bytes) + } + // Forks the generator into a parallel iterator, when splitting a bootstrap key into ggsw ct. pub(crate) fn par_fork_bsk_to_ggsw( &mut self, @@ -220,7 +258,21 @@ impl EncryptionRandomGenerator { ) -> Result>, ForkError> { let mask_bytes = mask_bytes_per_ggsw::(level, glwe_size, polynomial_size); let noise_bytes = noise_bytes_per_ggsw(level, glwe_size, polynomial_size); - // panic!("{:?} {:?} {:?}", lwe_dimension.0, mask_bytes, noise_bytes); + self.par_try_fork(lwe_dimension.0, mask_bytes, noise_bytes) + } + + // Forks the generator, when splitting a multi_bit bootstrap key into ggsw ct. + pub(crate) fn par_fork_multi_bit_bsk_to_ggsw_group( + &mut self, + lwe_dimension: LweDimension, + level: DecompositionLevelCount, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + grouping_factor: LweBskGroupingFactor, + ) -> Result>, ForkError> { + let ggsw_count = grouping_factor.ggsw_per_multi_bit_element(); + let mask_bytes = ggsw_count.0 * mask_bytes_per_ggsw::(level, glwe_size, polynomial_size); + let noise_bytes = ggsw_count.0 * noise_bytes_per_ggsw(level, glwe_size, polynomial_size); self.par_try_fork(lwe_dimension.0, mask_bytes, noise_bytes) } diff --git a/tfhe/src/core_crypto/commons/parameters.rs b/tfhe/src/core_crypto/commons/parameters.rs index fde4070d9..922f9a5a8 100644 --- a/tfhe/src/core_crypto/commons/parameters.rs +++ b/tfhe/src/core_crypto/commons/parameters.rs @@ -102,7 +102,7 @@ impl GlweDimension { /// The number of coefficients of a polynomial. /// -/// Assuming a polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns $N$. +/// Assuming a polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this new-type contains $N$. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct PolynomialSize(pub usize); @@ -124,7 +124,7 @@ impl PolynomialSize { /// The number of elements in the container of a fourier polynomial. /// -/// Assuming a standard polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this returns +/// Assuming a standard polynomial $a\_0 + a\_1X + /dots + a\_{N-1}X^{N-1}$, this new-type contains /// $\frac{N}{2}$. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct FourierPolynomialSize(pub usize); @@ -204,3 +204,21 @@ pub struct FunctionalPackingKeyswitchKeyCount(pub usize); /// The number of bits used for the mask coefficients and the body of a ciphertext #[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] pub struct CiphertextModulusLog(pub usize); + +/// The number of cpu execution thread to use +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct ThreadCount(pub usize); + +/// The number of key bits grouped together in the multi_bit PBS +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct LweBskGroupingFactor(pub usize); + +impl LweBskGroupingFactor { + pub fn ggsw_per_multi_bit_element(&self) -> GgswPerLweMultiBitBskElement { + GgswPerLweMultiBitBskElement(1 << self.0) + } +} + +/// The number of GGSW ciphertexts required per multi_bit BSK element +#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)] +pub struct GgswPerLweMultiBitBskElement(pub usize); diff --git a/tfhe/src/core_crypto/entities/ggsw_ciphertext.rs b/tfhe/src/core_crypto/entities/ggsw_ciphertext.rs index d12c7ac25..1dc757c26 100644 --- a/tfhe/src/core_crypto/entities/ggsw_ciphertext.rs +++ b/tfhe/src/core_crypto/entities/ggsw_ciphertext.rs @@ -192,6 +192,25 @@ pub fn ggsw_level_matrix_size(glwe_size: GlweSize, polynomial_size: PolynomialSi glwe_size.0 * glwe_size.0 * polynomial_size.0 } +/// Return the number of elements in a [`FourierGgswCiphertext`] given a [`GlweSize`], +/// [`FourierPolynomialSize`] and [`DecompositionLevelCount`]. +pub fn fourier_ggsw_ciphertext_size( + glwe_size: GlweSize, + fourier_polynomial_size: FourierPolynomialSize, + decomp_level_count: DecompositionLevelCount, +) -> usize { + decomp_level_count.0 * fourier_ggsw_level_matrix_size(glwe_size, fourier_polynomial_size) +} + +/// Return the number of elements in a [`FourierGgswLevelMatrix`] given a [`GlweSize`] and +/// [`FourierPolynomialSize`]. +pub fn fourier_ggsw_level_matrix_size( + glwe_size: GlweSize, + fourier_polynomial_size: FourierPolynomialSize, +) -> usize { + glwe_size.0 * glwe_size.0 * fourier_polynomial_size.0 +} + impl> GgswCiphertext { /// Create a [`GgswCiphertext`] from an existing container. /// diff --git a/tfhe/src/core_crypto/entities/ggsw_ciphertext_list.rs b/tfhe/src/core_crypto/entities/ggsw_ciphertext_list.rs index 97b53c053..18c0899af 100644 --- a/tfhe/src/core_crypto/entities/ggsw_ciphertext_list.rs +++ b/tfhe/src/core_crypto/entities/ggsw_ciphertext_list.rs @@ -163,6 +163,17 @@ impl> GgswCiphertextList { pub fn into_container(self) -> C { self.data } + + pub fn as_polynomial_list(&self) -> PolynomialListView<'_, Scalar> { + PolynomialList::from_container(self.as_ref(), self.polynomial_size()) + } +} + +impl> GgswCiphertextList { + pub fn as_mut_polynomial_list(&mut self) -> PolynomialListMutView<'_, Scalar> { + let polynomial_size = self.polynomial_size(); + PolynomialList::from_container(self.as_mut(), polynomial_size) + } } /// A [`GgswCiphertextList`] owning the memory for its own storage. diff --git a/tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs b/tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs new file mode 100644 index 000000000..caf7037ae --- /dev/null +++ b/tfhe/src/core_crypto/entities/lwe_multi_bit_bootstrap_key.rs @@ -0,0 +1,470 @@ +//! Module containing the definition of the [`LweMultiBitBootstrapKey`]. + +use crate::core_crypto::commons::parameters::*; +use crate::core_crypto::commons::traits::*; +use crate::core_crypto::entities::*; +use crate::core_crypto::fft_impl::math::fft::FourierPolynomialList; +use aligned_vec::{avec, ABox}; +use concrete_fft::c64; + +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub struct LweMultiBitBootstrapKey { + // An LweMultiBitBootstrapKey is literally a GgswCiphertextList, so we wrap a + // GgswCiphertextList and use Deref to have access to all the primitives of the + // GgswCiphertextList easily + ggsw_list: GgswCiphertextList, + grouping_factor: LweBskGroupingFactor, +} + +impl std::ops::Deref for LweMultiBitBootstrapKey { + type Target = GgswCiphertextList; + + fn deref(&self) -> &GgswCiphertextList { + &self.ggsw_list + } +} + +impl std::ops::DerefMut for LweMultiBitBootstrapKey { + fn deref_mut(&mut self) -> &mut GgswCiphertextList { + &mut self.ggsw_list + } +} + +impl> LweMultiBitBootstrapKey { + /// Create an [`LweMultiBitBootstrapKey`] from an existing container. + /// + /// # Note + /// + /// This function only wraps a container in the appropriate type. If you want to generate an LWE + /// bootstrap key you need to use + /// [`crate::core_crypto::algorithms::generate_lwe_multi_bit_bootstrap_key`] or its parallel + /// equivalent [`crate::core_crypto::algorithms::par_generate_lwe_multi_bit_bootstrap_key`] + /// using this key as output. + /// + /// This docstring exhibits [`LweMultiBitBootstrapKey`] primitives usage. + /// + /// ``` + /// use tfhe::core_crypto::prelude::*; + /// + /// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct + /// // computations + /// // Define parameters for LweMultiBitBootstrapKey creation + /// let glwe_size = GlweSize(2); + /// let polynomial_size = PolynomialSize(1024); + /// let decomp_base_log = DecompositionBaseLog(8); + /// let decomp_level_count = DecompositionLevelCount(3); + /// let input_lwe_dimension = LweDimension(600); + /// let grouping_factor = LweBskGroupingFactor(2); + /// + /// // Create a new LweMultiBitBootstrapKey + /// let bsk = LweMultiBitBootstrapKey::new( + /// 0u64, + /// glwe_size, + /// polynomial_size, + /// decomp_base_log, + /// decomp_level_count, + /// input_lwe_dimension, + /// grouping_factor, + /// ); + /// + /// // These methods are "inherited" from GgswCiphertextList and are accessed through the Deref + /// // trait + /// assert_eq!(bsk.glwe_size(), glwe_size); + /// assert_eq!(bsk.polynomial_size(), polynomial_size); + /// assert_eq!(bsk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(bsk.decomposition_level_count(), decomp_level_count); + /// + /// // These methods are specific to the LweMultiBitBootstrapKey + /// assert_eq!(bsk.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!( + /// bsk.multi_bit_input_lwe_dimension(), + /// LweDimension(input_lwe_dimension.0 / grouping_factor.0) + /// ); + /// assert_eq!( + /// bsk.output_lwe_dimension().0, + /// glwe_size.to_glwe_dimension().0 * polynomial_size.0 + /// ); + /// assert_eq!(bsk.grouping_factor(), grouping_factor); + /// + /// // Demonstrate how to recover the allocated container + /// let underlying_container: Vec = bsk.into_container(); + /// + /// // Recreate a key using from_container + /// let bsk = LweMultiBitBootstrapKey::from_container( + /// underlying_container, + /// glwe_size, + /// polynomial_size, + /// decomp_base_log, + /// decomp_level_count, + /// grouping_factor, + /// ); + /// + /// assert_eq!(bsk.glwe_size(), glwe_size); + /// assert_eq!(bsk.polynomial_size(), polynomial_size); + /// assert_eq!(bsk.decomposition_base_log(), decomp_base_log); + /// assert_eq!(bsk.decomposition_level_count(), decomp_level_count); + /// assert_eq!(bsk.input_lwe_dimension(), input_lwe_dimension); + /// assert_eq!( + /// bsk.multi_bit_input_lwe_dimension(), + /// LweDimension(input_lwe_dimension.0 / grouping_factor.0) + /// ); + /// assert_eq!( + /// bsk.output_lwe_dimension().0, + /// glwe_size.to_glwe_dimension().0 * polynomial_size.0 + /// ); + /// assert_eq!(bsk.grouping_factor(), grouping_factor); + /// ``` + pub fn from_container( + container: C, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + grouping_factor: LweBskGroupingFactor, + ) -> LweMultiBitBootstrapKey { + let bsk = LweMultiBitBootstrapKey { + ggsw_list: GgswCiphertextList::from_container( + container, + glwe_size, + polynomial_size, + decomp_base_log, + decomp_level_count, + ), + grouping_factor, + }; + assert!( + bsk.ggsw_ciphertext_count().0 % grouping_factor.0 == 0, + "Number of GGSW ({}) in the bootstrap key needs to be a multiple of {}", + bsk.ggsw_ciphertext_count().0, + grouping_factor.0, + ); + bsk + } + + /// Return the [`LweDimension`] of the input [`LweSecretKey`]. + /// + /// See [`LweMultiBitBootstrapKey::from_container`] for usage. + pub fn input_lwe_dimension(&self) -> LweDimension { + let grouping_factor = self.grouping_factor; + let ggsw_per_multi_bit_element = grouping_factor.ggsw_per_multi_bit_element(); + LweDimension( + self.ggsw_ciphertext_count().0 * grouping_factor.0 / ggsw_per_multi_bit_element.0, + ) + } + + /// Return the [`LweDimension`] of the input [`LweSecretKey`] taking into consideration the + /// grouping factor. This essentially returns the input [`LweDimension`] divided by the grouping + /// factor. + /// + /// See [`LweMultiBitBootstrapKey::from_container`] for usage. + pub fn multi_bit_input_lwe_dimension(&self) -> LweDimension { + LweDimension(self.input_lwe_dimension().0 / self.grouping_factor.0) + } + + /// Return the [`LweDimension`] of the equivalent output [`LweSecretKey`]. + /// + /// See [`LweMultiBitBootstrapKey::from_container`] for usage. + pub fn output_lwe_dimension(&self) -> LweDimension { + LweDimension(self.glwe_size().to_glwe_dimension().0 * self.polynomial_size().0) + } + + /// Return the [`LweBskGroupingFactor`] of the current [`LweMultiBitBootstrapKey`]. + /// + /// See [`LweMultiBitBootstrapKey::from_container`] for usage. + pub fn grouping_factor(&self) -> LweBskGroupingFactor { + self.grouping_factor + } + + /// Consume the entity and return its underlying container. + /// + /// See [`LweMultiBitBootstrapKey::from_container`] for usage. + pub fn into_container(self) -> C { + self.ggsw_list.into_container() + } + + /// Return a view of the [`LweMultiBitBootstrapKey`]. This is useful if an algorithm takes a + /// view by value. + pub fn as_view(&self) -> LweMultiBitBootstrapKey<&'_ [Scalar]> { + LweMultiBitBootstrapKey::from_container( + self.as_ref(), + self.glwe_size(), + self.polynomial_size(), + self.decomposition_base_log(), + self.decomposition_level_count(), + self.grouping_factor(), + ) + } +} + +impl> LweMultiBitBootstrapKey { + /// Mutable variant of [`LweMultiBitBootstrapKey::as_view`]. + pub fn as_mut_view(&mut self) -> LweMultiBitBootstrapKey<&'_ mut [Scalar]> { + let glwe_size = self.glwe_size(); + let polynomial_size = self.polynomial_size(); + let decomp_base_log = self.decomposition_base_log(); + let decomp_level_count = self.decomposition_level_count(); + let grouping_factor = self.grouping_factor(); + LweMultiBitBootstrapKey::from_container( + self.as_mut(), + glwe_size, + polynomial_size, + decomp_base_log, + decomp_level_count, + grouping_factor, + ) + } +} + +/// An [`LweMultiBitBootstrapKey`] owning the memory for its own storage. +pub type LweMultiBitBootstrapKeyOwned = LweMultiBitBootstrapKey>; + +impl LweMultiBitBootstrapKeyOwned { + /// Allocate memory and create a new owned [`LweMultiBitBootstrapKey`]. + /// + /// # Note + /// + /// This function allocates a vector of the appropriate size and wraps it in the appropriate + /// type. If you want to generate an LWE bootstrap key you need to use + /// [`crate::core_crypto::algorithms::generate_lwe_bootstrap_key`] or its parallel + /// equivalent [`crate::core_crypto::algorithms::par_generate_lwe_bootstrap_key`] using this + /// key as output. + /// + /// See [`LweMultiBitBootstrapKey::from_container`] for usage. + pub fn new( + fill_with: Scalar, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + decomp_base_log: DecompositionBaseLog, + decomp_level_count: DecompositionLevelCount, + input_lwe_dimension: LweDimension, + grouping_factor: LweBskGroupingFactor, + ) -> LweMultiBitBootstrapKeyOwned { + assert!( + input_lwe_dimension.0 % grouping_factor.0 == 0, + "Multi Bit BSK requires input LWE dimension ({}) to be a multiple of {}", + input_lwe_dimension.0, + grouping_factor.0 + ); + // For two bits multi_bit together + let equivalent_multi_bit_dimension = input_lwe_dimension.0 / grouping_factor.0; + + LweMultiBitBootstrapKeyOwned { + ggsw_list: GgswCiphertextList::new( + fill_with, + glwe_size, + polynomial_size, + decomp_base_log, + decomp_level_count, + GgswCiphertextCount( + equivalent_multi_bit_dimension * grouping_factor.ggsw_per_multi_bit_element().0, + ), + ), + grouping_factor, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct FourierLweMultiBitBootstrapKey> { + fourier: FourierPolynomialList, + input_lwe_dimension: LweDimension, + glwe_size: GlweSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + grouping_factor: LweBskGroupingFactor, +} + +pub type FourierLweMultiBitBootstrapKeyOwned = FourierLweMultiBitBootstrapKey>; +pub type FourierLweMultiBitBootstrapKeyView<'a> = FourierLweMultiBitBootstrapKey<&'a [c64]>; +pub type FourierLweMultiBitBootstrapKeyMutView<'a> = FourierLweMultiBitBootstrapKey<&'a mut [c64]>; + +impl> FourierLweMultiBitBootstrapKey { + pub fn from_container( + data: C, + input_lwe_dimension: LweDimension, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + grouping_factor: LweBskGroupingFactor, + ) -> Self { + assert!( + input_lwe_dimension.0 % grouping_factor.0 == 0, + "Multi Bit BSK requires input LWE dimension to be a multiple of {}", + grouping_factor.0 + ); + let equivalent_multi_bit_dimension = input_lwe_dimension.0 / grouping_factor.0; + let ggsw_count = + equivalent_multi_bit_dimension * grouping_factor.ggsw_per_multi_bit_element().0; + let expected_container_size = ggsw_count + * fourier_ggsw_ciphertext_size( + glwe_size, + polynomial_size.to_fourier_polynomial_size(), + decomposition_level_count, + ); + assert_eq!(data.container_len(), expected_container_size); + Self { + fourier: FourierPolynomialList { + data, + polynomial_size, + }, + input_lwe_dimension, + glwe_size, + decomposition_base_log, + decomposition_level_count, + grouping_factor, + } + } + + /// Return an iterator over the GGSW ciphertexts composing the key. + pub fn ggsw_iter( + &self, + ) -> impl DoubleEndedIterator> { + self.fourier + .data + .as_ref() + .chunks_exact(fourier_ggsw_ciphertext_size( + self.glwe_size, + self.fourier.polynomial_size.to_fourier_polynomial_size(), + self.decomposition_level_count, + )) + .map(move |slice| { + FourierGgswCiphertext::from_container( + slice, + self.glwe_size, + self.fourier.polynomial_size, + self.decomposition_base_log, + self.decomposition_level_count, + ) + }) + } + + pub fn input_lwe_dimension(&self) -> LweDimension { + self.input_lwe_dimension + } + + pub fn multi_bit_input_lwe_dimension(&self) -> LweDimension { + LweDimension(self.input_lwe_dimension().0 / self.grouping_factor.0) + } + + pub fn polynomial_size(&self) -> PolynomialSize { + self.fourier.polynomial_size + } + + pub fn glwe_size(&self) -> GlweSize { + self.glwe_size + } + + pub fn decomposition_base_log(&self) -> DecompositionBaseLog { + self.decomposition_base_log + } + + pub fn decomposition_level_count(&self) -> DecompositionLevelCount { + self.decomposition_level_count + } + + pub fn output_lwe_dimension(&self) -> LweDimension { + LweDimension((self.glwe_size.0 - 1) * self.polynomial_size().0) + } + + pub fn grouping_factor(&self) -> LweBskGroupingFactor { + self.grouping_factor + } + + pub fn data(self) -> C { + self.fourier.data + } + + pub fn as_view(&self) -> FourierLweMultiBitBootstrapKeyView<'_> { + FourierLweMultiBitBootstrapKeyView { + fourier: FourierPolynomialList { + data: self.fourier.data.as_ref(), + polynomial_size: self.fourier.polynomial_size, + }, + input_lwe_dimension: self.input_lwe_dimension, + glwe_size: self.glwe_size, + decomposition_base_log: self.decomposition_base_log, + decomposition_level_count: self.decomposition_level_count, + grouping_factor: self.grouping_factor, + } + } + + pub fn as_mut_view(&mut self) -> FourierLweMultiBitBootstrapKeyMutView<'_> + where + C: AsMut<[c64]>, + { + FourierLweMultiBitBootstrapKeyMutView { + fourier: FourierPolynomialList { + data: self.fourier.data.as_mut(), + polynomial_size: self.fourier.polynomial_size, + }, + input_lwe_dimension: self.input_lwe_dimension, + glwe_size: self.glwe_size, + decomposition_base_log: self.decomposition_base_log, + decomposition_level_count: self.decomposition_level_count, + grouping_factor: self.grouping_factor, + } + } + + pub fn as_polynomial_list(&self) -> FourierPolynomialList<&'_ [c64]> { + FourierPolynomialList { + data: self.fourier.data.as_ref(), + polynomial_size: self.fourier.polynomial_size, + } + } + + pub fn as_mut_polynomial_list(&mut self) -> FourierPolynomialList<&'_ mut [c64]> + where + C: AsMut<[c64]>, + { + FourierPolynomialList { + data: self.fourier.data.as_mut(), + polynomial_size: self.fourier.polynomial_size, + } + } +} + +impl FourierLweMultiBitBootstrapKeyOwned { + pub fn new( + input_lwe_dimension: LweDimension, + glwe_size: GlweSize, + polynomial_size: PolynomialSize, + decomposition_base_log: DecompositionBaseLog, + decomposition_level_count: DecompositionLevelCount, + grouping_factor: LweBskGroupingFactor, + ) -> Self { + assert!( + input_lwe_dimension.0 % grouping_factor.0 == 0, + "Multi Bit BSK requires input LWE dimension ({}) to be a multiple of {}", + input_lwe_dimension.0, + grouping_factor.0 + ); + let equivalent_multi_bit_dimension = input_lwe_dimension.0 / grouping_factor.0; + let ggsw_count = + equivalent_multi_bit_dimension * grouping_factor.ggsw_per_multi_bit_element().0; + let container_size = ggsw_count + * fourier_ggsw_ciphertext_size( + glwe_size, + polynomial_size.to_fourier_polynomial_size(), + decomposition_level_count, + ); + + let boxed = avec![ + c64::default(); + container_size + ] + .into_boxed_slice(); + + Self { + fourier: FourierPolynomialList { + data: boxed, + polynomial_size, + }, + input_lwe_dimension, + glwe_size, + decomposition_base_log, + decomposition_level_count, + grouping_factor, + } + } +} diff --git a/tfhe/src/core_crypto/entities/mod.rs b/tfhe/src/core_crypto/entities/mod.rs index 76aedb2a3..57f16938e 100644 --- a/tfhe/src/core_crypto/entities/mod.rs +++ b/tfhe/src/core_crypto/entities/mod.rs @@ -14,6 +14,7 @@ pub mod lwe_bootstrap_key; pub mod lwe_ciphertext; pub mod lwe_ciphertext_list; pub mod lwe_keyswitch_key; +pub mod lwe_multi_bit_bootstrap_key; pub mod lwe_private_functional_packing_keyswitch_key; pub mod lwe_private_functional_packing_keyswitch_key_list; pub mod lwe_public_key; @@ -35,7 +36,8 @@ pub mod seeded_lwe_public_key; pub use crate::core_crypto::fft_impl::crypto::bootstrap::{ FourierLweBootstrapKey, FourierLweBootstrapKeyOwned, }; -pub use crate::core_crypto::fft_impl::crypto::ggsw::FourierGgswCiphertext; +pub use crate::core_crypto::fft_impl::crypto::ggsw::*; +pub use crate::core_crypto::fft_impl::math::polynomial::FourierPolynomial; pub use cleartext::*; pub use ggsw_ciphertext::*; pub use ggsw_ciphertext_list::*; @@ -47,6 +49,7 @@ pub use lwe_bootstrap_key::*; pub use lwe_ciphertext::*; pub use lwe_ciphertext_list::*; pub use lwe_keyswitch_key::*; +pub use lwe_multi_bit_bootstrap_key::*; pub use lwe_private_functional_packing_keyswitch_key::*; pub use lwe_private_functional_packing_keyswitch_key_list::*; pub use lwe_public_key::*; diff --git a/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs index af8e5122d..7f9f21f93 100644 --- a/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/crypto/ggsw.rs @@ -428,7 +428,7 @@ pub fn add_external_product_assign( unsafe { update_with_fmadd( output_fft_buffer, - ggsw_row, + ggsw_row.data(), fourier, is_output_uninit, fourier_poly_size, @@ -637,9 +637,9 @@ unsafe fn update_with_fmadd_scalar( /// /// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values. #[cfg_attr(__profiling, inline(never))] -unsafe fn update_with_fmadd( +pub(crate) unsafe fn update_with_fmadd( output_fft_buffer: &mut [MaybeUninit], - ggsw_row: FourierGgswLevelRowView, + lhs_polynomial_list: &[c64], fourier: &[c64], is_output_uninit: bool, fourier_poly_size: usize, @@ -665,10 +665,10 @@ unsafe fn update_with_fmadd( izip!( output_fft_buffer.into_chunks(fourier_poly_size), - ggsw_row.data.into_chunks(fourier_poly_size) + lhs_polynomial_list.into_chunks(fourier_poly_size) ) - .for_each(|(output_fourier, ggsw_poly)| { - ptr(output_fourier, ggsw_poly, fourier, is_output_uninit); + .for_each(|(output_fourier, poly_from_list)| { + ptr(output_fourier, poly_from_list, fourier, is_output_uninit); }); } diff --git a/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs b/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs index 88192ff5c..1ad7eaae7 100644 --- a/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs +++ b/tfhe/src/core_crypto/fft_impl/math/fft/mod.rs @@ -5,8 +5,8 @@ use super::polynomial::{ }; use crate::core_crypto::commons::math::torus::UnsignedTorus; use crate::core_crypto::commons::numeric::CastInto; -use crate::core_crypto::commons::parameters::PolynomialSize; -use crate::core_crypto::commons::traits::{Container, IntoContainerOwned}; +use crate::core_crypto::commons::parameters::{PolynomialCount, PolynomialSize}; +use crate::core_crypto::commons::traits::{Container, ContainerMut, IntoContainerOwned}; use crate::core_crypto::commons::utils::izip; use crate::core_crypto::entities::*; use aligned_vec::{avec, ABox}; @@ -370,7 +370,7 @@ impl<'a> FftView<'a> { self.plan .fft_scratch()? .try_and(StackReq::try_new_aligned::( - self.polynomial_size().0 / 2, + self.polynomial_size().to_fourier_polynomial_size().0, aligned_vec::CACHELINE_ALIGN, )?) } @@ -521,6 +521,28 @@ pub struct FourierPolynomialList> { pub polynomial_size: PolynomialSize, } +impl> FourierPolynomialList { + pub fn polynomial_count(&self) -> PolynomialCount { + PolynomialCount( + self.data.container_len() / self.polynomial_size.to_fourier_polynomial_size().0, + ) + } +} + +impl> FourierPolynomialList { + pub fn iter_mut( + &mut self, + ) -> impl DoubleEndedIterator> { + assert!( + self.data.container_len() % self.polynomial_size.to_fourier_polynomial_size().0 == 0 + ); + self.data + .as_mut() + .chunks_exact_mut(self.polynomial_size.to_fourier_polynomial_size().0) + .map(move |slice| FourierPolynomial { data: slice }) + } +} + impl> serde::Serialize for FourierPolynomialList { fn serialize(&self, serializer: S) -> Result { fn serialize_impl( diff --git a/tfhe/src/core_crypto/fft_impl/math/polynomial.rs b/tfhe/src/core_crypto/fft_impl/math/polynomial.rs index 414912f3e..1bc47defd 100644 --- a/tfhe/src/core_crypto/fft_impl/math/polynomial.rs +++ b/tfhe/src/core_crypto/fft_impl/math/polynomial.rs @@ -1,6 +1,8 @@ use super::super::as_mut_uninit; +use crate::core_crypto::commons::parameters::*; use crate::core_crypto::commons::traits::*; use crate::core_crypto::entities::Polynomial; +use aligned_vec::{avec, ABox}; use concrete_fft::c64; //-------------------------------------------------------------------------------- @@ -21,6 +23,20 @@ pub struct FourierPolynomial { pub type FourierPolynomialView<'a> = FourierPolynomial<&'a [c64]>; pub type FourierPolynomialMutView<'a> = FourierPolynomial<&'a mut [c64]>; +pub type FourierPolynomialOwned = FourierPolynomial>; + +impl FourierPolynomial> { + pub fn new(polynomial_size: PolynomialSize) -> FourierPolynomial> { + let boxed = avec![ + c64::default(); + polynomial_size.to_fourier_polynomial_size().0 + ] + .into_boxed_slice(); + + FourierPolynomial { data: boxed } + } +} + /// Polynomial in the standard domain, with possibly uninitialized coefficients. /// /// This is used for the Fourier transforms to avoid the cost of initializing the output buffer, @@ -54,6 +70,10 @@ impl> FourierPolynomial { data: self.data.as_mut(), } } + + pub fn polynomial_size(&self) -> PolynomialSize { + PolynomialSize(self.data.container_len() * 2) + } } impl<'a, Scalar> Polynomial<&'a mut [Scalar]> {