feat(hpu): Add dedicated keyswitch implementation

Use a dedicated keyswitch implementation that used unbalanced keyswitch.
Enable to generate bit-accurate stimulus without the need of a feature
flag inside the decomposer implementation.
This commit is contained in:
Baptiste Roux
2025-03-31 21:40:56 +02:00
parent 3c78fdbdb0
commit d2554b273c
5 changed files with 608 additions and 41 deletions

View File

@@ -13,6 +13,7 @@ use tfhe::core_crypto::entities::{
Cleartext, LweCiphertextOwned, LweCiphertextView, LweKeyswitchKey, NttLweBootstrapKey,
Plaintext,
};
use tfhe::core_crypto::hpu::algorithms::lwe_keyswitch::hpu_keyswitch_lwe_ciphertext;
use tfhe::core_crypto::hpu::from_with::FromWith;
use tfhe::core_crypto::hpu::glwe_lookuptable::create_hpu_lookuptable;
use tfhe::core_crypto::prelude::*;
@@ -648,7 +649,7 @@ impl HpuSim {
// TODO add a check on trivialness for fast simulation ?
// TODO assert ordering (i.e. KS+PBS)
keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks);
hpu_keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks);
blind_rotate_ntt64_bnf_assign(bfr_after_ks, &mut tfhe_lut, &bsk);
assert_eq!(

View File

@@ -587,9 +587,11 @@ pub(crate) fn add_external_product_ntt64_bnf_assign<InputGlweCont>(
// DOMAIN In this section, we perform the external product in the ntt
// domain, and accumulate the result in the output_fft_buffer variable.
let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref()
.iter()
.map(|s| decomposer.init_decomposer_state(*s)),
glwe.as_ref().iter().map(|s| {
decomposer.closest_representable(*s)
>> (u64::BITS
- (decomposer.level_count().0 * decomposer.base_log().0) as u32)
}),
decomposer.base_log(),
decomposer.level_count(),
substack0,

View File

@@ -55,7 +55,6 @@ pub fn native_closest_representable<Scalar: UnsignedInteger>(
/// returns 1 if the following if condition is true otherwise 0
///
/// (val > B / 2) || ((val == B / 2) && (random == 1))
#[cfg(any(not(feature = "hpu"), test))]
#[inline(always)]
fn balanced_rounding_condition_bit_trick<Scalar: UnsignedInteger>(
val: Scalar,
@@ -147,43 +146,33 @@ where
#[inline(always)]
pub fn init_decomposer_state(&self, input: Scalar) -> Scalar {
// Default mode -> use balanced decomposition
#[cfg(not(feature = "hpu"))]
{
// The closest number representable by the decomposition can be computed by performing
// the rounding at the appropriate bit.
// The closest number representable by the decomposition can be computed by performing
// the rounding at the appropriate bit.
// We compute the number of least significant bits which can not be represented by the
// decomposition
// Example with level_count = 3, base_log = 4 and BITS == 64 -> 52
let rep_bit_count = self.level_count * self.base_log;
let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count;
// Move the representable bits + 1 to the LSB, with our example :
// |-----| 64 - (64 - 12 - 1) == 13 bits
// 0....0XX...XX
let mut res = input >> (non_rep_bit_count - 1);
// Fetch the first bit value as we need it for a balanced rounding
let rounding_bit = res & Scalar::ONE;
// Add one to do the rounding by adding the half interval
res += Scalar::ONE;
// Discard the LSB which was the one deciding in which direction we round
res >>= 1;
// Keep the low base_log * level bits
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
res &= mod_mask;
// Control bit about whether we should balance the state
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == 1)
let need_balance =
balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit);
// Balance depending on the control bit
res.wrapping_sub(need_balance << rep_bit_count)
}
// Hpu used unbalanced decomposition
#[cfg(feature = "hpu")]
{
self.closest_representable(input) >> (Scalar::BITS - (self.level_count * self.base_log))
}
// We compute the number of least significant bits which can not be represented by the
// decomposition
// Example with level_count = 3, base_log = 4 and BITS == 64 -> 52
let rep_bit_count = self.level_count * self.base_log;
let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count;
// Move the representable bits + 1 to the LSB, with our example :
// |-----| 64 - (64 - 12 - 1) == 13 bits
// 0....0XX...XX
let mut res = input >> (non_rep_bit_count - 1);
// Fetch the first bit value as we need it for a balanced rounding
let rounding_bit = res & Scalar::ONE;
// Add one to do the rounding by adding the half interval
res += Scalar::ONE;
// Discard the LSB which was the one deciding in which direction we round
res >>= 1;
// Keep the low base_log * level bits
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
res &= mod_mask;
// Control bit about whether we should balance the state
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random ==
// 1)
let need_balance = balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit);
// Balance depending on the control bit
res.wrapping_sub(need_balance << rep_bit_count)
}
/// Generate an iterator over the terms of the decomposition of the input.
@@ -228,6 +217,48 @@ where
)
}
/// Generate an iterator over the terms of the decomposition of the input.
/// # Warning
/// This used unbalanced decomposition and shouldn't be used with one-level decomposition
/// The returned iterator yields the terms $\tilde{\theta}\_i$ in order of decreasing $i$.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer;
/// use tfhe::core_crypto::commons::numeric::UnsignedInteger;
/// use tfhe::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
/// let decomposer =
/// SignedDecomposer::<u32>::new(DecompositionBaseLog(4), DecompositionLevelCount(3));
/// // 2147483647 == 2^31 - 1 and has a decomposition term == to half_basis
/// for term in decomposer.decompose(2147483647u32) {
/// assert!(1 <= term.level().0);
/// assert!(term.level().0 <= 3);
/// let signed_term = term.value().into_signed();
/// let half_basis = 2i32.pow(4) / 2i32;
/// assert!(
/// -half_basis <= signed_term,
/// "{} <= {signed_term} failed",
/// -half_basis
/// );
/// assert!(
/// signed_term <= half_basis,
/// "{signed_term} <= {half_basis} failed"
/// );
/// }
/// assert_eq!(decomposer.decompose(1).count(), 3);
/// ```
pub fn decompose_raw(&self, input: Scalar) -> SignedDecompositionIter<Scalar> {
// Note that there would be no sense of making the decomposition on an input which was
// not rounded to the closest representable first. We then perform it before decomposing.
SignedDecompositionIter::new(
self.closest_representable(input)
>> (Scalar::BITS - (self.level_count * self.base_log)),
DecompositionBaseLog(self.base_log),
DecompositionLevelCount(self.level_count),
)
}
/// Recomposes a decomposed value by summing all the terms.
///
/// If the input iterator yields $\tilde{\theta}\_i$, this returns

