refactor(shortint): factorize PBS code

This commit is contained in:
Mayeul@Zama
2024-03-07 11:16:44 +01:00
committed by mayeul-zama
parent 13f7adec66
commit 7e723f1ec2
2 changed files with 124 additions and 252 deletions

View File

@@ -33,6 +33,7 @@ use crate::core_crypto::commons::parameters::{
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
use crate::core_crypto::prelude::ComputationBuffers;
use crate::shortint::ciphertext::{Ciphertext, Degree, MaxDegree, MaxNoiseLevel, NoiseLevel};
use crate::shortint::client_key::ClientKey;
use crate::shortint::engine::{
@@ -738,16 +739,49 @@ impl ServerKey {
}
pub fn apply_lookup_table_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) {
match self.pbs_order {
PBSOrder::KeyswitchBootstrap => {
// This updates the ciphertext degree
self.keyswitch_programmable_bootstrap_assign(ct, acc);
if ct.is_trivial() {
self.trivial_pbs_assign(ct, acc);
return;
}
ShortintEngine::with_thread_local_mut(|engine| {
let (mut ciphertext_buffers, buffers) = engine.get_buffers(self);
match self.pbs_order {
PBSOrder::KeyswitchBootstrap => {
keyswitch_lwe_ciphertext(
&self.key_switching_key,
&ct.ct,
&mut ciphertext_buffers.buffer_lwe_after_ks,
);
apply_programmable_bootstrap(
&self.bootstrapping_key,
&ciphertext_buffers.buffer_lwe_after_ks,
&mut ct.ct,
acc,
buffers,
);
}
PBSOrder::BootstrapKeyswitch => {
apply_programmable_bootstrap(
&self.bootstrapping_key,
&ct.ct,
&mut ciphertext_buffers.buffer_lwe_after_pbs,
acc,
buffers,
);
keyswitch_lwe_ciphertext(
&self.key_switching_key,
&ciphertext_buffers.buffer_lwe_after_pbs,
&mut ct.ct,
);
}
}
PBSOrder::BootstrapKeyswitch => {
// This updates the ciphertext degree
self.programmable_bootstrap_keyswitch_assign(ct, acc);
}
};
});
ct.degree = acc.degree;
ct.set_noise_level(NoiseLevel::NOMINAL);
}
/// Compute a keyswitch and programmable bootstrap applying several functions on an input
@@ -1111,6 +1145,11 @@ impl ServerKey {
}
fn trivial_pbs_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) {
#[cfg(feature = "pbs-stats")]
// We want to count trivial PBS in simulator mode
// In the non trivial case, this increment is done in the `apply_blind_rotate` function
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
assert_eq!(ct.noise_level(), NoiseLevel::ZERO);
let modulus_sup = self.message_modulus.0 * self.carry_modulus.0;
let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0) as u64;
@@ -1183,163 +1222,6 @@ impl ServerKey {
outputs
}
pub(crate) fn keyswitch_programmable_bootstrap_assign(
&self,
ct: &mut Ciphertext,
acc: &LookupTableOwned,
) {
#[cfg(feature = "pbs-stats")]
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
if ct.is_trivial() {
self.trivial_pbs_assign(ct, acc);
return;
}
ShortintEngine::with_thread_local_mut(|engine| {
// Compute the programmable bootstrapping with fixed test polynomial
let (mut ciphertext_buffers, buffers) = engine.get_buffers(self);
// Compute a key switch
keyswitch_lwe_ciphertext(
&self.key_switching_key,
&ct.ct,
&mut ciphertext_buffers.buffer_lwe_after_ks,
);
match &self.bootstrapping_key {
ShortintBootstrappingKey::Classic(fourier_bsk) => {
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = buffers.stack();
// Compute a bootstrap
programmable_bootstrap_lwe_ciphertext_mem_optimized(
&ciphertext_buffers.buffer_lwe_after_ks,
&mut ct.ct,
&acc.acc,
fourier_bsk,
fft,
stack,
);
}
ShortintBootstrappingKey::MultiBit {
fourier_bsk,
thread_count,
deterministic_execution,
} => {
if *deterministic_execution {
multi_bit_deterministic_programmable_bootstrap_lwe_ciphertext(
&ciphertext_buffers.buffer_lwe_after_ks,
&mut ct.ct,
&acc.acc,
fourier_bsk,
*thread_count,
);
} else {
multi_bit_programmable_bootstrap_lwe_ciphertext(
&ciphertext_buffers.buffer_lwe_after_ks,
&mut ct.ct,
&acc.acc,
fourier_bsk,
*thread_count,
);
}
}
};
});
ct.degree = acc.degree;
ct.set_noise_level(NoiseLevel::NOMINAL);
}
pub(crate) fn programmable_bootstrap_keyswitch_assign(
&self,
ct: &mut Ciphertext,
acc: &LookupTableOwned,
) {
#[cfg(feature = "pbs-stats")]
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
if ct.is_trivial() {
self.trivial_pbs_assign(ct, acc);
return;
}
ShortintEngine::with_thread_local_mut(|engine| {
let (mut ciphertext_buffers, buffers) = engine.get_buffers(self);
match &self.bootstrapping_key {
ShortintBootstrappingKey::Classic(fourier_bsk) => {
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = buffers.stack();
// Compute a bootstrap
programmable_bootstrap_lwe_ciphertext_mem_optimized(
&ct.ct,
&mut ciphertext_buffers.buffer_lwe_after_pbs,
&acc.acc,
fourier_bsk,
fft,
stack,
);
}
ShortintBootstrappingKey::MultiBit {
fourier_bsk,
thread_count,
deterministic_execution,
} => {
if *deterministic_execution {
multi_bit_deterministic_programmable_bootstrap_lwe_ciphertext(
&ct.ct,
&mut ciphertext_buffers.buffer_lwe_after_pbs,
&acc.acc,
fourier_bsk,
*thread_count,
);
} else {
multi_bit_programmable_bootstrap_lwe_ciphertext(
&ct.ct,
&mut ciphertext_buffers.buffer_lwe_after_pbs,
&acc.acc,
fourier_bsk,
*thread_count,
);
}
}
};
// Compute a key switch
keyswitch_lwe_ciphertext(
&self.key_switching_key,
&ciphertext_buffers.buffer_lwe_after_pbs,
&mut ct.ct,
);
});
ct.degree = acc.degree;
ct.set_noise_level(NoiseLevel::NOMINAL);
}
pub(crate) fn keyswitch_programmable_bootstrap_many_lut(
&self,
ct: &Ciphertext,
@@ -1362,52 +1244,12 @@ impl ServerKey {
&mut ciphertext_buffers.buffer_lwe_after_ks,
);
match &self.bootstrapping_key {
ShortintBootstrappingKey::Classic(fourier_bsk) => {
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = buffers.stack();
// Compute the blind rotation
blind_rotate_assign_mem_optimized(
&ciphertext_buffers.buffer_lwe_after_ks,
&mut acc,
fourier_bsk,
fft,
stack,
);
}
ShortintBootstrappingKey::MultiBit {
fourier_bsk,
thread_count,
deterministic_execution,
} => {
if *deterministic_execution {
multi_bit_deterministic_blind_rotate_assign(
&ciphertext_buffers.buffer_lwe_after_ks,
&mut acc,
fourier_bsk,
*thread_count,
);
} else {
multi_bit_blind_rotate_assign(
&ciphertext_buffers.buffer_lwe_after_ks,
&mut acc,
fourier_bsk,
*thread_count,
);
}
}
};
apply_blind_rotate(
&self.bootstrapping_key,
&ciphertext_buffers.buffer_lwe_after_ks.as_view(),
&mut acc,
buffers,
);
});
// The accumulator has been rotated, we can now proceed with the various sample extractions
@@ -1447,41 +1289,7 @@ impl ServerKey {
// Compute the programmable bootstrapping with fixed test polynomial
let (_, buffers) = engine.get_buffers(self);
match &self.bootstrapping_key {
ShortintBootstrappingKey::Classic(fourier_bsk) => {
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = buffers.stack();
// Compute the blind rotation
blind_rotate_assign_mem_optimized(&ct.ct, &mut acc, fourier_bsk, fft, stack);
}
ShortintBootstrappingKey::MultiBit {
fourier_bsk,
thread_count,
deterministic_execution,
} => {
if *deterministic_execution {
multi_bit_deterministic_blind_rotate_assign(
&ct.ct,
&mut acc,
fourier_bsk,
*thread_count,
);
} else {
multi_bit_blind_rotate_assign(&ct.ct, &mut acc, fourier_bsk, *thread_count);
}
}
};
apply_blind_rotate(&self.bootstrapping_key, &ct.ct, &mut acc, buffers);
});
// The accumulator has been rotated, we can now proceed with the various sample extractions
@@ -1685,3 +1493,70 @@ impl ServerKey {
.min_by_key(|op| op.number_of_pbs())
}
}
pub(crate) fn apply_blind_rotate<Scalar, InputCont, OutputCont>(
bootstrapping_key: &ShortintBootstrappingKey,
in_buffer: &LweCiphertext<InputCont>,
acc: &mut GlweCiphertext<OutputCont>,
buffers: &mut ComputationBuffers,
) where
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
#[cfg(feature = "pbs-stats")]
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
match bootstrapping_key {
ShortintBootstrappingKey::Classic(fourier_bsk) => {
let fft = Fft::new(fourier_bsk.polynomial_size());
let fft = fft.as_view();
buffers.resize(
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
fourier_bsk.glwe_size(),
fourier_bsk.polynomial_size(),
fft,
)
.unwrap()
.unaligned_bytes_required(),
);
let stack = buffers.stack();
// Compute the blind rotation
blind_rotate_assign_mem_optimized(in_buffer, acc, fourier_bsk, fft, stack);
}
ShortintBootstrappingKey::MultiBit {
fourier_bsk,
thread_count,
deterministic_execution,
} => {
if *deterministic_execution {
multi_bit_deterministic_blind_rotate_assign(
in_buffer,
acc,
fourier_bsk,
*thread_count,
);
} else {
multi_bit_blind_rotate_assign(in_buffer, acc, fourier_bsk, *thread_count);
}
}
};
}
pub(crate) fn apply_programmable_bootstrap<InputCont, OutputCont>(
bootstrapping_key: &ShortintBootstrappingKey,
in_buffer: &LweCiphertext<InputCont>,
out_buffer: &mut LweCiphertext<OutputCont>,
acc: &LookupTableOwned,
buffers: &mut ComputationBuffers,
) where
InputCont: Container<Element = u64>,
OutputCont: ContainerMut<Element = u64>,
{
let mut glwe_out = acc.acc.clone();
apply_blind_rotate(bootstrapping_key, in_buffer, &mut glwe_out, buffers);
extract_lwe_sample_from_glwe_ciphertext(&glwe_out, out_buffer, MonomialDegree(0));
}

View File

@@ -891,11 +891,8 @@ impl WopbsKey {
) -> Ciphertext {
let extracted_bits = self.extract_bits(delta_log, ct_in, nb_bit_to_extract);
let ciphertext_list = self.circuit_bootstrap_with_bits(
&extracted_bits.as_view(),
&lut.lut(),
LweCiphertextCount(1),
);
let ciphertext_list =
self.circuit_bootstrap_with_bits(&extracted_bits, &lut.lut(), LweCiphertextCount(1));
// Here the output list contains a single ciphertext, we can consume the container to
// convert it to a single ciphertext