fix(hpu): Remove dedicated decomposition changes for RTL

Indeed, there is an error in the RTL implementation that prevent decomposition
to be balanced.
Now RTL use same decomposition as SW and thus, there is no more need of a
dedicated keyswitch implementation
This commit is contained in:
Baptiste Roux
2025-04-29 09:23:59 +02:00
parent 1c0202562b
commit 4bbe002d8b
4 changed files with 4 additions and 540 deletions

View File

@@ -13,7 +13,6 @@ use tfhe::core_crypto::entities::{
Cleartext, LweCiphertextOwned, LweCiphertextView, LweKeyswitchKey, NttLweBootstrapKey,
Plaintext,
};
use tfhe::core_crypto::hpu::algorithms::lwe_keyswitch::hpu_keyswitch_lwe_ciphertext;
use tfhe::core_crypto::hpu::glwe_lookuptable::create_hpu_lookuptable;
use tfhe::core_crypto::prelude::*;
use tfhe::shortint::prelude::ClassicPBSParameters;
@@ -654,7 +653,7 @@ impl HpuSim {
// TODO add a check on trivialness for fast simulation ?
// TODO assert ordering (i.e. KS+PBS)
hpu_keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks);
keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks);
blind_rotate_ntt64_bnf_assign(bfr_after_ks, &mut tfhe_lut, &bsk);
assert_eq!(

View File

@@ -584,11 +584,9 @@ pub(crate) fn add_external_product_ntt64_bnf_assign<InputGlweCont>(
// DOMAIN In this section, we perform the external product in the ntt
// domain, and accumulate the result in the output_fft_buffer variable.
let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new(
glwe.as_ref().iter().map(|s| {
decomposer.closest_representable(*s)
>> (u64::BITS
- (decomposer.level_count().0 * decomposer.base_log().0) as u32)
}),
glwe.as_ref()
.iter()
.map(|s| decomposer.init_decomposer_state(*s)),
decomposer.base_log(),
decomposer.level_count(),
substack0,

View File

@@ -1,532 +0,0 @@
//! Module containing primitives pertaining to [`LWE ciphertext
//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`).
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::parameters::{
DecompositionBaseLog, DecompositionLevelCount, ThreadCount,
};
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
/// Keyswitch an [`LWE ciphertext`](`LweCiphertext`) encrypted under an
/// [`LWE secret key`](`LweSecretKey`) to another [`LWE secret key`](`LweSecretKey`).
///
/// # Panics
///
/// Panics if the modulus of the inputs are not power of twos.
/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
/// modulus.
///
/// # Formal Definition
///
/// See [`LWE keyswitch key`](`LweKeyswitchKey#lwe-keyswitch`).
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweKeyswitchKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
/// &input_lwe_secret_key,
/// &output_lwe_secret_key,
/// decomp_base_log,
/// decomp_level_count,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Create the plaintext
/// let msg = 3u64;
/// let plaintext = Plaintext(msg << 60);
///
/// // Create a new LweCiphertext
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
/// &input_lwe_secret_key,
/// plaintext,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut output_lwe = LweCiphertext::new(
/// 0,
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
/// ciphertext_modulus,
/// );
///
/// hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
///
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
///
/// // Round and remove encoding
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
///
/// // Remove the encoding
/// let cleartext = rounded >> 60;
///
/// // Check we recovered the original message
/// assert_eq!(cleartext, msg);
/// ```
pub fn hpu_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
// Clear the output ciphertext, as it will get updated gradually
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
// If the moduli are not the same, we need to round the body in the output ciphertext
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);
*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}
// We instantiate a decomposer
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
);
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
.iter()
.zip(input_lwe_ciphertext.get_mask().as_ref())
{
let decomposition_iter = decomposer.decompose_raw(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
output_lwe_ciphertext.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value(),
);
}
}
}
/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`].
///
/// This will use all threads available in the current rayon thread pool.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweKeyswitchKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
/// &input_lwe_secret_key,
/// &output_lwe_secret_key,
/// decomp_base_log,
/// decomp_level_count,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Create the plaintext
/// let msg = 3u64;
/// let plaintext = Plaintext(msg << 60);
///
/// // Create a new LweCiphertext
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
/// &input_lwe_secret_key,
/// plaintext,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut output_lwe = LweCiphertext::new(
/// 0,
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
/// ciphertext_modulus,
/// );
///
/// // Use all threads available in the current rayon thread pool
/// par_hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
///
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
///
/// // Round and remove encoding
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
///
/// // Remove the encoding
/// let cleartext = rounded >> 60;
///
/// // Check we recovered the original message
/// assert_eq!(cleartext, msg);
/// ```
pub fn par_hpu_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
let thread_count = ThreadCount(rayon::current_num_threads());
par_hpu_keyswitch_lwe_ciphertext_with_thread_count(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
thread_count,
);
}
/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`].
///
/// This will try to use `thread_count` threads for the computation, if this number is bigger than
/// the available number of threads in the current rayon thread pool then only the number of
/// available threads will be used. Note that `thread_count` cannot be 0.
///
/// # Panics
///
/// Panics if the modulus of the inputs are not power of twos.
/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
/// modulus.
///
/// # Example
///
/// ```rust
/// use tfhe::core_crypto::prelude::*;
///
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
/// // computations
/// // Define parameters for LweKeyswitchKey creation
/// let input_lwe_dimension = LweDimension(742);
/// let lwe_noise_distribution =
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
/// let output_lwe_dimension = LweDimension(2048);
/// let decomp_base_log = DecompositionBaseLog(3);
/// let decomp_level_count = DecompositionLevelCount(5);
/// let ciphertext_modulus = CiphertextModulus::new_native();
///
/// // Create the PRNG
/// let mut seeder = new_seeder();
/// let seeder = seeder.as_mut();
/// let mut encryption_generator =
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
///
/// // Create the LweSecretKey
/// let input_lwe_secret_key =
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
/// output_lwe_dimension,
/// &mut secret_generator,
/// );
///
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
/// &input_lwe_secret_key,
/// &output_lwe_secret_key,
/// decomp_base_log,
/// decomp_level_count,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// // Create the plaintext
/// let msg = 3u64;
/// let plaintext = Plaintext(msg << 60);
///
/// // Create a new LweCiphertext
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
/// &input_lwe_secret_key,
/// plaintext,
/// lwe_noise_distribution,
/// ciphertext_modulus,
/// &mut encryption_generator,
/// );
///
/// let mut output_lwe = LweCiphertext::new(
/// 0,
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
/// ciphertext_modulus,
/// );
///
/// // Try to use 4 threads for the keyswitch if enough are available
/// // in the current rayon thread pool
/// par_hpu_keyswitch_lwe_ciphertext_with_thread_count(
/// &ksk,
/// &input_lwe,
/// &mut output_lwe,
/// ThreadCount(4),
/// );
///
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
///
/// // Round and remove encoding
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
///
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
///
/// // Remove the encoding
/// let cleartext = rounded >> 60;
///
/// // Check we recovered the original message
/// assert_eq!(cleartext, msg);
/// ```
pub fn par_hpu_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
thread_count: ThreadCount,
) where
Scalar: UnsignedInteger + Send + Sync,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus,
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_ciphertext_modulus
);
assert!(
output_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
assert!(
input_ciphertext_modulus.is_compatible_with_native_modulus(),
"This operation currently only supports power of 2 moduli"
);
assert!(
thread_count.0 != 0,
"Got thread_count == 0, this is not supported"
);
// Clear the output ciphertext, as it will get updated gradually
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
let output_lwe_size = output_lwe_ciphertext.lwe_size();
// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
// If the moduli are not the same, we need to round the body in the output ciphertext
if output_ciphertext_modulus != input_ciphertext_modulus
&& !output_ciphertext_modulus.is_native_modulus()
{
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
let output_decomposer = SignedDecomposer::new(
DecompositionBaseLog(modulus_bits),
DecompositionLevelCount(1),
);
*output_lwe_ciphertext.get_mut_body().data =
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
}
// We instantiate a decomposer
let decomposer = SignedDecomposer::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
);
// Don't go above the current number of threads
let thread_count = thread_count.0.min(rayon::current_num_threads());
let mut intermediate_accumulators = Vec::with_capacity(thread_count);
// Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size
let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
lwe_keyswitch_key
.par_chunks(chunk_size)
.zip(
input_lwe_ciphertext
.get_mask()
.as_ref()
.par_chunks(chunk_size),
)
.map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
let mut buffer =
LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus);
for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
.iter()
.zip(input_mask_element_chunk.iter())
{
let decomposition_iter = decomposer.decompose_raw(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in
keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign(
buffer.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value(),
);
}
}
buffer
})
.collect_into_vec(&mut intermediate_accumulators);
let reduced = intermediate_accumulators
.par_iter_mut()
.reduce_with(|lhs, rhs| {
lhs.as_mut()
.iter_mut()
.zip(rhs.as_ref().iter())
.for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src));
lhs
})
.unwrap();
output_lwe_ciphertext
.get_mut_mask()
.as_mut()
.copy_from_slice(reduced.get_mask().as_ref());
let reduced_ksed_body = *reduced.get_body().data;
// Add the reduced body of the keyswitch to the output body to complete the keyswitch
*output_lwe_ciphertext.get_mut_body().data =
(*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body);
}

View File

@@ -1,3 +1,2 @@
pub mod lwe_keyswitch;
pub mod modswitch;
pub mod order;