View File

@@ -0,0 +1,532 @@
//! Module containing primitives pertaining to [`LWE ciphertext
//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`).
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, ThreadCount,
};
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
/// Keyswitch an [`LWE ciphertext`](`LweCiphertext`) encrypted under an
/// [`LWE secret key`](`LweSecretKey`) to another [`LWE secret key`](`LweSecretKey`).
///
/// # Panics
///
/// Panics if the modulus of the inputs are not power of twos.
/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
/// modulus.
///
/// # Formal Definition
///
/// See [`LWE keyswitch key`](`LweKeyswitchKey#lwe-keyswitch`).
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweKeyswitchKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
/// &input_lwe_secret_key,
/// &output_lwe_secret_key,
/// decomp_base_log,
/// decomp_level_count,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Create the plaintext
/// let msg = 3u64;
/// let plaintext = Plaintext(msg << 60);
///
/// // Create a new LweCiphertext
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
/// &input_lwe_secret_key,
/// plaintext,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut output_lwe = LweCiphertext::new(
/// 0,
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
/// ciphertext_modulus,
/// );
///
/// hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
///
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
///
/// // Round and remove encoding
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
///
/// // Remove the encoding
/// let cleartext = rounded >> 60;
///
/// // Check we recovered the original message
/// assert_eq!(cleartext, msg);
/// ```
pub fn hpu_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
// Clear the output ciphertext, as it will get updated gradually
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
// If the moduli are not the same, we need to round the body in the output ciphertext
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);
*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}
// We instantiate a decomposer
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
);
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
.iter()
.zip(input_lwe_ciphertext.get_mask().as_ref())
{
let decomposition_iter = decomposer.decompose_raw(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value(),
);
}
}
}
/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`].
///
/// This will use all threads available in the current rayon thread pool.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweKeyswitchKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
/// &input_lwe_secret_key,
/// &output_lwe_secret_key,
/// decomp_base_log,
/// decomp_level_count,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Create the plaintext
/// let msg = 3u64;
/// let plaintext = Plaintext(msg << 60);
///
/// // Create a new LweCiphertext
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
/// &input_lwe_secret_key,
/// plaintext,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut output_lwe = LweCiphertext::new(
/// 0,
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
/// ciphertext_modulus,
/// );
///
/// // Use all threads available in the current rayon thread pool
/// par_hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
///
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
///
/// // Round and remove encoding
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
///
/// // Remove the encoding
/// let cleartext = rounded >> 60;
///
/// // Check we recovered the original message
/// assert_eq!(cleartext, msg);
/// ```
pub fn par_hpu_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
let thread_count = ThreadCount(rayon::current_num_threads());
par_hpu_keyswitch_lwe_ciphertext_with_thread_count(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
thread_count,
);
}
/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`].
///
/// This will try to use `thread_count` threads for the computation, if this number is bigger than
/// the available number of threads in the current rayon thread pool then only the number of
/// available threads will be used. Note that `thread_count` cannot be 0.
///
/// # Panics
///
/// Panics if the modulus of the inputs are not power of twos.
/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
/// modulus.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweKeyswitchKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
/// &input_lwe_secret_key,
/// &output_lwe_secret_key,
/// decomp_base_log,
/// decomp_level_count,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Create the plaintext
/// let msg = 3u64;
/// let plaintext = Plaintext(msg << 60);
///
/// // Create a new LweCiphertext
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
/// &input_lwe_secret_key,
/// plaintext,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut output_lwe = LweCiphertext::new(
/// 0,
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
/// ciphertext_modulus,
/// );
///
/// // Try to use 4 threads for the keyswitch if enough are available
/// // in the current rayon thread pool
/// par_hpu_keyswitch_lwe_ciphertext_with_thread_count(
/// &ksk,
/// &input_lwe,
/// &mut output_lwe,
/// ThreadCount(4),
/// );
///
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
///
/// // Round and remove encoding
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
///
/// // Remove the encoding
/// let cleartext = rounded >> 60;
///
/// // Check we recovered the original message
/// assert_eq!(cleartext, msg);
/// ```
pub fn par_hpu_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
thread_count: ThreadCount,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
assert!(
thread_count.0 != 0,
"Got thread_count == 0, this is not supported"
);
// Clear the output ciphertext, as it will get updated gradually
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
let output_lwe_size = output_lwe_ciphertext.lwe_size();
// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
// If the moduli are not the same, we need to round the body in the output ciphertext
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);
*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}
// We instantiate a decomposer
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
);
// Don't go above the current number of threads
let thread_count = thread_count.0.min(rayon::current_num_threads());
let mut intermediate_accumulators = Vec::with_capacity(thread_count);
// Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size
let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
lwe_keyswitch_key
.par_chunks(chunk_size)
.zip(
input_lwe_ciphertext
.get_mask()
.as_ref()
.par_chunks(chunk_size),
)
.map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
let mut buffer =
LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus);
for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
.iter()
.zip(input_mask_element_chunk.iter())
{
let decomposition_iter = decomposer.decompose_raw(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
buffer.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value(),
);
}
}
buffer
})
.collect_into_vec(&mut intermediate_accumulators);
let reduced = intermediate_accumulators
.par_iter_mut()
.reduce_with(|lhs, rhs| {
lhs.as_mut()
.iter_mut()
.zip(rhs.as_ref().iter())
.for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src));
lhs
})
.unwrap();
output_lwe_ciphertext
.get_mut_mask()
.as_mut()
.copy_from_slice(reduced.get_mask().as_ref());
let reduced_ksed_body = *reduced.get_body().data;
// Add the reduced body of the keyswitch to the output body to complete the keyswitch
*output_lwe_ciphertext.get_mut_body().data =
(*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body);
}

View File

@@ -1,2 +1,3 @@
pub mod lwe_keyswitch;
pub mod modswitch;
pub mod order;