refactor(core): factorize multiplicative factor code for GGSW encryption

- some code was repeated several times, factorize it out in a function
This commit is contained in:
Arthur Meyre
2024-05-27 11:22:07 +02:00
parent 8a31abfca4
commit dc0d72436d

View File

@@ -4,17 +4,45 @@
use crate::core_crypto::algorithms::misc::divide_round;
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::ciphertext_modulus::{CiphertextModulus, CiphertextModulusKind};
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
use crate::core_crypto::commons::math::decomposition::{
DecompositionLevel, DecompositionTermNonNative, SignedDecomposer,
DecompositionLevel, DecompositionTerm, DecompositionTermNonNative, SignedDecomposer,
};
use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, Distribution, Uniform};
use crate::core_crypto::commons::parameters::PlaintextCount;
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, PlaintextCount};
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
use rayon::prelude::*;
/// Compute the multiplicative factor for a GGSW encryption based on an input value and GGSW
/// encryption parameters.
pub fn ggsw_encryption_multiplicative_factor<Scalar: UnsignedInteger>(
ciphertext_modulus: CiphertextModulus<Scalar>,
decomp_level: DecompositionLevel,
decomp_base_log: DecompositionBaseLog,
encoded: Plaintext<Scalar>,
) -> Scalar {
match ciphertext_modulus.kind() {
CiphertextModulusKind::Other => DecompositionTermNonNative::new(
decomp_level,
decomp_base_log,
encoded.0.wrapping_neg(),
ciphertext_modulus,
)
.to_approximate_recomposition_summand(),
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo => {
let native_decomp_term =
DecompositionTerm::new(decomp_level, decomp_base_log, encoded.0.wrapping_neg())
.to_recomposition_summand();
// We scale the factor down from the native torus to whatever our power of 2 torus is,
// the encryption process will scale it back up
native_decomp_term
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
}
}
}
/// Encrypt a plaintext in a [`GGSW ciphertext`](`GgswCiphertext`) in the constant coefficient.
///
/// See the [`GGSW ciphertext formal definition`](`GgswCiphertext#ggsw-encryption`) for the
@@ -114,29 +142,12 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, NoiseDistribution, KeyCont, Outp
output.iter_mut().zip(gen_iter).enumerate()
{
let decomp_level = DecompositionLevel(level_index + 1);
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
let factor = match ciphertext_modulus.kind() {
CiphertextModulusKind::Other => DecompositionTermNonNative::new(
decomp_level,
decomp_base_log,
encoded.0.wrapping_neg(),
ciphertext_modulus,
)
.to_approximate_recomposition_summand(),
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
{
encoded
.0
.wrapping_neg()
.wrapping_mul(
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
)
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
}
};
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
decomp_base_log,
encoded,
);
// We iterate over the rows of the level matrix, the last row needs special treatment
let gen_iter = generator
@@ -263,29 +274,12 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, NoiseDistribution, KeyCont,
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
let factor = match ciphertext_modulus.kind() {
CiphertextModulusKind::Other => DecompositionTermNonNative::new(
decomp_level,
decomp_base_log,
encoded.0.wrapping_neg(),
ciphertext_modulus,
)
.to_approximate_recomposition_summand(),
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
{
encoded
.0
.wrapping_neg()
.wrapping_mul(
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
)
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
}
};
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
decomp_base_log,
encoded,
);
// We iterate over the rows of the level matrix, the last row needs special
// treatment
@@ -414,29 +408,12 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
output.iter_mut().zip(gen_iter).enumerate()
{
let decomp_level = DecompositionLevel(level_index + 1);
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
let factor = match ciphertext_modulus.kind() {
CiphertextModulusKind::Other => DecompositionTermNonNative::new(
decomp_level,
decomp_base_log,
encoded.0.wrapping_neg(),
ciphertext_modulus,
)
.to_approximate_recomposition_summand(),
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
{
encoded
.0
.wrapping_neg()
.wrapping_mul(
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
)
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
}
};
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
decomp_base_log,
encoded,
);
// We iterate over the rows of the level matrix, the last row needs special treatment
let gen_iter = loop_generator
@@ -608,29 +585,12 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|(level_index, (mut level_matrix, mut generator))| {
let decomp_level = DecompositionLevel(level_index + 1);
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
let factor = match ciphertext_modulus.kind() {
CiphertextModulusKind::Other => DecompositionTermNonNative::new(
decomp_level,
decomp_base_log,
encoded.0.wrapping_neg(),
ciphertext_modulus,
)
.to_approximate_recomposition_summand(),
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
// We scale the factor down from the native torus to whatever our torus is, the
// encryption process will scale it back up
{
encoded
.0
.wrapping_neg()
.wrapping_mul(
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
)
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
}
};
let factor = ggsw_encryption_multiplicative_factor(
ciphertext_modulus,
decomp_level,
decomp_base_log,
encoded,
);
// We iterate over the rows of the level matrix, the last row needs special treatment
let gen_iter = generator