mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 14:47:56 -05:00
feat(hpu): Add dedicated keyswitch implementation
Use a dedicated keyswitch implementation that used unbalanced keyswitch. Enable to generate bit-accurate stimulus without the need of a feature flag inside the decomposer implementation.
This commit is contained in:
@@ -13,6 +13,7 @@ use tfhe::core_crypto::entities::{
|
||||
Cleartext, LweCiphertextOwned, LweCiphertextView, LweKeyswitchKey, NttLweBootstrapKey,
|
||||
Plaintext,
|
||||
};
|
||||
use tfhe::core_crypto::hpu::algorithms::lwe_keyswitch::hpu_keyswitch_lwe_ciphertext;
|
||||
use tfhe::core_crypto::hpu::from_with::FromWith;
|
||||
use tfhe::core_crypto::hpu::glwe_lookuptable::create_hpu_lookuptable;
|
||||
use tfhe::core_crypto::prelude::*;
|
||||
@@ -648,7 +649,7 @@ impl HpuSim {
|
||||
|
||||
// TODO add a check on trivialness for fast simulation ?
|
||||
// TODO assert ordering (i.e. KS+PBS)
|
||||
keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks);
|
||||
hpu_keyswitch_lwe_ciphertext(ksk, &cpu_reg, bfr_after_ks);
|
||||
blind_rotate_ntt64_bnf_assign(bfr_after_ks, &mut tfhe_lut, &bsk);
|
||||
|
||||
assert_eq!(
|
||||
|
||||
@@ -587,9 +587,11 @@ pub(crate) fn add_external_product_ntt64_bnf_assign<InputGlweCont>(
|
||||
// DOMAIN In this section, we perform the external product in the ntt
|
||||
// domain, and accumulate the result in the output_fft_buffer variable.
|
||||
let (mut decomposition, substack1) = TensorSignedDecompositionLendingIter::new(
|
||||
glwe.as_ref()
|
||||
.iter()
|
||||
.map(|s| decomposer.init_decomposer_state(*s)),
|
||||
glwe.as_ref().iter().map(|s| {
|
||||
decomposer.closest_representable(*s)
|
||||
>> (u64::BITS
|
||||
- (decomposer.level_count().0 * decomposer.base_log().0) as u32)
|
||||
}),
|
||||
decomposer.base_log(),
|
||||
decomposer.level_count(),
|
||||
substack0,
|
||||
|
||||
@@ -55,7 +55,6 @@ pub fn native_closest_representable<Scalar: UnsignedInteger>(
|
||||
/// returns 1 if the following if condition is true otherwise 0
|
||||
///
|
||||
/// (val > B / 2) || ((val == B / 2) && (random == 1))
|
||||
#[cfg(any(not(feature = "hpu"), test))]
|
||||
#[inline(always)]
|
||||
fn balanced_rounding_condition_bit_trick<Scalar: UnsignedInteger>(
|
||||
val: Scalar,
|
||||
@@ -147,43 +146,33 @@ where
|
||||
|
||||
#[inline(always)]
|
||||
pub fn init_decomposer_state(&self, input: Scalar) -> Scalar {
|
||||
// Default mode -> use balanced decomposition
|
||||
#[cfg(not(feature = "hpu"))]
|
||||
{
|
||||
// The closest number representable by the decomposition can be computed by performing
|
||||
// the rounding at the appropriate bit.
|
||||
// The closest number representable by the decomposition can be computed by performing
|
||||
// the rounding at the appropriate bit.
|
||||
|
||||
// We compute the number of least significant bits which can not be represented by the
|
||||
// decomposition
|
||||
// Example with level_count = 3, base_log = 4 and BITS == 64 -> 52
|
||||
let rep_bit_count = self.level_count * self.base_log;
|
||||
let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count;
|
||||
// Move the representable bits + 1 to the LSB, with our example :
|
||||
// |-----| 64 - (64 - 12 - 1) == 13 bits
|
||||
// 0....0XX...XX
|
||||
let mut res = input >> (non_rep_bit_count - 1);
|
||||
// Fetch the first bit value as we need it for a balanced rounding
|
||||
let rounding_bit = res & Scalar::ONE;
|
||||
// Add one to do the rounding by adding the half interval
|
||||
res += Scalar::ONE;
|
||||
// Discard the LSB which was the one deciding in which direction we round
|
||||
res >>= 1;
|
||||
// Keep the low base_log * level bits
|
||||
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
|
||||
res &= mod_mask;
|
||||
// Control bit about whether we should balance the state
|
||||
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random == 1)
|
||||
let need_balance =
|
||||
balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit);
|
||||
// Balance depending on the control bit
|
||||
res.wrapping_sub(need_balance << rep_bit_count)
|
||||
}
|
||||
|
||||
// Hpu used unbalanced decomposition
|
||||
#[cfg(feature = "hpu")]
|
||||
{
|
||||
self.closest_representable(input) >> (Scalar::BITS - (self.level_count * self.base_log))
|
||||
}
|
||||
// We compute the number of least significant bits which can not be represented by the
|
||||
// decomposition
|
||||
// Example with level_count = 3, base_log = 4 and BITS == 64 -> 52
|
||||
let rep_bit_count = self.level_count * self.base_log;
|
||||
let non_rep_bit_count: usize = Scalar::BITS - rep_bit_count;
|
||||
// Move the representable bits + 1 to the LSB, with our example :
|
||||
// |-----| 64 - (64 - 12 - 1) == 13 bits
|
||||
// 0....0XX...XX
|
||||
let mut res = input >> (non_rep_bit_count - 1);
|
||||
// Fetch the first bit value as we need it for a balanced rounding
|
||||
let rounding_bit = res & Scalar::ONE;
|
||||
// Add one to do the rounding by adding the half interval
|
||||
res += Scalar::ONE;
|
||||
// Discard the LSB which was the one deciding in which direction we round
|
||||
res >>= 1;
|
||||
// Keep the low base_log * level bits
|
||||
let mod_mask = Scalar::MAX >> (Scalar::BITS - rep_bit_count);
|
||||
res &= mod_mask;
|
||||
// Control bit about whether we should balance the state
|
||||
// This is equivalent to res > 2^(base_log * l) || (res == 2^(base_log * l) && random ==
|
||||
// 1)
|
||||
let need_balance = balanced_rounding_condition_bit_trick(res, rep_bit_count, rounding_bit);
|
||||
// Balance depending on the control bit
|
||||
res.wrapping_sub(need_balance << rep_bit_count)
|
||||
}
|
||||
|
||||
/// Generate an iterator over the terms of the decomposition of the input.
|
||||
@@ -228,6 +217,48 @@ where
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate an iterator over the terms of the decomposition of the input.
|
||||
/// # Warning
|
||||
/// This used unbalanced decomposition and shouldn't be used with one-level decomposition
|
||||
/// The returned iterator yields the terms $\tilde{\theta}\_i$ in order of decreasing $i$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
/// use tfhe::core_crypto::commons::numeric::UnsignedInteger;
|
||||
/// use tfhe::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
/// let decomposer =
|
||||
/// SignedDecomposer::<u32>::new(DecompositionBaseLog(4), DecompositionLevelCount(3));
|
||||
/// // 2147483647 == 2^31 - 1 and has a decomposition term == to half_basis
|
||||
/// for term in decomposer.decompose(2147483647u32) {
|
||||
/// assert!(1 <= term.level().0);
|
||||
/// assert!(term.level().0 <= 3);
|
||||
/// let signed_term = term.value().into_signed();
|
||||
/// let half_basis = 2i32.pow(4) / 2i32;
|
||||
/// assert!(
|
||||
/// -half_basis <= signed_term,
|
||||
/// "{} <= {signed_term} failed",
|
||||
/// -half_basis
|
||||
/// );
|
||||
/// assert!(
|
||||
/// signed_term <= half_basis,
|
||||
/// "{signed_term} <= {half_basis} failed"
|
||||
/// );
|
||||
/// }
|
||||
/// assert_eq!(decomposer.decompose(1).count(), 3);
|
||||
/// ```
|
||||
pub fn decompose_raw(&self, input: Scalar) -> SignedDecompositionIter<Scalar> {
|
||||
// Note that there would be no sense of making the decomposition on an input which was
|
||||
// not rounded to the closest representable first. We then perform it before decomposing.
|
||||
SignedDecompositionIter::new(
|
||||
self.closest_representable(input)
|
||||
>> (Scalar::BITS - (self.level_count * self.base_log)),
|
||||
DecompositionBaseLog(self.base_log),
|
||||
DecompositionLevelCount(self.level_count),
|
||||
)
|
||||
}
|
||||
|
||||
/// Recomposes a decomposed value by summing all the terms.
|
||||
///
|
||||
/// If the input iterator yields $\tilde{\theta}\_i$, this returns
|
||||
|
||||
532
tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs
Normal file
532
tfhe/src/core_crypto/hpu/algorithms/lwe_keyswitch.rs
Normal file
@@ -0,0 +1,532 @@
|
||||
//! Module containing primitives pertaining to [`LWE ciphertext
|
||||
//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`).
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
DecompositionBaseLog, DecompositionLevelCount, ThreadCount,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
use rayon::prelude::*;
|
||||
|
||||
/// Keyswitch an [`LWE ciphertext`](`LweCiphertext`) encrypted under an
|
||||
/// [`LWE secret key`](`LweSecretKey`) to another [`LWE secret key`](`LweSecretKey`).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the modulus of the inputs are not power of twos.
|
||||
/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
|
||||
/// modulus.
|
||||
///
|
||||
/// # Formal Definition
|
||||
///
|
||||
/// See [`LWE keyswitch key`](`LweKeyswitchKey#lwe-keyswitch`).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
///
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
/// // computations
|
||||
/// // Define parameters for LweKeyswitchKey creation
|
||||
/// let input_lwe_dimension = LweDimension(742);
|
||||
/// let lwe_noise_distribution =
|
||||
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
|
||||
/// let output_lwe_dimension = LweDimension(2048);
|
||||
/// let decomp_base_log = DecompositionBaseLog(3);
|
||||
/// let decomp_level_count = DecompositionLevelCount(5);
|
||||
/// let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
///
|
||||
/// // Create the PRNG
|
||||
/// let mut seeder = new_seeder();
|
||||
/// let seeder = seeder.as_mut();
|
||||
/// let mut encryption_generator =
|
||||
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
|
||||
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
|
||||
///
|
||||
/// // Create the LweSecretKey
|
||||
/// let input_lwe_secret_key =
|
||||
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
|
||||
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
/// output_lwe_dimension,
|
||||
/// &mut secret_generator,
|
||||
/// );
|
||||
///
|
||||
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
/// &input_lwe_secret_key,
|
||||
/// &output_lwe_secret_key,
|
||||
/// decomp_base_log,
|
||||
/// decomp_level_count,
|
||||
/// lwe_noise_distribution,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// // Create the plaintext
|
||||
/// let msg = 3u64;
|
||||
/// let plaintext = Plaintext(msg << 60);
|
||||
///
|
||||
/// // Create a new LweCiphertext
|
||||
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
/// &input_lwe_secret_key,
|
||||
/// plaintext,
|
||||
/// lwe_noise_distribution,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let mut output_lwe = LweCiphertext::new(
|
||||
/// 0,
|
||||
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
/// ciphertext_modulus,
|
||||
/// );
|
||||
///
|
||||
/// hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
|
||||
///
|
||||
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
|
||||
///
|
||||
/// // Round and remove encoding
|
||||
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
|
||||
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
///
|
||||
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
|
||||
///
|
||||
/// // Remove the encoding
|
||||
/// let cleartext = rounded >> 60;
|
||||
///
|
||||
/// // Check we recovered the original message
|
||||
/// assert_eq!(cleartext, msg);
|
||||
/// ```
|
||||
pub fn hpu_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
|
||||
input_lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
KSKCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension()
|
||||
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
"Mismatched input LweDimension. \
|
||||
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
|
||||
lwe_keyswitch_key.input_key_lwe_dimension(),
|
||||
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
);
|
||||
assert!(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension()
|
||||
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
"Mismatched output LweDimension. \
|
||||
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
);
|
||||
|
||||
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert_eq!(
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_ciphertext_modulus,
|
||||
"Mismatched CiphertextModulus. \
|
||||
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_ciphertext_modulus
|
||||
);
|
||||
assert!(
|
||||
output_ciphertext_modulus.is_compatible_with_native_modulus(),
|
||||
"This operation currently only supports power of 2 moduli"
|
||||
);
|
||||
|
||||
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert!(
|
||||
input_ciphertext_modulus.is_compatible_with_native_modulus(),
|
||||
"This operation currently only supports power of 2 moduli"
|
||||
);
|
||||
|
||||
// Clear the output ciphertext, as it will get updated gradually
|
||||
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
// Copy the input body to the output ciphertext
|
||||
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
|
||||
|
||||
// If the moduli are not the same, we need to round the body in the output ciphertext
|
||||
if output_ciphertext_modulus != input_ciphertext_modulus
|
||||
&& !output_ciphertext_modulus.is_native_modulus()
|
||||
{
|
||||
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
|
||||
let output_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(modulus_bits),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
|
||||
*output_lwe_ciphertext.get_mut_body().data =
|
||||
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
|
||||
}
|
||||
|
||||
// We instantiate a decomposer
|
||||
let decomposer = SignedDecomposer::new(
|
||||
lwe_keyswitch_key.decomposition_base_log(),
|
||||
lwe_keyswitch_key.decomposition_level_count(),
|
||||
);
|
||||
|
||||
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
|
||||
.iter()
|
||||
.zip(input_lwe_ciphertext.get_mask().as_ref())
|
||||
{
|
||||
let decomposition_iter = decomposer.decompose_raw(input_mask_element);
|
||||
// Loop over the levels
|
||||
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
|
||||
{
|
||||
slice_wrapping_sub_scalar_mul_assign(
|
||||
output_lwe_ciphertext.as_mut(),
|
||||
level_key_ciphertext.as_ref(),
|
||||
decomposed.value(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`].
|
||||
///
|
||||
/// This will use all threads available in the current rayon thread pool.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
///
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
/// // computations
|
||||
/// // Define parameters for LweKeyswitchKey creation
|
||||
/// let input_lwe_dimension = LweDimension(742);
|
||||
/// let lwe_noise_distribution =
|
||||
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
|
||||
/// let output_lwe_dimension = LweDimension(2048);
|
||||
/// let decomp_base_log = DecompositionBaseLog(3);
|
||||
/// let decomp_level_count = DecompositionLevelCount(5);
|
||||
/// let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
///
|
||||
/// // Create the PRNG
|
||||
/// let mut seeder = new_seeder();
|
||||
/// let seeder = seeder.as_mut();
|
||||
/// let mut encryption_generator =
|
||||
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
|
||||
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
|
||||
///
|
||||
/// // Create the LweSecretKey
|
||||
/// let input_lwe_secret_key =
|
||||
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
|
||||
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
/// output_lwe_dimension,
|
||||
/// &mut secret_generator,
|
||||
/// );
|
||||
///
|
||||
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
/// &input_lwe_secret_key,
|
||||
/// &output_lwe_secret_key,
|
||||
/// decomp_base_log,
|
||||
/// decomp_level_count,
|
||||
/// lwe_noise_distribution,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// // Create the plaintext
|
||||
/// let msg = 3u64;
|
||||
/// let plaintext = Plaintext(msg << 60);
|
||||
///
|
||||
/// // Create a new LweCiphertext
|
||||
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
/// &input_lwe_secret_key,
|
||||
/// plaintext,
|
||||
/// lwe_noise_distribution,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let mut output_lwe = LweCiphertext::new(
|
||||
/// 0,
|
||||
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
/// ciphertext_modulus,
|
||||
/// );
|
||||
///
|
||||
/// // Use all threads available in the current rayon thread pool
|
||||
/// par_hpu_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
|
||||
///
|
||||
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
|
||||
///
|
||||
/// // Round and remove encoding
|
||||
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
|
||||
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
///
|
||||
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
|
||||
///
|
||||
/// // Remove the encoding
|
||||
/// let cleartext = rounded >> 60;
|
||||
///
|
||||
/// // Check we recovered the original message
|
||||
/// assert_eq!(cleartext, msg);
|
||||
/// ```
|
||||
pub fn par_hpu_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
|
||||
input_lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger + Send + Sync,
|
||||
KSKCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let thread_count = ThreadCount(rayon::current_num_threads());
|
||||
par_hpu_keyswitch_lwe_ciphertext_with_thread_count(
|
||||
lwe_keyswitch_key,
|
||||
input_lwe_ciphertext,
|
||||
output_lwe_ciphertext,
|
||||
thread_count,
|
||||
);
|
||||
}
|
||||
|
||||
/// Parallel variant of [`hpu_keyswitch_lwe_ciphertext`].
|
||||
///
|
||||
/// This will try to use `thread_count` threads for the computation, if this number is bigger than
|
||||
/// the available number of threads in the current rayon thread pool then only the number of
|
||||
/// available threads will be used. Note that `thread_count` cannot be 0.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the modulus of the inputs are not power of twos.
|
||||
/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
|
||||
/// modulus.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::prelude::*;
|
||||
///
|
||||
/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
|
||||
/// // computations
|
||||
/// // Define parameters for LweKeyswitchKey creation
|
||||
/// let input_lwe_dimension = LweDimension(742);
|
||||
/// let lwe_noise_distribution =
|
||||
/// Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
|
||||
/// let output_lwe_dimension = LweDimension(2048);
|
||||
/// let decomp_base_log = DecompositionBaseLog(3);
|
||||
/// let decomp_level_count = DecompositionLevelCount(5);
|
||||
/// let ciphertext_modulus = CiphertextModulus::new_native();
|
||||
///
|
||||
/// // Create the PRNG
|
||||
/// let mut seeder = new_seeder();
|
||||
/// let seeder = seeder.as_mut();
|
||||
/// let mut encryption_generator =
|
||||
/// EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
|
||||
/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
|
||||
///
|
||||
/// // Create the LweSecretKey
|
||||
/// let input_lwe_secret_key =
|
||||
/// allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
|
||||
/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
/// output_lwe_dimension,
|
||||
/// &mut secret_generator,
|
||||
/// );
|
||||
///
|
||||
/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
|
||||
/// &input_lwe_secret_key,
|
||||
/// &output_lwe_secret_key,
|
||||
/// decomp_base_log,
|
||||
/// decomp_level_count,
|
||||
/// lwe_noise_distribution,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// // Create the plaintext
|
||||
/// let msg = 3u64;
|
||||
/// let plaintext = Plaintext(msg << 60);
|
||||
///
|
||||
/// // Create a new LweCiphertext
|
||||
/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
/// &input_lwe_secret_key,
|
||||
/// plaintext,
|
||||
/// lwe_noise_distribution,
|
||||
/// ciphertext_modulus,
|
||||
/// &mut encryption_generator,
|
||||
/// );
|
||||
///
|
||||
/// let mut output_lwe = LweCiphertext::new(
|
||||
/// 0,
|
||||
/// output_lwe_secret_key.lwe_dimension().to_lwe_size(),
|
||||
/// ciphertext_modulus,
|
||||
/// );
|
||||
///
|
||||
/// // Try to use 4 threads for the keyswitch if enough are available
|
||||
/// // in the current rayon thread pool
|
||||
/// par_hpu_keyswitch_lwe_ciphertext_with_thread_count(
|
||||
/// &ksk,
|
||||
/// &input_lwe,
|
||||
/// &mut output_lwe,
|
||||
/// ThreadCount(4),
|
||||
/// );
|
||||
///
|
||||
/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
|
||||
///
|
||||
/// // Round and remove encoding
|
||||
/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
|
||||
/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
|
||||
///
|
||||
/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
|
||||
///
|
||||
/// // Remove the encoding
|
||||
/// let cleartext = rounded >> 60;
|
||||
///
|
||||
/// // Check we recovered the original message
|
||||
/// assert_eq!(cleartext, msg);
|
||||
/// ```
|
||||
pub fn par_hpu_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
|
||||
input_lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
|
||||
thread_count: ThreadCount,
|
||||
) where
|
||||
Scalar: UnsignedInteger + Send + Sync,
|
||||
KSKCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension()
|
||||
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
"Mismatched input LweDimension. \
|
||||
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
|
||||
lwe_keyswitch_key.input_key_lwe_dimension(),
|
||||
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
);
|
||||
assert!(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension()
|
||||
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
"Mismatched output LweDimension. \
|
||||
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
);
|
||||
|
||||
let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert_eq!(
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_ciphertext_modulus,
|
||||
"Mismatched CiphertextModulus. \
|
||||
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_ciphertext_modulus
|
||||
);
|
||||
assert!(
|
||||
output_ciphertext_modulus.is_compatible_with_native_modulus(),
|
||||
"This operation currently only supports power of 2 moduli"
|
||||
);
|
||||
|
||||
let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert!(
|
||||
input_ciphertext_modulus.is_compatible_with_native_modulus(),
|
||||
"This operation currently only supports power of 2 moduli"
|
||||
);
|
||||
|
||||
assert!(
|
||||
thread_count.0 != 0,
|
||||
"Got thread_count == 0, this is not supported"
|
||||
);
|
||||
|
||||
// Clear the output ciphertext, as it will get updated gradually
|
||||
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
let output_lwe_size = output_lwe_ciphertext.lwe_size();
|
||||
|
||||
// Copy the input body to the output ciphertext
|
||||
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
|
||||
|
||||
// If the moduli are not the same, we need to round the body in the output ciphertext
|
||||
if output_ciphertext_modulus != input_ciphertext_modulus
|
||||
&& !output_ciphertext_modulus.is_native_modulus()
|
||||
{
|
||||
let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
|
||||
let output_decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(modulus_bits),
|
||||
DecompositionLevelCount(1),
|
||||
);
|
||||
|
||||
*output_lwe_ciphertext.get_mut_body().data =
|
||||
output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
|
||||
}
|
||||
|
||||
// We instantiate a decomposer
|
||||
let decomposer = SignedDecomposer::new(
|
||||
lwe_keyswitch_key.decomposition_base_log(),
|
||||
lwe_keyswitch_key.decomposition_level_count(),
|
||||
);
|
||||
|
||||
// Don't go above the current number of threads
|
||||
let thread_count = thread_count.0.min(rayon::current_num_threads());
|
||||
let mut intermediate_accumulators = Vec::with_capacity(thread_count);
|
||||
|
||||
// Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size
|
||||
let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
|
||||
|
||||
lwe_keyswitch_key
|
||||
.par_chunks(chunk_size)
|
||||
.zip(
|
||||
input_lwe_ciphertext
|
||||
.get_mask()
|
||||
.as_ref()
|
||||
.par_chunks(chunk_size),
|
||||
)
|
||||
.map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
|
||||
let mut buffer =
|
||||
LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus);
|
||||
|
||||
for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
|
||||
.iter()
|
||||
.zip(input_mask_element_chunk.iter())
|
||||
{
|
||||
let decomposition_iter = decomposer.decompose_raw(input_mask_element);
|
||||
// Loop over the levels
|
||||
for (level_key_ciphertext, decomposed) in
|
||||
keyswitch_key_block.iter().zip(decomposition_iter)
|
||||
{
|
||||
slice_wrapping_sub_scalar_mul_assign(
|
||||
buffer.as_mut(),
|
||||
level_key_ciphertext.as_ref(),
|
||||
decomposed.value(),
|
||||
);
|
||||
}
|
||||
}
|
||||
buffer
|
||||
})
|
||||
.collect_into_vec(&mut intermediate_accumulators);
|
||||
|
||||
let reduced = intermediate_accumulators
|
||||
.par_iter_mut()
|
||||
.reduce_with(|lhs, rhs| {
|
||||
lhs.as_mut()
|
||||
.iter_mut()
|
||||
.zip(rhs.as_ref().iter())
|
||||
.for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src));
|
||||
|
||||
lhs
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
output_lwe_ciphertext
|
||||
.get_mut_mask()
|
||||
.as_mut()
|
||||
.copy_from_slice(reduced.get_mask().as_ref());
|
||||
let reduced_ksed_body = *reduced.get_body().data;
|
||||
|
||||
// Add the reduced body of the keyswitch to the output body to complete the keyswitch
|
||||
*output_lwe_ciphertext.get_mut_body().data =
|
||||
(*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body);
|
||||
}
|
||||
@@ -1,2 +1,3 @@
|
||||
pub mod lwe_keyswitch;
|
||||
pub mod modswitch;
|
||||
pub mod order;
|
||||
|
||||
Reference in New Issue
Block a user