mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-10 15:18:33 -05:00
refactor(shortint): factorize PBS code
This commit is contained in:
@@ -33,6 +33,7 @@ use crate::core_crypto::commons::parameters::{
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use crate::core_crypto::fft_impl::fft64::math::fft::Fft;
|
||||
use crate::core_crypto::prelude::ComputationBuffers;
|
||||
use crate::shortint::ciphertext::{Ciphertext, Degree, MaxDegree, MaxNoiseLevel, NoiseLevel};
|
||||
use crate::shortint::client_key::ClientKey;
|
||||
use crate::shortint::engine::{
|
||||
@@ -738,16 +739,49 @@ impl ServerKey {
|
||||
}
|
||||
|
||||
pub fn apply_lookup_table_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) {
|
||||
match self.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => {
|
||||
// This updates the ciphertext degree
|
||||
self.keyswitch_programmable_bootstrap_assign(ct, acc);
|
||||
if ct.is_trivial() {
|
||||
self.trivial_pbs_assign(ct, acc);
|
||||
return;
|
||||
}
|
||||
|
||||
ShortintEngine::with_thread_local_mut(|engine| {
|
||||
let (mut ciphertext_buffers, buffers) = engine.get_buffers(self);
|
||||
match self.pbs_order {
|
||||
PBSOrder::KeyswitchBootstrap => {
|
||||
keyswitch_lwe_ciphertext(
|
||||
&self.key_switching_key,
|
||||
&ct.ct,
|
||||
&mut ciphertext_buffers.buffer_lwe_after_ks,
|
||||
);
|
||||
|
||||
apply_programmable_bootstrap(
|
||||
&self.bootstrapping_key,
|
||||
&ciphertext_buffers.buffer_lwe_after_ks,
|
||||
&mut ct.ct,
|
||||
acc,
|
||||
buffers,
|
||||
);
|
||||
}
|
||||
PBSOrder::BootstrapKeyswitch => {
|
||||
apply_programmable_bootstrap(
|
||||
&self.bootstrapping_key,
|
||||
&ct.ct,
|
||||
&mut ciphertext_buffers.buffer_lwe_after_pbs,
|
||||
acc,
|
||||
buffers,
|
||||
);
|
||||
|
||||
keyswitch_lwe_ciphertext(
|
||||
&self.key_switching_key,
|
||||
&ciphertext_buffers.buffer_lwe_after_pbs,
|
||||
&mut ct.ct,
|
||||
);
|
||||
}
|
||||
}
|
||||
PBSOrder::BootstrapKeyswitch => {
|
||||
// This updates the ciphertext degree
|
||||
self.programmable_bootstrap_keyswitch_assign(ct, acc);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
ct.degree = acc.degree;
|
||||
ct.set_noise_level(NoiseLevel::NOMINAL);
|
||||
}
|
||||
|
||||
/// Compute a keyswitch and programmable bootstrap applying several functions on an input
|
||||
@@ -1111,6 +1145,11 @@ impl ServerKey {
|
||||
}
|
||||
|
||||
fn trivial_pbs_assign(&self, ct: &mut Ciphertext, acc: &LookupTableOwned) {
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
// We want to count trivial PBS in simulator mode
|
||||
// In the non trivial case, this increment is done in the `apply_blind_rotate` function
|
||||
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
assert_eq!(ct.noise_level(), NoiseLevel::ZERO);
|
||||
let modulus_sup = self.message_modulus.0 * self.carry_modulus.0;
|
||||
let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0) as u64;
|
||||
@@ -1183,163 +1222,6 @@ impl ServerKey {
|
||||
outputs
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_programmable_bootstrap_assign(
|
||||
&self,
|
||||
ct: &mut Ciphertext,
|
||||
acc: &LookupTableOwned,
|
||||
) {
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if ct.is_trivial() {
|
||||
self.trivial_pbs_assign(ct, acc);
|
||||
return;
|
||||
}
|
||||
|
||||
ShortintEngine::with_thread_local_mut(|engine| {
|
||||
// Compute the programmable bootstrapping with fixed test polynomial
|
||||
let (mut ciphertext_buffers, buffers) = engine.get_buffers(self);
|
||||
|
||||
// Compute a key switch
|
||||
keyswitch_lwe_ciphertext(
|
||||
&self.key_switching_key,
|
||||
&ct.ct,
|
||||
&mut ciphertext_buffers.buffer_lwe_after_ks,
|
||||
);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
ShortintBootstrappingKey::Classic(fourier_bsk) => {
|
||||
let fft = Fft::new(fourier_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
buffers.resize(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = buffers.stack();
|
||||
|
||||
// Compute a bootstrap
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized(
|
||||
&ciphertext_buffers.buffer_lwe_after_ks,
|
||||
&mut ct.ct,
|
||||
&acc.acc,
|
||||
fourier_bsk,
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
}
|
||||
ShortintBootstrappingKey::MultiBit {
|
||||
fourier_bsk,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
} => {
|
||||
if *deterministic_execution {
|
||||
multi_bit_deterministic_programmable_bootstrap_lwe_ciphertext(
|
||||
&ciphertext_buffers.buffer_lwe_after_ks,
|
||||
&mut ct.ct,
|
||||
&acc.acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
} else {
|
||||
multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
&ciphertext_buffers.buffer_lwe_after_ks,
|
||||
&mut ct.ct,
|
||||
&acc.acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
ct.degree = acc.degree;
|
||||
ct.set_noise_level(NoiseLevel::NOMINAL);
|
||||
}
|
||||
|
||||
pub(crate) fn programmable_bootstrap_keyswitch_assign(
|
||||
&self,
|
||||
ct: &mut Ciphertext,
|
||||
acc: &LookupTableOwned,
|
||||
) {
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
if ct.is_trivial() {
|
||||
self.trivial_pbs_assign(ct, acc);
|
||||
return;
|
||||
}
|
||||
|
||||
ShortintEngine::with_thread_local_mut(|engine| {
|
||||
let (mut ciphertext_buffers, buffers) = engine.get_buffers(self);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
ShortintBootstrappingKey::Classic(fourier_bsk) => {
|
||||
let fft = Fft::new(fourier_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
buffers.resize(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = buffers.stack();
|
||||
|
||||
// Compute a bootstrap
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized(
|
||||
&ct.ct,
|
||||
&mut ciphertext_buffers.buffer_lwe_after_pbs,
|
||||
&acc.acc,
|
||||
fourier_bsk,
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
}
|
||||
ShortintBootstrappingKey::MultiBit {
|
||||
fourier_bsk,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
} => {
|
||||
if *deterministic_execution {
|
||||
multi_bit_deterministic_programmable_bootstrap_lwe_ciphertext(
|
||||
&ct.ct,
|
||||
&mut ciphertext_buffers.buffer_lwe_after_pbs,
|
||||
&acc.acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
} else {
|
||||
multi_bit_programmable_bootstrap_lwe_ciphertext(
|
||||
&ct.ct,
|
||||
&mut ciphertext_buffers.buffer_lwe_after_pbs,
|
||||
&acc.acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Compute a key switch
|
||||
keyswitch_lwe_ciphertext(
|
||||
&self.key_switching_key,
|
||||
&ciphertext_buffers.buffer_lwe_after_pbs,
|
||||
&mut ct.ct,
|
||||
);
|
||||
});
|
||||
|
||||
ct.degree = acc.degree;
|
||||
ct.set_noise_level(NoiseLevel::NOMINAL);
|
||||
}
|
||||
|
||||
pub(crate) fn keyswitch_programmable_bootstrap_many_lut(
|
||||
&self,
|
||||
ct: &Ciphertext,
|
||||
@@ -1362,52 +1244,12 @@ impl ServerKey {
|
||||
&mut ciphertext_buffers.buffer_lwe_after_ks,
|
||||
);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
ShortintBootstrappingKey::Classic(fourier_bsk) => {
|
||||
let fft = Fft::new(fourier_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
buffers.resize(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = buffers.stack();
|
||||
|
||||
// Compute the blind rotation
|
||||
blind_rotate_assign_mem_optimized(
|
||||
&ciphertext_buffers.buffer_lwe_after_ks,
|
||||
&mut acc,
|
||||
fourier_bsk,
|
||||
fft,
|
||||
stack,
|
||||
);
|
||||
}
|
||||
ShortintBootstrappingKey::MultiBit {
|
||||
fourier_bsk,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
} => {
|
||||
if *deterministic_execution {
|
||||
multi_bit_deterministic_blind_rotate_assign(
|
||||
&ciphertext_buffers.buffer_lwe_after_ks,
|
||||
&mut acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
} else {
|
||||
multi_bit_blind_rotate_assign(
|
||||
&ciphertext_buffers.buffer_lwe_after_ks,
|
||||
&mut acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
apply_blind_rotate(
|
||||
&self.bootstrapping_key,
|
||||
&ciphertext_buffers.buffer_lwe_after_ks.as_view(),
|
||||
&mut acc,
|
||||
buffers,
|
||||
);
|
||||
});
|
||||
|
||||
// The accumulator has been rotated, we can now proceed with the various sample extractions
|
||||
@@ -1447,41 +1289,7 @@ impl ServerKey {
|
||||
// Compute the programmable bootstrapping with fixed test polynomial
|
||||
let (_, buffers) = engine.get_buffers(self);
|
||||
|
||||
match &self.bootstrapping_key {
|
||||
ShortintBootstrappingKey::Classic(fourier_bsk) => {
|
||||
let fft = Fft::new(fourier_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
buffers.resize(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = buffers.stack();
|
||||
|
||||
// Compute the blind rotation
|
||||
blind_rotate_assign_mem_optimized(&ct.ct, &mut acc, fourier_bsk, fft, stack);
|
||||
}
|
||||
ShortintBootstrappingKey::MultiBit {
|
||||
fourier_bsk,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
} => {
|
||||
if *deterministic_execution {
|
||||
multi_bit_deterministic_blind_rotate_assign(
|
||||
&ct.ct,
|
||||
&mut acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
} else {
|
||||
multi_bit_blind_rotate_assign(&ct.ct, &mut acc, fourier_bsk, *thread_count);
|
||||
}
|
||||
}
|
||||
};
|
||||
apply_blind_rotate(&self.bootstrapping_key, &ct.ct, &mut acc, buffers);
|
||||
});
|
||||
|
||||
// The accumulator has been rotated, we can now proceed with the various sample extractions
|
||||
@@ -1685,3 +1493,70 @@ impl ServerKey {
|
||||
.min_by_key(|op| op.number_of_pbs())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn apply_blind_rotate<Scalar, InputCont, OutputCont>(
|
||||
bootstrapping_key: &ShortintBootstrappingKey,
|
||||
in_buffer: &LweCiphertext<InputCont>,
|
||||
acc: &mut GlweCiphertext<OutputCont>,
|
||||
buffers: &mut ComputationBuffers,
|
||||
) where
|
||||
Scalar: UnsignedTorus + CastInto<usize> + CastFrom<usize> + Sync,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
#[cfg(feature = "pbs-stats")]
|
||||
let _ = PBS_COUNT.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
match bootstrapping_key {
|
||||
ShortintBootstrappingKey::Classic(fourier_bsk) => {
|
||||
let fft = Fft::new(fourier_bsk.polynomial_size());
|
||||
let fft = fft.as_view();
|
||||
buffers.resize(
|
||||
programmable_bootstrap_lwe_ciphertext_mem_optimized_requirement::<u64>(
|
||||
fourier_bsk.glwe_size(),
|
||||
fourier_bsk.polynomial_size(),
|
||||
fft,
|
||||
)
|
||||
.unwrap()
|
||||
.unaligned_bytes_required(),
|
||||
);
|
||||
let stack = buffers.stack();
|
||||
|
||||
// Compute the blind rotation
|
||||
blind_rotate_assign_mem_optimized(in_buffer, acc, fourier_bsk, fft, stack);
|
||||
}
|
||||
ShortintBootstrappingKey::MultiBit {
|
||||
fourier_bsk,
|
||||
thread_count,
|
||||
deterministic_execution,
|
||||
} => {
|
||||
if *deterministic_execution {
|
||||
multi_bit_deterministic_blind_rotate_assign(
|
||||
in_buffer,
|
||||
acc,
|
||||
fourier_bsk,
|
||||
*thread_count,
|
||||
);
|
||||
} else {
|
||||
multi_bit_blind_rotate_assign(in_buffer, acc, fourier_bsk, *thread_count);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) fn apply_programmable_bootstrap<InputCont, OutputCont>(
|
||||
bootstrapping_key: &ShortintBootstrappingKey,
|
||||
in_buffer: &LweCiphertext<InputCont>,
|
||||
out_buffer: &mut LweCiphertext<OutputCont>,
|
||||
acc: &LookupTableOwned,
|
||||
buffers: &mut ComputationBuffers,
|
||||
) where
|
||||
InputCont: Container<Element = u64>,
|
||||
OutputCont: ContainerMut<Element = u64>,
|
||||
{
|
||||
let mut glwe_out = acc.acc.clone();
|
||||
|
||||
apply_blind_rotate(bootstrapping_key, in_buffer, &mut glwe_out, buffers);
|
||||
|
||||
extract_lwe_sample_from_glwe_ciphertext(&glwe_out, out_buffer, MonomialDegree(0));
|
||||
}
|
||||
|
||||
@@ -891,11 +891,8 @@ impl WopbsKey {
|
||||
) -> Ciphertext {
|
||||
let extracted_bits = self.extract_bits(delta_log, ct_in, nb_bit_to_extract);
|
||||
|
||||
let ciphertext_list = self.circuit_bootstrap_with_bits(
|
||||
&extracted_bits.as_view(),
|
||||
&lut.lut(),
|
||||
LweCiphertextCount(1),
|
||||
);
|
||||
let ciphertext_list =
|
||||
self.circuit_bootstrap_with_bits(&extracted_bits, &lut.lut(), LweCiphertextCount(1));
|
||||
|
||||
// Here the output list contains a single ciphertext, we can consume the container to
|
||||
// convert it to a single ciphertext
|
||||
|
||||
Reference in New Issue
Block a user