mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
Compare commits
7 Commits
tfhe-test-
...
cb/feat/su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5522a89631 | ||
|
|
b7b8fbf657 | ||
|
|
0b81a2db58 | ||
|
|
dd7489692f | ||
|
|
ef55a9e076 | ||
|
|
c6d1f0189c | ||
|
|
48d92b8e69 |
@@ -3,9 +3,12 @@
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::dispersion::DispersionParameter;
|
||||
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
|
||||
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, SignedDecomposer};
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
DecompositionLevel, DecompositionTermNonNative, SignedDecomposer, SignedDecomposerNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
|
||||
use crate::core_crypto::commons::parameters::PlaintextCount;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
@@ -115,13 +118,29 @@ pub fn encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
output.iter_mut().zip(gen_iter).enumerate()
|
||||
{
|
||||
let decomp_level = DecompositionLevel(level_index + 1);
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
let factor = encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
let factor = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => DecompositionTermNonNative::new(
|
||||
decomp_level,
|
||||
decomp_base_log,
|
||||
encoded
|
||||
.0
|
||||
.wrapping_neg_custom_mod(ciphertext_modulus.get_custom_modulus().cast_into()),
|
||||
ciphertext_modulus,
|
||||
)
|
||||
.to_recomposition_summand(),
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
{
|
||||
encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(
|
||||
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
|
||||
)
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
|
||||
}
|
||||
};
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special treatment
|
||||
let gen_iter = generator
|
||||
@@ -252,13 +271,29 @@ pub fn par_encrypt_constant_ggsw_ciphertext<Scalar, KeyCont, OutputCont, Gen>(
|
||||
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|
||||
|(level_index, (mut level_matrix, mut generator))| {
|
||||
let decomp_level = DecompositionLevel(level_index + 1);
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
let factor = encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
let factor = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => DecompositionTermNonNative::new(
|
||||
decomp_level,
|
||||
decomp_base_log,
|
||||
encoded.0.wrapping_neg_custom_mod(
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
),
|
||||
ciphertext_modulus,
|
||||
)
|
||||
.to_recomposition_summand(),
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
{
|
||||
encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(
|
||||
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
|
||||
)
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
|
||||
}
|
||||
};
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special
|
||||
// treatment
|
||||
@@ -315,15 +350,35 @@ fn encrypt_constant_ggsw_level_matrix_row<Scalar, KeyCont, OutputCont, Gen>(
|
||||
let mut body = row_as_glwe.get_mut_body();
|
||||
body.as_mut().copy_from_slice(sk_poly.as_ref());
|
||||
|
||||
slice_wrapping_scalar_mul_assign(body.as_mut(), factor);
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
|
||||
match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => slice_wrapping_scalar_mul_assign_custom_mod(
|
||||
body.as_mut(),
|
||||
factor,
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
),
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
slice_wrapping_scalar_mul_assign(body.as_mut(), factor)
|
||||
}
|
||||
}
|
||||
|
||||
encrypt_glwe_ciphertext_assign(glwe_secret_key, row_as_glwe, noise_parameters, generator);
|
||||
} else {
|
||||
// The last row needs a slightly different treatment
|
||||
let mut body = row_as_glwe.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
|
||||
body.as_mut().fill(Scalar::ZERO);
|
||||
body.as_mut()[0] = factor.wrapping_neg();
|
||||
let encoded = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => {
|
||||
factor.wrapping_neg_custom_mod(ciphertext_modulus.get_custom_modulus().cast_into())
|
||||
}
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
factor.wrapping_neg()
|
||||
}
|
||||
};
|
||||
body.as_mut()[0] = encoded;
|
||||
|
||||
encrypt_glwe_ciphertext_assign(glwe_secret_key, row_as_glwe, noise_parameters, generator);
|
||||
}
|
||||
@@ -371,13 +426,29 @@ pub fn encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
|
||||
output.iter_mut().zip(gen_iter).enumerate()
|
||||
{
|
||||
let decomp_level = DecompositionLevel(level_index + 1);
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
let factor = encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
let factor = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => DecompositionTermNonNative::new(
|
||||
decomp_level,
|
||||
decomp_base_log,
|
||||
encoded
|
||||
.0
|
||||
.wrapping_neg_custom_mod(ciphertext_modulus.get_custom_modulus().cast_into()),
|
||||
ciphertext_modulus,
|
||||
)
|
||||
.to_recomposition_summand(),
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
{
|
||||
encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(
|
||||
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
|
||||
)
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
|
||||
}
|
||||
};
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special treatment
|
||||
let gen_iter = loop_generator
|
||||
@@ -545,13 +616,29 @@ pub fn par_encrypt_constant_seeded_ggsw_ciphertext_with_existing_generator<
|
||||
output.par_iter_mut().zip(gen_iter).enumerate().for_each(
|
||||
|(level_index, (mut level_matrix, mut generator))| {
|
||||
let decomp_level = DecompositionLevel(level_index + 1);
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
let factor = encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
let factor = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => DecompositionTermNonNative::new(
|
||||
decomp_level,
|
||||
decomp_base_log,
|
||||
encoded.0.wrapping_neg_custom_mod(
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
),
|
||||
ciphertext_modulus,
|
||||
)
|
||||
.to_recomposition_summand(),
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo =>
|
||||
// We scale the factor down from the native torus to whatever our torus is, the
|
||||
// encryption process will scale it back up
|
||||
{
|
||||
encoded
|
||||
.0
|
||||
.wrapping_neg()
|
||||
.wrapping_mul(
|
||||
Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)),
|
||||
)
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
|
||||
}
|
||||
};
|
||||
|
||||
// We iterate over the rows of the level matrix, the last row needs special treatment
|
||||
let gen_iter = generator
|
||||
@@ -708,7 +795,18 @@ fn encrypt_constant_seeded_ggsw_level_matrix_row<Scalar, KeyCont, OutputCont, Ge
|
||||
let mut body = row_as_glwe.get_mut_body();
|
||||
body.as_mut().copy_from_slice(sk_poly.as_ref());
|
||||
|
||||
slice_wrapping_scalar_mul_assign(body.as_mut(), factor);
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
|
||||
match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => slice_wrapping_scalar_mul_assign_custom_mod(
|
||||
body.as_mut(),
|
||||
factor,
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
),
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
slice_wrapping_scalar_mul_assign(body.as_mut(), factor)
|
||||
}
|
||||
}
|
||||
|
||||
encrypt_seeded_glwe_ciphertext_assign_with_existing_generator(
|
||||
glwe_secret_key,
|
||||
@@ -719,9 +817,18 @@ fn encrypt_constant_seeded_ggsw_level_matrix_row<Scalar, KeyCont, OutputCont, Ge
|
||||
} else {
|
||||
// The last row needs a slightly different treatment
|
||||
let mut body = row_as_glwe.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
|
||||
body.as_mut().fill(Scalar::ZERO);
|
||||
body.as_mut()[0] = factor.wrapping_neg();
|
||||
let encoded = match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => {
|
||||
factor.wrapping_neg_custom_mod(ciphertext_modulus.get_custom_modulus().cast_into())
|
||||
}
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
factor.wrapping_neg()
|
||||
}
|
||||
};
|
||||
body.as_mut()[0] = encoded;
|
||||
|
||||
encrypt_seeded_glwe_ciphertext_assign_with_existing_generator(
|
||||
glwe_secret_key,
|
||||
@@ -828,20 +935,40 @@ where
|
||||
|
||||
let decomp_base_log = ggsw_ciphertext.decomposition_base_log();
|
||||
|
||||
let decomposer = SignedDecomposer::new(decomp_base_log, decomp_level);
|
||||
|
||||
let plaintext_ref = decrypted_plaintext_list.get(0);
|
||||
|
||||
// Glwe decryption maps to a smaller torus potentially, map back to the native torus
|
||||
let rounded = decomposer.closest_representable(
|
||||
(*plaintext_ref.0).wrapping_mul(
|
||||
ggsw_ciphertext
|
||||
.ciphertext_modulus()
|
||||
.get_scaling_to_native_torus(),
|
||||
),
|
||||
);
|
||||
let decoded =
|
||||
rounded.wrapping_div(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)));
|
||||
let ciphertext_modulus = ggsw_ciphertext.ciphertext_modulus();
|
||||
|
||||
Plaintext(decoded)
|
||||
match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::NonNative => {
|
||||
let decomposer =
|
||||
SignedDecomposerNonNative::new(decomp_base_log, decomp_level, ciphertext_modulus);
|
||||
|
||||
let rounded = decomposer.closest_representable(*plaintext_ref.0);
|
||||
let delta = DecompositionTermNonNative::new(
|
||||
DecompositionLevel(decomp_level.0),
|
||||
decomp_base_log,
|
||||
Scalar::ONE,
|
||||
ciphertext_modulus,
|
||||
)
|
||||
.to_recomposition_summand();
|
||||
|
||||
let decoded = rounded.wrapping_div(delta);
|
||||
|
||||
Plaintext(decoded)
|
||||
}
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
let decomposer = SignedDecomposer::new(decomp_base_log, decomp_level);
|
||||
|
||||
// Glwe decryption maps to a smaller torus potentially, map back to the native torus
|
||||
let rounded = decomposer.closest_representable(
|
||||
(*plaintext_ref.0)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
let decoded = rounded
|
||||
.wrapping_div(Scalar::ONE << (Scalar::BITS - (decomp_base_log.0 * decomp_level.0)));
|
||||
|
||||
Plaintext(decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ use crate::core_crypto::algorithms::polynomial_algorithms::*;
|
||||
use crate::core_crypto::algorithms::slice_algorithms::{
|
||||
slice_wrapping_scalar_div_assign, slice_wrapping_scalar_mul_assign,
|
||||
};
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::dispersion::DispersionParameter;
|
||||
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
|
||||
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
|
||||
@@ -26,6 +27,46 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
|
||||
BodyCont: ContainerMut<Element = Scalar>,
|
||||
MaskCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
fill_glwe_mask_and_body_for_encryption_assign_native_mod_compatible(
|
||||
glwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
} else {
|
||||
fill_glwe_mask_and_body_for_encryption_assign_non_native_mod(
|
||||
glwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fill_glwe_mask_and_body_for_encryption_assign_native_mod_compatible<
|
||||
KeyCont,
|
||||
BodyCont,
|
||||
MaskCont,
|
||||
Scalar,
|
||||
Gen,
|
||||
>(
|
||||
glwe_secret_key: &GlweSecretKey<KeyCont>,
|
||||
output_mask: &mut GlweMask<MaskCont>,
|
||||
output_body: &mut GlweBody<BodyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
BodyCont: ContainerMut<Element = Scalar>,
|
||||
MaskCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
@@ -37,6 +78,8 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
generator.unsigned_torus_slice_wrapping_add_random_noise_custom_mod_assign(
|
||||
output_body.as_mut(),
|
||||
@@ -44,8 +87,9 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
slice_wrapping_scalar_mul_assign(output_body.as_mut(), torus_scaling);
|
||||
}
|
||||
@@ -57,6 +101,52 @@ pub fn fill_glwe_mask_and_body_for_encryption_assign<KeyCont, BodyCont, MaskCont
|
||||
);
|
||||
}
|
||||
|
||||
pub fn fill_glwe_mask_and_body_for_encryption_assign_non_native_mod<
|
||||
KeyCont,
|
||||
BodyCont,
|
||||
MaskCont,
|
||||
Scalar,
|
||||
Gen,
|
||||
>(
|
||||
glwe_secret_key: &GlweSecretKey<KeyCont>,
|
||||
output_mask: &mut GlweMask<MaskCont>,
|
||||
output_body: &mut GlweBody<BodyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
BodyCont: ContainerMut<Element = Scalar>,
|
||||
MaskCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus(),
|
||||
"Mismatched moduli between output_mask ({:?}) and output_body ({:?})",
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
generator.unsigned_torus_slice_wrapping_add_random_noise_custom_mod_assign(
|
||||
output_body.as_mut(),
|
||||
noise_parameters,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
polynomial_wrapping_add_multisum_assign_custom_mod(
|
||||
&mut output_body.as_mut_polynomial(),
|
||||
&output_mask.as_polynomial_list(),
|
||||
&glwe_secret_key.as_polynomial_list(),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
);
|
||||
}
|
||||
|
||||
/// Variant of [`encrypt_glwe_ciphertext`] which assumes that the plaintexts to encrypt are already
|
||||
/// loaded in the body of the output [`GLWE ciphertext`](`GlweCiphertext`), this is sometimes useful
|
||||
/// to avoid allocating a [`PlaintextList`] in situ.
|
||||
@@ -242,6 +332,51 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
|
||||
BodyCont: ContainerMut<Element = Scalar>,
|
||||
MaskCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
fill_glwe_mask_and_body_for_encryption_native_mod_compatible(
|
||||
glwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
} else {
|
||||
fill_glwe_mask_and_body_for_encryption_non_ative_mod(
|
||||
glwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fill_glwe_mask_and_body_for_encryption_native_mod_compatible<
|
||||
KeyCont,
|
||||
InputCont,
|
||||
BodyCont,
|
||||
MaskCont,
|
||||
Scalar,
|
||||
Gen,
|
||||
>(
|
||||
glwe_secret_key: &GlweSecretKey<KeyCont>,
|
||||
output_mask: &mut GlweMask<MaskCont>,
|
||||
output_body: &mut GlweBody<BodyCont>,
|
||||
encoded: &PlaintextList<InputCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
BodyCont: ContainerMut<Element = Scalar>,
|
||||
MaskCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
@@ -250,6 +385,8 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
generator.fill_slice_with_random_noise_custom_mod(
|
||||
output_body.as_mut(),
|
||||
@@ -262,8 +399,9 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
|
||||
&encoded.as_polynomial(),
|
||||
);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
slice_wrapping_scalar_mul_assign(output_body.as_mut(), torus_scaling);
|
||||
}
|
||||
@@ -275,6 +413,60 @@ pub fn fill_glwe_mask_and_body_for_encryption<KeyCont, InputCont, BodyCont, Mask
|
||||
);
|
||||
}
|
||||
|
||||
pub fn fill_glwe_mask_and_body_for_encryption_non_ative_mod<
|
||||
KeyCont,
|
||||
InputCont,
|
||||
BodyCont,
|
||||
MaskCont,
|
||||
Scalar,
|
||||
Gen,
|
||||
>(
|
||||
glwe_secret_key: &GlweSecretKey<KeyCont>,
|
||||
output_mask: &mut GlweMask<MaskCont>,
|
||||
output_body: &mut GlweBody<BodyCont>,
|
||||
encoded: &PlaintextList<InputCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
BodyCont: ContainerMut<Element = Scalar>,
|
||||
MaskCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
generator.fill_slice_with_random_noise_custom_mod(
|
||||
output_body.as_mut(),
|
||||
noise_parameters,
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let ciphertext_modulus = ciphertext_modulus.get_custom_modulus().cast_into();
|
||||
|
||||
polynomial_wrapping_add_assign_custom_mod(
|
||||
&mut output_body.as_mut_polynomial(),
|
||||
&encoded.as_polynomial(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
polynomial_wrapping_add_multisum_assign_custom_mod(
|
||||
&mut output_body.as_mut_polynomial(),
|
||||
&output_mask.as_polynomial_list(),
|
||||
&glwe_secret_key.as_polynomial_list(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
}
|
||||
|
||||
/// Encrypt a (scalar) plaintext list in a [`GLWE ciphertext`](`GlweCiphertext`).
|
||||
///
|
||||
/// # Formal Definition
|
||||
@@ -576,16 +768,30 @@ pub fn decrypt_glwe_ciphertext<Scalar, KeyCont, InputCont, OutputCont>(
|
||||
output_plaintext_list
|
||||
.as_mut()
|
||||
.copy_from_slice(body.as_ref());
|
||||
polynomial_wrapping_sub_multisum_assign(
|
||||
&mut output_plaintext_list.as_mut_polynomial(),
|
||||
&mask.as_polynomial_list(),
|
||||
&glwe_secret_key.as_polynomial_list(),
|
||||
);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let ciphertext_modulus_kind = ciphertext_modulus.kind();
|
||||
|
||||
match ciphertext_modulus_kind {
|
||||
CiphertextModulusKind::NonNative => polynomial_wrapping_sub_multisum_assign_custom_mod(
|
||||
&mut output_plaintext_list.as_mut_polynomial(),
|
||||
&mask.as_polynomial_list(),
|
||||
&glwe_secret_key.as_polynomial_list(),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
),
|
||||
CiphertextModulusKind::Native | CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
polynomial_wrapping_sub_multisum_assign(
|
||||
&mut output_plaintext_list.as_mut_polynomial(),
|
||||
&mask.as_polynomial_list(),
|
||||
&glwe_secret_key.as_polynomial_list(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus_kind {
|
||||
slice_wrapping_scalar_div_assign(
|
||||
output_plaintext_list.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -720,10 +926,11 @@ pub fn trivially_encrypt_glwe_ciphertext<Scalar, InputCont, OutputCont>(
|
||||
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
body.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -809,10 +1016,11 @@ where
|
||||
let mut body = new_ct.get_mut_body();
|
||||
body.as_mut().copy_from_slice(encoded.as_ref());
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
body.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::dispersion::DispersionParameter;
|
||||
use crate::core_crypto::commons::generators::{EncryptionRandomGenerator, SecretRandomGenerator};
|
||||
use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, RandomGenerator};
|
||||
@@ -36,6 +37,57 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
|
||||
|
||||
let ciphertext_modulus = output_mask.ciphertext_modulus();
|
||||
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
fill_lwe_mask_and_body_for_encryption_native_mod_compatible(
|
||||
lwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
} else {
|
||||
fill_lwe_mask_and_body_for_encryption_non_native_mod(
|
||||
lwe_secret_key,
|
||||
output_mask,
|
||||
output_body,
|
||||
encoded,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn fill_lwe_mask_and_body_for_encryption_native_mod_compatible<
|
||||
Scalar,
|
||||
KeyCont,
|
||||
OutputCont,
|
||||
Gen,
|
||||
>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
output_mask: &mut LweMask<OutputCont>,
|
||||
output_body: LweBodyRefMut<Scalar>,
|
||||
encoded: Plaintext<Scalar>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus(),
|
||||
"Mismatched moduli between mask ({:?}) and body ({:?})",
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_mask.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
|
||||
// generate an error from the normal distribution described by std_dev
|
||||
@@ -43,7 +95,7 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
|
||||
*output_body.data = (*output_body.data).wrapping_add(encoded.0);
|
||||
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
let torus_scaling = ciphertext_modulus.get_scaling_to_native_torus();
|
||||
let torus_scaling = ciphertext_modulus.get_power_of_two_scaling_to_native_torus();
|
||||
slice_wrapping_scalar_mul_assign(output_mask.as_mut(), torus_scaling);
|
||||
*output_body.data = (*output_body.data).wrapping_mul(torus_scaling);
|
||||
}
|
||||
@@ -55,6 +107,51 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
|
||||
));
|
||||
}
|
||||
|
||||
pub fn fill_lwe_mask_and_body_for_encryption_non_native_mod<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
output_mask: &mut LweMask<OutputCont>,
|
||||
output_body: LweBodyRefMut<Scalar>,
|
||||
encoded: Plaintext<Scalar>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert_eq!(
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus(),
|
||||
"Mismatched moduli between mask ({:?}) and body ({:?})",
|
||||
output_mask.ciphertext_modulus(),
|
||||
output_body.ciphertext_modulus()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output_mask.ciphertext_modulus();
|
||||
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
|
||||
// generate an error from the normal distribution described by std_dev
|
||||
*output_body.data = generator.random_noise_custom_mod(noise_parameters, ciphertext_modulus);
|
||||
*output_body.data = (*output_body.data).wrapping_add_custom_mod(
|
||||
encoded.0,
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
);
|
||||
|
||||
// compute the multisum between the secret key and the mask
|
||||
*output_body.data = (*output_body.data).wrapping_add_custom_mod(
|
||||
slice_wrapping_dot_product_custom_mod(
|
||||
output_mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
);
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`).
|
||||
///
|
||||
/// See the [`LWE ciphertext formal definition`](`LweCiphertext#lwe-encryption`) for the definition
|
||||
@@ -296,9 +393,10 @@ pub fn trivially_encrypt_lwe_ciphertext<Scalar, OutputCont>(
|
||||
*output_body.data = encoded.0;
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
*output_body.data =
|
||||
(*output_body.data).wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
// Manage non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
*output_body.data = (*output_body.data)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,9 +472,10 @@ where
|
||||
*output_body.data = encoded.0;
|
||||
|
||||
let ciphertext_modulus = output_body.ciphertext_modulus();
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
*output_body.data =
|
||||
(*output_body.data).wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
*output_body.data = (*output_body.data)
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
new_ct
|
||||
@@ -394,6 +493,24 @@ pub fn decrypt_lwe_ciphertext<Scalar, KeyCont, InputCont>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
) -> Plaintext<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
decrypt_lwe_ciphertext_native_mod_compatible(lwe_secret_key, lwe_ciphertext)
|
||||
} else {
|
||||
decrypt_lwe_ciphertext_non_native_mod(lwe_secret_key, lwe_ciphertext)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_lwe_ciphertext_native_mod_compatible<Scalar, KeyCont, InputCont>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
) -> Plaintext<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
@@ -409,6 +526,8 @@ where
|
||||
|
||||
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let (mask, body) = lwe_ciphertext.get_mask_and_body();
|
||||
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
@@ -423,11 +542,44 @@ where
|
||||
mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
))
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt_lwe_ciphertext_non_native_mod<Scalar, KeyCont, InputCont>(
|
||||
lwe_secret_key: &LweSecretKey<KeyCont>,
|
||||
lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
) -> Plaintext<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
KeyCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
lwe_ciphertext.lwe_size().to_lwe_dimension() == lwe_secret_key.lwe_dimension(),
|
||||
"Mismatch between LweDimension of output ciphertext and input secret key. \
|
||||
Got {:?} in output, and {:?} in secret key.",
|
||||
lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
lwe_secret_key.lwe_dimension()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
|
||||
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
let (mask, body) = lwe_ciphertext.get_mask_and_body();
|
||||
|
||||
Plaintext((*body.data).wrapping_sub_custom_mod(
|
||||
slice_wrapping_dot_product_custom_mod(
|
||||
mask.as_ref(),
|
||||
lwe_secret_key.as_ref(),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
))
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext list in an output [`LWE ciphertext list`](`LweCiphertextList`).
|
||||
///
|
||||
/// See this [`formal definition`](`encrypt_lwe_ciphertext#formal-definition`) for the definition
|
||||
@@ -775,8 +927,6 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_public_key.lwe_size().to_lwe_dimension()
|
||||
);
|
||||
|
||||
let ciphertext_modulus = output.ciphertext_modulus();
|
||||
|
||||
output.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
let mut tmp_zero_encryption =
|
||||
@@ -798,17 +948,7 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
|
||||
lwe_ciphertext_add_assign(output, &tmp_zero_encryption);
|
||||
}
|
||||
|
||||
let body = output.get_mut_body();
|
||||
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_add(encoded.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
encoded
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
lwe_ciphertext_plaintext_add_assign(output, encoded);
|
||||
}
|
||||
|
||||
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`) using a
|
||||
@@ -923,10 +1063,10 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
let (mut mask, body) = tmp_zero_encryption.get_mut_mask_and_body();
|
||||
random_generator
|
||||
.fill_slice_with_random_uniform_custom_mod(mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
if ciphertext_modulus.is_non_native_power_of_two() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
*body.data = *public_encryption_of_zero_body.data;
|
||||
@@ -937,17 +1077,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
|
||||
lwe_ciphertext_add_assign(output, &tmp_zero_encryption);
|
||||
}
|
||||
|
||||
// Add encoded plaintext
|
||||
let body = output.get_mut_body();
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_add(encoded.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
encoded
|
||||
.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
lwe_ciphertext_plaintext_add_assign(output, encoded);
|
||||
}
|
||||
|
||||
/// Convenience function to share the core logic of the seeded LWE encryption between all functions
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`).
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
SignedDecomposer, SignedDecomposerNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::numeric::UnsignedInteger;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -99,6 +101,34 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
KSKCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
if lwe_keyswitch_key
|
||||
.ciphertext_modulus()
|
||||
.is_compatible_with_native_modulus()
|
||||
{
|
||||
keyswitch_lwe_ciphertext_native_mod_compatible(
|
||||
lwe_keyswitch_key,
|
||||
input_lwe_ciphertext,
|
||||
output_lwe_ciphertext,
|
||||
)
|
||||
} else {
|
||||
keyswitch_lwe_ciphertext_non_native_mod(
|
||||
lwe_keyswitch_key,
|
||||
input_lwe_ciphertext,
|
||||
output_lwe_ciphertext,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn keyswitch_lwe_ciphertext_native_mod_compatible<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
|
||||
input_lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
KSKCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension()
|
||||
@@ -116,6 +146,25 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
);
|
||||
assert_eq!(
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
input_lwe_ciphertext.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus. \
|
||||
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
input_lwe_ciphertext.ciphertext_modulus(),
|
||||
);
|
||||
assert_eq!(
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_lwe_ciphertext.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus. \
|
||||
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_lwe_ciphertext.ciphertext_modulus(),
|
||||
);
|
||||
assert!(lwe_keyswitch_key
|
||||
.ciphertext_modulus()
|
||||
.is_compatible_with_native_modulus());
|
||||
|
||||
// Clear the output ciphertext, as it will get updated gradually
|
||||
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
|
||||
@@ -145,3 +194,79 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn keyswitch_lwe_ciphertext_non_native_mod<Scalar, KSKCont, InputCont, OutputCont>(
|
||||
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
|
||||
input_lwe_ciphertext: &LweCiphertext<InputCont>,
|
||||
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
KSKCont: Container<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension()
|
||||
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
"Mismatched input LweDimension. \
|
||||
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
|
||||
lwe_keyswitch_key.input_key_lwe_dimension(),
|
||||
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
);
|
||||
assert!(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension()
|
||||
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
"Mismatched output LweDimension. \
|
||||
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
|
||||
);
|
||||
assert_eq!(
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
input_lwe_ciphertext.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus. \
|
||||
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
input_lwe_ciphertext.ciphertext_modulus(),
|
||||
);
|
||||
assert_eq!(
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_lwe_ciphertext.ciphertext_modulus(),
|
||||
"Mismatched CiphertextModulus. \
|
||||
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
|
||||
lwe_keyswitch_key.ciphertext_modulus(),
|
||||
output_lwe_ciphertext.ciphertext_modulus(),
|
||||
);
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// Clear the output ciphertext, as it will get updated gradually
|
||||
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
|
||||
|
||||
// Copy the input body to the output ciphertext
|
||||
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
|
||||
|
||||
// We instantiate a decomposer
|
||||
let decomposer = SignedDecomposerNonNative::new(
|
||||
lwe_keyswitch_key.decomposition_base_log(),
|
||||
lwe_keyswitch_key.decomposition_level_count(),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
|
||||
.iter()
|
||||
.zip(input_lwe_ciphertext.get_mask().as_ref())
|
||||
{
|
||||
let decomposition_iter = decomposer.decompose(input_mask_element);
|
||||
// Loop over the levels
|
||||
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
|
||||
{
|
||||
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
|
||||
output_lwe_ciphertext.as_mut(),
|
||||
level_key_ciphertext.as_ref(),
|
||||
decomposed.value(),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,9 @@
|
||||
use crate::core_crypto::algorithms::*;
|
||||
use crate::core_crypto::commons::dispersion::DispersionParameter;
|
||||
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
|
||||
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, DecompositionTerm};
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
DecompositionLevel, DecompositionTerm, DecompositionTermNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
|
||||
use crate::core_crypto::commons::parameters::*;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
@@ -75,59 +77,217 @@ pub fn generate_lwe_keyswitch_key<Scalar, InputKeyCont, OutputKeyCont, KSKeyCont
|
||||
KSKeyCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
|
||||
"The destination LweKeyswitchKey input LweDimension is not equal \
|
||||
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
|
||||
lwe_keyswitch_key.input_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
assert!(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
|
||||
"The destination LweKeyswitchKey output LweDimension is not equal \
|
||||
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
|
||||
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
|
||||
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
|
||||
// The plaintexts used to encrypt a key element will be stored in this buffer
|
||||
let mut decomposition_plaintexts_buffer =
|
||||
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
|
||||
|
||||
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
|
||||
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(lwe_keyswitch_key.iter_mut())
|
||||
{
|
||||
// We fill the buffer with the powers of the key elmements
|
||||
for (level, message) in (1..=decomp_level_count.0)
|
||||
.rev()
|
||||
.map(DecompositionLevel)
|
||||
.zip(decomposition_plaintexts_buffer.iter_mut())
|
||||
{
|
||||
// Here we take the decomposition term from the native torus, bring it to the torus we
|
||||
// are working with by dividing by the scaling factor and the encryption will take care
|
||||
// of mapping that back to the native torus
|
||||
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
|
||||
.to_recomposition_summand()
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
encrypt_lwe_ciphertext_list(
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
generate_lwe_keyswitch_key_native_mod_compatible(
|
||||
input_lwe_sk,
|
||||
output_lwe_sk,
|
||||
&mut keyswitch_key_block,
|
||||
&decomposition_plaintexts_buffer,
|
||||
lwe_keyswitch_key,
|
||||
noise_parameters,
|
||||
generator,
|
||||
);
|
||||
)
|
||||
} else {
|
||||
generate_lwe_keyswitch_key_non_native_mod(
|
||||
input_lwe_sk,
|
||||
output_lwe_sk,
|
||||
lwe_keyswitch_key,
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_lwe_keyswitch_key_native_mod_compatible<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
KSKeyCont,
|
||||
Gen,
|
||||
>(
|
||||
input_lwe_sk: &LweSecretKey<InputKeyCont>,
|
||||
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
|
||||
lwe_keyswitch_key: &mut LweKeyswitchKey<KSKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar>,
|
||||
KSKeyCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
fn implementation<Scalar, Gen>(
|
||||
input_lwe_sk: LweSecretKeyView<'_, Scalar>,
|
||||
output_lwe_sk: LweSecretKeyView<'_, Scalar>,
|
||||
mut lwe_keyswitch_key: LweKeyswitchKeyMutView<'_, Scalar>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
|
||||
"The destination LweKeyswitchKey input LweDimension is not equal \
|
||||
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
|
||||
lwe_keyswitch_key.input_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
assert!(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
|
||||
"The destination LweKeyswitchKey output LweDimension is not equal \
|
||||
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
assert!(lwe_keyswitch_key
|
||||
.ciphertext_modulus()
|
||||
.is_compatible_with_native_modulus());
|
||||
|
||||
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
|
||||
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
|
||||
// The plaintexts used to encrypt a key element will be stored in this buffer
|
||||
let mut decomposition_plaintexts_buffer =
|
||||
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
|
||||
|
||||
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
|
||||
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(lwe_keyswitch_key.iter_mut())
|
||||
{
|
||||
// We fill the buffer with the powers of the key elmements
|
||||
for (level, message) in (1..=decomp_level_count.0)
|
||||
.rev()
|
||||
.map(DecompositionLevel)
|
||||
.zip(decomposition_plaintexts_buffer.iter_mut())
|
||||
{
|
||||
// Here we take the decomposition term from the native torus, bring it to the torus
|
||||
// we are working with by dividing by the scaling factor and the
|
||||
// encryption will take care of mapping that back to the native
|
||||
// torus
|
||||
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
|
||||
.to_recomposition_summand()
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&output_lwe_sk,
|
||||
&mut keyswitch_key_block,
|
||||
&decomposition_plaintexts_buffer,
|
||||
noise_parameters,
|
||||
generator,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
implementation(
|
||||
input_lwe_sk.as_view(),
|
||||
output_lwe_sk.as_view(),
|
||||
lwe_keyswitch_key.as_mut_view(),
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn generate_lwe_keyswitch_key_non_native_mod<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
KSKeyCont,
|
||||
Gen,
|
||||
>(
|
||||
input_lwe_sk: &LweSecretKey<InputKeyCont>,
|
||||
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
|
||||
lwe_keyswitch_key: &mut LweKeyswitchKey<KSKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar>,
|
||||
KSKeyCont: ContainerMut<Element = Scalar>,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
fn implementation<Scalar, Gen>(
|
||||
input_lwe_sk: LweSecretKeyView<'_, Scalar>,
|
||||
output_lwe_sk: LweSecretKeyView<'_, Scalar>,
|
||||
mut lwe_keyswitch_key: LweKeyswitchKeyMutView<'_, Scalar>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
generator: &mut EncryptionRandomGenerator<Gen>,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
Gen: ByteRandomGenerator,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
|
||||
"The destination LweKeyswitchKey input LweDimension is not equal \
|
||||
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
|
||||
lwe_keyswitch_key.input_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
assert!(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
|
||||
"The destination LweKeyswitchKey output LweDimension is not equal \
|
||||
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
assert!(!lwe_keyswitch_key
|
||||
.ciphertext_modulus()
|
||||
.is_compatible_with_native_modulus());
|
||||
|
||||
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
|
||||
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
|
||||
// The plaintexts used to encrypt a key element will be stored in this buffer
|
||||
let mut decomposition_plaintexts_buffer =
|
||||
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
|
||||
|
||||
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
|
||||
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(lwe_keyswitch_key.iter_mut())
|
||||
{
|
||||
// We fill the buffer with the powers of the key elmements
|
||||
for (level, message) in (1..=decomp_level_count.0)
|
||||
.rev()
|
||||
.map(DecompositionLevel)
|
||||
.zip(decomposition_plaintexts_buffer.iter_mut())
|
||||
{
|
||||
*message.0 = DecompositionTermNonNative::new(
|
||||
level,
|
||||
decomp_base_log,
|
||||
*input_key_element,
|
||||
ciphertext_modulus,
|
||||
)
|
||||
.to_recomposition_summand();
|
||||
}
|
||||
|
||||
encrypt_lwe_ciphertext_list(
|
||||
&output_lwe_sk,
|
||||
&mut keyswitch_key_block,
|
||||
&decomposition_plaintexts_buffer,
|
||||
noise_parameters,
|
||||
generator,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
implementation(
|
||||
input_lwe_sk.as_view(),
|
||||
output_lwe_sk.as_view(),
|
||||
lwe_keyswitch_key.as_mut_view(),
|
||||
noise_parameters,
|
||||
generator,
|
||||
)
|
||||
}
|
||||
|
||||
/// Allocate a new [`LWE keyswitch key`](`LweKeyswitchKey`) and fill it with an actual keyswitching
|
||||
/// key constructed from an input and an output key [`LWE secret key`](`LweSecretKey`).
|
||||
///
|
||||
@@ -236,6 +396,46 @@ pub fn generate_seeded_lwe_keyswitch_key<
|
||||
KSKeyCont: ContainerMut<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
generate_seeded_lwe_keyswitch_key_native_mod_compatible(
|
||||
input_lwe_sk,
|
||||
output_lwe_sk,
|
||||
lwe_keyswitch_key,
|
||||
noise_parameters,
|
||||
noise_seeder,
|
||||
)
|
||||
} else {
|
||||
generate_seeded_lwe_keyswitch_key_non_native_mod(
|
||||
input_lwe_sk,
|
||||
output_lwe_sk,
|
||||
lwe_keyswitch_key,
|
||||
noise_parameters,
|
||||
noise_seeder,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_seeded_lwe_keyswitch_key_native_mod_compatible<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
KSKeyCont,
|
||||
NoiseSeeder,
|
||||
>(
|
||||
input_lwe_sk: &LweSecretKey<InputKeyCont>,
|
||||
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
|
||||
lwe_keyswitch_key: &mut SeededLweKeyswitchKey<KSKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar>,
|
||||
KSKeyCont: ContainerMut<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
|
||||
@@ -255,6 +455,7 @@ pub fn generate_seeded_lwe_keyswitch_key<
|
||||
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
|
||||
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// The plaintexts used to encrypt a key element will be stored in this buffer
|
||||
let mut decomposition_plaintexts_buffer =
|
||||
@@ -282,7 +483,87 @@ pub fn generate_seeded_lwe_keyswitch_key<
|
||||
// of mapping that back to the native torus
|
||||
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
|
||||
.to_recomposition_summand()
|
||||
.wrapping_div(ciphertext_modulus.get_scaling_to_native_torus());
|
||||
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
|
||||
}
|
||||
|
||||
encrypt_seeded_lwe_ciphertext_list_with_existing_generator(
|
||||
output_lwe_sk,
|
||||
&mut keyswitch_key_block,
|
||||
&decomposition_plaintexts_buffer,
|
||||
noise_parameters,
|
||||
&mut generator,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_seeded_lwe_keyswitch_key_non_native_mod<
|
||||
Scalar,
|
||||
InputKeyCont,
|
||||
OutputKeyCont,
|
||||
KSKeyCont,
|
||||
NoiseSeeder,
|
||||
>(
|
||||
input_lwe_sk: &LweSecretKey<InputKeyCont>,
|
||||
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
|
||||
lwe_keyswitch_key: &mut SeededLweKeyswitchKey<KSKeyCont>,
|
||||
noise_parameters: impl DispersionParameter,
|
||||
noise_seeder: &mut NoiseSeeder,
|
||||
) where
|
||||
Scalar: UnsignedTorus,
|
||||
InputKeyCont: Container<Element = Scalar>,
|
||||
OutputKeyCont: Container<Element = Scalar>,
|
||||
KSKeyCont: ContainerMut<Element = Scalar>,
|
||||
// Maybe Sized allows to pass Box<dyn Seeder>.
|
||||
NoiseSeeder: Seeder + ?Sized,
|
||||
{
|
||||
assert!(
|
||||
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
|
||||
"The destination SeededLweKeyswitchKey input LweDimension is not equal \
|
||||
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
|
||||
lwe_keyswitch_key.input_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
assert!(
|
||||
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
|
||||
"The destination SeededLweKeyswitchKey output LweDimension is not equal \
|
||||
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
|
||||
lwe_keyswitch_key.output_key_lwe_dimension(),
|
||||
input_lwe_sk.lwe_dimension()
|
||||
);
|
||||
|
||||
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
|
||||
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
|
||||
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
// The plaintexts used to encrypt a key element will be stored in this buffer
|
||||
let mut decomposition_plaintexts_buffer =
|
||||
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
|
||||
|
||||
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
|
||||
lwe_keyswitch_key.compression_seed().seed,
|
||||
noise_seeder,
|
||||
);
|
||||
|
||||
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
|
||||
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
|
||||
.as_ref()
|
||||
.iter()
|
||||
.zip(lwe_keyswitch_key.iter_mut())
|
||||
{
|
||||
// We fill the buffer with the powers of the key elmements
|
||||
for (level, message) in (1..=decomp_level_count.0)
|
||||
.rev()
|
||||
.map(DecompositionLevel)
|
||||
.zip(decomposition_plaintexts_buffer.iter_mut())
|
||||
{
|
||||
*message.0 = DecompositionTermNonNative::new(
|
||||
level,
|
||||
decomp_base_log,
|
||||
*input_key_element,
|
||||
ciphertext_modulus,
|
||||
)
|
||||
.to_recomposition_summand();
|
||||
}
|
||||
|
||||
encrypt_seeded_lwe_ciphertext_list_with_existing_generator(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
//! like addition, multiplication, etc.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::*;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::numeric::UnsignedInteger;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -71,6 +72,22 @@ pub fn lwe_ciphertext_add_assign<Scalar, LhsCont, RhsCont>(
|
||||
Scalar: UnsignedInteger,
|
||||
LhsCont: ContainerMut<Element = Scalar>,
|
||||
RhsCont: Container<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = rhs.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
lwe_ciphertext_add_assign_native_mod_compatible(lhs, rhs)
|
||||
} else {
|
||||
lwe_ciphertext_add_assign_non_native_mod(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_add_assign_native_mod_compatible<Scalar, LhsCont, RhsCont>(
|
||||
lhs: &mut LweCiphertext<LhsCont>,
|
||||
rhs: &LweCiphertext<RhsCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
LhsCont: ContainerMut<Element = Scalar>,
|
||||
RhsCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(
|
||||
lhs.ciphertext_modulus(),
|
||||
@@ -79,10 +96,37 @@ pub fn lwe_ciphertext_add_assign<Scalar, LhsCont, RhsCont>(
|
||||
lhs.ciphertext_modulus(),
|
||||
rhs.ciphertext_modulus()
|
||||
);
|
||||
let ciphertext_modulus = rhs.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
slice_wrapping_add_assign(lhs.as_mut(), rhs.as_ref());
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_add_assign_non_native_mod<Scalar, LhsCont, RhsCont>(
|
||||
lhs: &mut LweCiphertext<LhsCont>,
|
||||
rhs: &LweCiphertext<RhsCont>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
LhsCont: ContainerMut<Element = Scalar>,
|
||||
RhsCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(
|
||||
lhs.ciphertext_modulus(),
|
||||
rhs.ciphertext_modulus(),
|
||||
"Mismatched moduli between lhs ({:?}) and rhs ({:?}) LweCiphertext",
|
||||
lhs.ciphertext_modulus(),
|
||||
rhs.ciphertext_modulus()
|
||||
);
|
||||
let ciphertext_modulus = rhs.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
slice_wrapping_add_assign_custom_mod(
|
||||
lhs.as_mut(),
|
||||
rhs.as_ref(),
|
||||
ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
);
|
||||
}
|
||||
|
||||
/// Add the right-hand side [`LWE ciphertext`](`LweCiphertext`) to the left-hand side [`LWE
|
||||
/// ciphertext`](`LweCiphertext`) writing the result in the output [`LWE
|
||||
/// ciphertext`](`LweCiphertext`).
|
||||
@@ -235,19 +279,51 @@ pub fn lwe_ciphertext_plaintext_add_assign<Scalar, InCont>(
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = lhs.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
lwe_ciphertext_plaintext_add_assign_native_mod_compatible(lhs, rhs)
|
||||
} else {
|
||||
lwe_ciphertext_plaintext_add_assign_non_native_mod(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_add_assign_native_mod_compatible<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_add(rhs.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
rhs.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
);
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
match ciphertext_modulus.kind() {
|
||||
CiphertextModulusKind::Native => *body.data = (*body.data).wrapping_add(rhs.0),
|
||||
CiphertextModulusKind::NonNativePowerOfTwo => {
|
||||
*body.data = (*body.data).wrapping_add(
|
||||
rhs.0
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
)
|
||||
}
|
||||
CiphertextModulusKind::NonNative => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_add_assign_non_native_mod<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
*body.data = (*body.data)
|
||||
.wrapping_add_custom_mod(rhs.0, ciphertext_modulus.get_custom_modulus().cast_into());
|
||||
}
|
||||
|
||||
/// Add the right-hand side encoded [`Plaintext`] to the left-hand side [`LWE
|
||||
/// ciphertext`](`LweCiphertext`) updating it in-place.
|
||||
///
|
||||
@@ -310,19 +386,49 @@ pub fn lwe_ciphertext_plaintext_sub_assign<Scalar, InCont>(
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let ciphertext_modulus = lhs.ciphertext_modulus();
|
||||
if ciphertext_modulus.is_compatible_with_native_modulus() {
|
||||
lwe_ciphertext_plaintext_sub_assign_native_mod_compatible(lhs, rhs)
|
||||
} else {
|
||||
lwe_ciphertext_plaintext_sub_assign_non_native_mod(lhs, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_sub_assign_native_mod_compatible<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if ciphertext_modulus.is_native_modulus() {
|
||||
*body.data = (*body.data).wrapping_sub(rhs.0);
|
||||
} else {
|
||||
*body.data = (*body.data).wrapping_sub(
|
||||
rhs.0
|
||||
.wrapping_mul(ciphertext_modulus.get_scaling_to_native_torus()),
|
||||
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn lwe_ciphertext_plaintext_sub_assign_non_native_mod<Scalar, InCont>(
|
||||
lhs: &mut LweCiphertext<InCont>,
|
||||
rhs: Plaintext<Scalar>,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
InCont: ContainerMut<Element = Scalar>,
|
||||
{
|
||||
let body = lhs.get_mut_body();
|
||||
let ciphertext_modulus = body.ciphertext_modulus();
|
||||
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
*body.data = (*body.data)
|
||||
.wrapping_sub_custom_mod(rhs.0, ciphertext_modulus.get_custom_modulus().cast_into());
|
||||
}
|
||||
|
||||
/// Compute the opposite of the input [`LWE ciphertext`](`LweCiphertext`) and update it in place.
|
||||
///
|
||||
/// # Example
|
||||
|
||||
@@ -270,6 +270,10 @@ pub fn multi_bit_blind_rotate_assign<Scalar, InputCont, OutputCont, KeyCont>(
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
assert!(accumulator
|
||||
.ciphertext_modulus()
|
||||
.is_compatible_with_native_modulus());
|
||||
|
||||
let (lwe_mask, lwe_body) = input.get_mask_and_body();
|
||||
|
||||
// No way to chunk the result of ggsw_iter at the moment
|
||||
|
||||
@@ -485,6 +485,8 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
|
||||
InputGlweCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(out.ciphertext_modulus(), glwe.ciphertext_modulus());
|
||||
let ciphertext_modulus = out.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
impl_add_external_product_assign(
|
||||
out.as_mut_view(),
|
||||
@@ -494,7 +496,6 @@ pub fn add_external_product_assign_mem_optimized<Scalar, OutputGlweCont, InputGl
|
||||
stack,
|
||||
);
|
||||
|
||||
let ciphertext_modulus = out.ciphertext_modulus();
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
@@ -774,6 +775,8 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
|
||||
GgswCont: Container<Element = c64>,
|
||||
{
|
||||
assert_eq!(ct0.ciphertext_modulus(), ct1.ciphertext_modulus());
|
||||
let ciphertext_modulus = ct0.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
|
||||
cmux(
|
||||
ct0.as_mut_view(),
|
||||
@@ -783,7 +786,6 @@ pub fn cmux_assign_mem_optimized<Scalar, Cont0, Cont1, GgswCont>(
|
||||
stack,
|
||||
);
|
||||
|
||||
let ciphertext_modulus = ct0.ciphertext_modulus();
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to 53
|
||||
// MSBs with information. In our representation of power of 2 moduli < native modulus we
|
||||
|
||||
86
tfhe/src/core_crypto/algorithms/misc.rs
Normal file
86
tfhe/src/core_crypto/algorithms/misc.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
//! Miscellaneous algorithms.
|
||||
|
||||
use crate::core_crypto::prelude::*;
|
||||
|
||||
#[inline]
|
||||
pub fn divide_round_to_u128<Scalar>(numerator: Scalar, denominator: Scalar) -> u128
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
let numerator_128: u128 = numerator.cast_into();
|
||||
let half_denominator: u128 = (denominator / Scalar::TWO).cast_into();
|
||||
let denominator_128: u128 = denominator.cast_into();
|
||||
// That's the rounding
|
||||
(numerator_128 + half_denominator) / denominator_128
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn divide_round_to_u128_custom_mod<Scalar>(
|
||||
numerator: Scalar,
|
||||
denominator: Scalar,
|
||||
modulus: u128,
|
||||
) -> u128
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
let numerator_128: u128 = numerator.cast_into();
|
||||
let half_denominator: u128 = (denominator / Scalar::TWO).cast_into();
|
||||
let denominator_128: u128 = denominator.cast_into();
|
||||
// That's the rounding
|
||||
((numerator_128 + half_denominator) % modulus) / denominator_128
|
||||
}
|
||||
|
||||
pub fn odd_modular_inverse_pow_2<Scalar>(odd_value_to_invert: Scalar, log2_modulo: usize) -> Scalar
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
let t = log2_modulo.ilog2() + if log2_modulo.is_power_of_two() { 0 } else { 1 };
|
||||
let mut y = Scalar::ONE;
|
||||
let e = odd_value_to_invert;
|
||||
|
||||
for i in 1..=t {
|
||||
// 1 << (1 << i) == 2 ^ {2 ^ i}
|
||||
let curr_mod = Scalar::ONE.shl(1 << i);
|
||||
// y = y * (2 - y * e) mod 2 ^ {2 ^ i}
|
||||
// Here using wrapping ops is ok as the modulus used is a power of 2, as long as 2 ^ {2 ^ i}
|
||||
// is smaller than Scalar::BITS, we are good to go, the discarded values would not have been
|
||||
// Used anyways, and 2 ^ {2 ^ i} is compatible with a native modulus
|
||||
y = (y.wrapping_mul(Scalar::TWO.wrapping_sub(y.wrapping_mul(e)))).wrapping_rem(curr_mod);
|
||||
}
|
||||
|
||||
y.wrapping_rem(Scalar::ONE.shl(log2_modulo))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_divide_round() {
|
||||
use rand::Rng;
|
||||
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
const NB_TESTS: usize = 1_000_000_000;
|
||||
const SCALING: f64 = u64::MAX as f64;
|
||||
for _ in 0..NB_TESTS {
|
||||
let num: f64 = rng.gen();
|
||||
let mut denom = 0.0f64;
|
||||
while denom == 0.0f64 {
|
||||
denom = rng.gen();
|
||||
}
|
||||
|
||||
let num = (num * SCALING).round();
|
||||
let denom = (denom * SCALING).round();
|
||||
|
||||
let rounded = (num / denom).round();
|
||||
let expected_rounded_u64: u64 = rounded as u64;
|
||||
|
||||
let num_u64: u64 = num as u64;
|
||||
let denom_u64: u64 = denom as u64;
|
||||
|
||||
// sanity check
|
||||
assert_eq!(num, num_u64 as f64);
|
||||
assert_eq!(denom, denom_u64 as f64);
|
||||
|
||||
let rounded_u128 = divide_round_to_u128(num_u64, denom_u64);
|
||||
|
||||
assert_eq!(expected_rounded_u64, rounded_u128 as u64);
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ pub mod lwe_programmable_bootstrapping;
|
||||
pub mod lwe_public_key_generation;
|
||||
pub mod lwe_secret_key_generation;
|
||||
pub mod lwe_wopbs;
|
||||
pub mod misc;
|
||||
pub mod polynomial_algorithms;
|
||||
pub mod seeded_ggsw_ciphertext_decompression;
|
||||
pub mod seeded_ggsw_ciphertext_list_decompression;
|
||||
|
||||
@@ -35,6 +35,19 @@ pub fn polynomial_wrapping_add_assign<Scalar, OutputCont, InputCont>(
|
||||
slice_wrapping_add_assign(lhs.as_mut(), rhs.as_ref())
|
||||
}
|
||||
|
||||
pub fn polynomial_wrapping_add_assign_custom_mod<Scalar, OutputCont, InputCont>(
|
||||
lhs: &mut Polynomial<OutputCont>,
|
||||
rhs: &Polynomial<InputCont>,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(lhs.polynomial_size(), rhs.polynomial_size());
|
||||
slice_wrapping_add_assign_custom_mod(lhs.as_mut(), rhs.as_ref(), custom_modulus)
|
||||
}
|
||||
|
||||
/// Subtract a polynomial to the output polynomial.
|
||||
///
|
||||
/// # Note
|
||||
@@ -64,6 +77,19 @@ pub fn polynomial_wrapping_sub_assign<Scalar, OutputCont, InputCont>(
|
||||
slice_wrapping_sub_assign(lhs.as_mut(), rhs.as_ref())
|
||||
}
|
||||
|
||||
pub fn polynomial_wrapping_sub_assign_custom_mod<Scalar, OutputCont, InputCont>(
|
||||
lhs: &mut Polynomial<OutputCont>,
|
||||
rhs: &Polynomial<InputCont>,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont: Container<Element = Scalar>,
|
||||
{
|
||||
assert_eq!(lhs.polynomial_size(), rhs.polynomial_size());
|
||||
slice_wrapping_sub_assign_custom_mod(lhs.as_mut(), rhs.as_ref(), custom_modulus)
|
||||
}
|
||||
|
||||
/// Add the sum of the element-wise product between two lists of polynomials to the output
|
||||
/// polynomial.
|
||||
///
|
||||
@@ -105,6 +131,27 @@ pub fn polynomial_wrapping_add_multisum_assign<Scalar, OutputCont, InputCont1, I
|
||||
}
|
||||
}
|
||||
|
||||
pub fn polynomial_wrapping_add_multisum_assign_custom_mod<
|
||||
Scalar,
|
||||
OutputCont,
|
||||
InputCont1,
|
||||
InputCont2,
|
||||
>(
|
||||
output: &mut Polynomial<OutputCont>,
|
||||
poly_list_1: &PolynomialList<InputCont1>,
|
||||
poly_list_2: &PolynomialList<InputCont2>,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont1: Container<Element = Scalar>,
|
||||
InputCont2: Container<Element = Scalar>,
|
||||
{
|
||||
for (poly_1, poly_2) in poly_list_1.iter().zip(poly_list_2.iter()) {
|
||||
polynomial_wrapping_add_mul_assign_custom_mod(output, &poly_1, &poly_2, custom_modulus);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add the result of the product between two polynomials, reduced modulo $(X^{N}+1)$, to the
|
||||
/// output polynomial.
|
||||
///
|
||||
@@ -176,6 +223,63 @@ pub fn polynomial_wrapping_add_mul_assign<Scalar, OutputCont, InputCont1, InputC
|
||||
}
|
||||
}
|
||||
|
||||
pub fn polynomial_wrapping_add_mul_assign_custom_mod<Scalar, OutputCont, InputCont1, InputCont2>(
|
||||
output: &mut Polynomial<OutputCont>,
|
||||
lhs: &Polynomial<InputCont1>,
|
||||
rhs: &Polynomial<InputCont2>,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont1: Container<Element = Scalar>,
|
||||
InputCont2: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
output.polynomial_size() == lhs.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input lhs polynomial {:?}.",
|
||||
output.polynomial_size(),
|
||||
lhs.polynomial_size(),
|
||||
);
|
||||
assert!(
|
||||
output.polynomial_size() == rhs.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input rhs polynomial {:?}.",
|
||||
output.polynomial_size(),
|
||||
rhs.polynomial_size(),
|
||||
);
|
||||
|
||||
let polynomial_size = output.polynomial_size();
|
||||
|
||||
if polynomial_size.0.is_power_of_two() && polynomial_size.0 > KARATUSBA_STOP {
|
||||
let mut tmp = Polynomial::new(Scalar::ZERO, polynomial_size);
|
||||
|
||||
polynomial_karatsuba_wrapping_mul_custom_mod(&mut tmp, lhs, rhs, custom_modulus);
|
||||
polynomial_wrapping_add_assign_custom_mod(output, &tmp, custom_modulus);
|
||||
} else {
|
||||
let degree = output.degree();
|
||||
for (lhs_degree, &lhs_coeff) in lhs.iter().enumerate() {
|
||||
for (rhs_degree, &rhs_coeff) in rhs.iter().enumerate() {
|
||||
let target_degree = lhs_degree + rhs_degree;
|
||||
if target_degree <= degree {
|
||||
let output_coefficient = &mut output.as_mut()[target_degree];
|
||||
|
||||
*output_coefficient = (*output_coefficient).wrapping_add_custom_mod(
|
||||
lhs_coeff.wrapping_mul_custom_mod(rhs_coeff, custom_modulus),
|
||||
custom_modulus,
|
||||
);
|
||||
} else {
|
||||
let target_degree = target_degree % polynomial_size.0;
|
||||
let output_coefficient = &mut output.as_mut()[target_degree];
|
||||
|
||||
*output_coefficient = (*output_coefficient).wrapping_sub_custom_mod(
|
||||
lhs_coeff.wrapping_mul_custom_mod(rhs_coeff, custom_modulus),
|
||||
custom_modulus,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Divides (mod $(X^{N}+1)$), the output polynomial with a monic monomial of a given degree i.e.
|
||||
/// $X^{degree}$.
|
||||
///
|
||||
@@ -301,6 +405,27 @@ pub fn polynomial_wrapping_sub_multisum_assign<Scalar, OutputCont, InputCont1, I
|
||||
}
|
||||
}
|
||||
|
||||
pub fn polynomial_wrapping_sub_multisum_assign_custom_mod<
|
||||
Scalar,
|
||||
OutputCont,
|
||||
InputCont1,
|
||||
InputCont2,
|
||||
>(
|
||||
output: &mut Polynomial<OutputCont>,
|
||||
poly_list_1: &PolynomialList<InputCont1>,
|
||||
poly_list_2: &PolynomialList<InputCont2>,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont1: Container<Element = Scalar>,
|
||||
InputCont2: Container<Element = Scalar>,
|
||||
{
|
||||
for (poly_1, poly_2) in poly_list_1.iter().zip(poly_list_2.iter()) {
|
||||
polynomial_wrapping_sub_mul_assign_custom_mod(output, &poly_1, &poly_2, custom_modulus);
|
||||
}
|
||||
}
|
||||
|
||||
/// Subtract the result of the product between two polynomials, reduced modulo $(X^{N}+1)$, to the
|
||||
/// output polynomial.
|
||||
///
|
||||
@@ -372,6 +497,63 @@ pub fn polynomial_wrapping_sub_mul_assign<Scalar, OutputCont, InputCont1, InputC
|
||||
}
|
||||
}
|
||||
|
||||
pub fn polynomial_wrapping_sub_mul_assign_custom_mod<Scalar, OutputCont, InputCont1, InputCont2>(
|
||||
output: &mut Polynomial<OutputCont>,
|
||||
lhs: &Polynomial<InputCont1>,
|
||||
rhs: &Polynomial<InputCont2>,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
InputCont1: Container<Element = Scalar>,
|
||||
InputCont2: Container<Element = Scalar>,
|
||||
{
|
||||
assert!(
|
||||
output.polynomial_size() == lhs.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input lhs polynomial {:?}.",
|
||||
output.polynomial_size(),
|
||||
lhs.polynomial_size(),
|
||||
);
|
||||
assert!(
|
||||
output.polynomial_size() == rhs.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input rhs polynomial {:?}.",
|
||||
output.polynomial_size(),
|
||||
rhs.polynomial_size(),
|
||||
);
|
||||
|
||||
let polynomial_size = output.polynomial_size();
|
||||
|
||||
if polynomial_size.0.is_power_of_two() && polynomial_size.0 > KARATUSBA_STOP {
|
||||
let mut tmp = Polynomial::new(Scalar::ZERO, polynomial_size);
|
||||
|
||||
polynomial_karatsuba_wrapping_mul_custom_mod(&mut tmp, lhs, rhs, custom_modulus);
|
||||
polynomial_wrapping_sub_assign_custom_mod(output, &tmp, custom_modulus);
|
||||
} else {
|
||||
let degree = output.degree();
|
||||
for (lhs_degree, &lhs_coeff) in lhs.iter().enumerate() {
|
||||
for (rhs_degree, &rhs_coeff) in rhs.iter().enumerate() {
|
||||
let target_degree = lhs_degree + rhs_degree;
|
||||
if target_degree <= degree {
|
||||
let output_coefficient = &mut output.as_mut()[target_degree];
|
||||
|
||||
*output_coefficient = (*output_coefficient).wrapping_sub_custom_mod(
|
||||
lhs_coeff.wrapping_mul_custom_mod(rhs_coeff, custom_modulus),
|
||||
custom_modulus,
|
||||
);
|
||||
} else {
|
||||
let target_degree = target_degree % polynomial_size.0;
|
||||
let output_coefficient = &mut output.as_mut()[target_degree];
|
||||
|
||||
*output_coefficient = (*output_coefficient).wrapping_add_custom_mod(
|
||||
lhs_coeff.wrapping_mul_custom_mod(rhs_coeff, custom_modulus),
|
||||
custom_modulus,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Fill the ouptut polynomial, with the result of the product of two polynomials, reduced modulo
|
||||
/// $(X^{N} + 1)$ with the schoolbook algorithm Complexity: $O(N^{2})$
|
||||
///
|
||||
@@ -530,6 +712,172 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub fn polynomial_karatsuba_wrapping_mul_custom_mod<Scalar, OutputCont, LhsCont, RhsCont>(
|
||||
output: &mut Polynomial<OutputCont>,
|
||||
p: &Polynomial<LhsCont>,
|
||||
q: &Polynomial<RhsCont>,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
OutputCont: ContainerMut<Element = Scalar>,
|
||||
LhsCont: Container<Element = Scalar>,
|
||||
RhsCont: Container<Element = Scalar>,
|
||||
{
|
||||
// check same dimensions
|
||||
assert!(
|
||||
output.polynomial_size() == p.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input lhs polynomial {:?}.",
|
||||
output.polynomial_size(),
|
||||
p.polynomial_size(),
|
||||
);
|
||||
assert!(
|
||||
output.polynomial_size() == q.polynomial_size(),
|
||||
"Output polynomial size {:?} is not the same as input rhs polynomial {:?}.",
|
||||
output.polynomial_size(),
|
||||
q.polynomial_size(),
|
||||
);
|
||||
|
||||
let poly_size = output.polynomial_size().0;
|
||||
|
||||
// check dimensions are a power of 2
|
||||
assert!(poly_size.is_power_of_two());
|
||||
|
||||
// allocate slices for the rec
|
||||
let mut a0 = vec![Scalar::ZERO; poly_size];
|
||||
let mut a1 = vec![Scalar::ZERO; poly_size];
|
||||
let mut a2 = vec![Scalar::ZERO; poly_size];
|
||||
let mut input_a2_p = vec![Scalar::ZERO; poly_size / 2];
|
||||
let mut input_a2_q = vec![Scalar::ZERO; poly_size / 2];
|
||||
|
||||
// prepare for splitting
|
||||
let bottom = 0..(poly_size / 2);
|
||||
let top = (poly_size / 2)..poly_size;
|
||||
|
||||
// induction
|
||||
induction_karatsuba_custom_mod(
|
||||
&mut a0,
|
||||
&p[bottom.clone()],
|
||||
&q[bottom.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
induction_karatsuba_custom_mod(&mut a1, &p[top.clone()], &q[top.clone()], custom_modulus);
|
||||
slice_wrapping_add_custom_mod(
|
||||
&mut input_a2_p,
|
||||
&p[bottom.clone()],
|
||||
&p[top.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_add_custom_mod(
|
||||
&mut input_a2_q,
|
||||
&q[bottom.clone()],
|
||||
&q[top.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
induction_karatsuba_custom_mod(&mut a2, &input_a2_p, &input_a2_q, custom_modulus);
|
||||
|
||||
// rebuild the result
|
||||
let output: &mut [Scalar] = output.as_mut();
|
||||
slice_wrapping_sub_custom_mod(output, &a0, &a1, custom_modulus);
|
||||
slice_wrapping_sub_assign_custom_mod(
|
||||
&mut output[bottom.clone()],
|
||||
&a2[top.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_add_assign_custom_mod(
|
||||
&mut output[bottom.clone()],
|
||||
&a0[top.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_add_assign_custom_mod(
|
||||
&mut output[bottom.clone()],
|
||||
&a1[top.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_add_assign_custom_mod(
|
||||
&mut output[top.clone()],
|
||||
&a2[bottom.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_sub_assign_custom_mod(
|
||||
&mut output[top.clone()],
|
||||
&a0[bottom.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_sub_assign_custom_mod(&mut output[top], &a1[bottom], custom_modulus);
|
||||
}
|
||||
|
||||
fn induction_karatsuba_custom_mod<Scalar>(
|
||||
res: &mut [Scalar],
|
||||
p: &[Scalar],
|
||||
q: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
// stop the induction when polynomials have KARATUSBA_STOP elements
|
||||
if p.len() <= KARATUSBA_STOP {
|
||||
// schoolbook algorithm
|
||||
for (lhs_degree, &lhs_elt) in p.iter().enumerate() {
|
||||
let res = &mut res[lhs_degree..];
|
||||
for (&rhs_elt, res) in q.iter().zip(res) {
|
||||
*res = (*res).wrapping_add_custom_mod(
|
||||
lhs_elt.wrapping_mul_custom_mod(rhs_elt, custom_modulus),
|
||||
custom_modulus,
|
||||
)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let poly_size = res.len();
|
||||
|
||||
// allocate slices for the rec
|
||||
let mut a0 = vec![Scalar::ZERO; poly_size / 2];
|
||||
let mut a1 = vec![Scalar::ZERO; poly_size / 2];
|
||||
let mut a2 = vec![Scalar::ZERO; poly_size / 2];
|
||||
let mut input_a2_p = vec![Scalar::ZERO; poly_size / 4];
|
||||
let mut input_a2_q = vec![Scalar::ZERO; poly_size / 4];
|
||||
|
||||
// prepare for splitting
|
||||
let bottom = 0..(poly_size / 4);
|
||||
let top = (poly_size / 4)..(poly_size / 2);
|
||||
|
||||
// rec
|
||||
induction_karatsuba_custom_mod(
|
||||
&mut a0,
|
||||
&p[bottom.clone()],
|
||||
&q[bottom.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
induction_karatsuba_custom_mod(&mut a1, &p[top.clone()], &q[top.clone()], custom_modulus);
|
||||
slice_wrapping_add_custom_mod(
|
||||
&mut input_a2_p,
|
||||
&p[bottom.clone()],
|
||||
&p[top.clone()],
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_add_custom_mod(&mut input_a2_q, &q[bottom], &q[top], custom_modulus);
|
||||
induction_karatsuba_custom_mod(&mut a2, &input_a2_p, &input_a2_q, custom_modulus);
|
||||
|
||||
// rebuild the result
|
||||
slice_wrapping_sub_custom_mod(
|
||||
&mut res[(poly_size / 4)..(3 * poly_size / 4)],
|
||||
&a2,
|
||||
&a0,
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_sub_assign_custom_mod(
|
||||
&mut res[(poly_size / 4)..(3 * poly_size / 4)],
|
||||
&a1,
|
||||
custom_modulus,
|
||||
);
|
||||
slice_wrapping_add_assign_custom_mod(&mut res[0..(poly_size / 2)], &a0, custom_modulus);
|
||||
slice_wrapping_add_assign_custom_mod(
|
||||
&mut res[(poly_size / 2)..poly_size],
|
||||
&a1,
|
||||
custom_modulus,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use rand::Rng;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Module with primitives pertaining to [`SeededGlweCiphertext`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::math::random::RandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -37,10 +38,11 @@ pub fn decompress_seeded_glwe_ciphertext_with_existing_generator<
|
||||
|
||||
// generate a uniformly random mask
|
||||
generator.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
output_body
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Module with primitives pertaining to [`SeededGlweCiphertextList`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::math::random::RandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -39,10 +40,11 @@ pub fn decompress_seeded_glwe_ciphertext_list_with_existing_generator<
|
||||
// generate a uniformly random mask
|
||||
generator
|
||||
.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
output_body.as_mut().copy_from_slice(body_in.as_ref());
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Module with primitives pertaining to [`SeededLweCiphertext`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::math::random::RandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -30,10 +31,11 @@ pub fn decompress_seeded_lwe_ciphertext_with_existing_generator<Scalar, OutputCo
|
||||
|
||||
// generate a uniformly random mask
|
||||
generator.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// Manage the specific encoding for non native power of 2
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
*output_body.data = *input_seeded_lwe.get_body().data;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
//! Module with primitives pertaining to [`SeededLweCiphertextList`] decompression.
|
||||
|
||||
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
|
||||
use crate::core_crypto::commons::math::random::RandomGenerator;
|
||||
use crate::core_crypto::commons::traits::*;
|
||||
use crate::core_crypto::entities::*;
|
||||
@@ -39,10 +40,11 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
|
||||
// generate a uniformly random mask
|
||||
generator
|
||||
.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// Manage the non native power of 2 encoding
|
||||
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
|
||||
slice_wrapping_scalar_mul_assign(
|
||||
output_mask.as_mut(),
|
||||
ciphertext_modulus.get_scaling_to_native_torus(),
|
||||
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
|
||||
);
|
||||
}
|
||||
*output_body.data = *body_in.data;
|
||||
|
||||
@@ -36,6 +36,31 @@ where
|
||||
})
|
||||
}
|
||||
|
||||
/// This primitive is meant to manage the dot product avoiding overflow on multiplication by casting
|
||||
/// to u128, for example for u64, avoiding overflow on each multiplication (as u64::MAX * u64::MAX <
|
||||
/// u128::MAX)
|
||||
pub fn slice_wrapping_dot_product_custom_mod<Scalar>(
|
||||
lhs: &[Scalar],
|
||||
rhs: &[Scalar],
|
||||
modulus: Scalar,
|
||||
) -> Scalar
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
|
||||
lhs.iter()
|
||||
.zip(rhs.iter())
|
||||
.fold(Scalar::ZERO, |acc, (&left, &right)| {
|
||||
acc.wrapping_add_custom_mod(left.wrapping_mul_custom_mod(right, modulus), modulus)
|
||||
})
|
||||
}
|
||||
|
||||
/// Add a slice containing unsigned integers to another one element-wise.
|
||||
///
|
||||
/// # Note
|
||||
@@ -76,6 +101,33 @@ where
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_add(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_add_custom_mod<Scalar>(
|
||||
output: &mut [Scalar],
|
||||
lhs: &[Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
assert!(
|
||||
output.len() == lhs.len(),
|
||||
"output (len: {}) and rhs (len: {}) must have the same length",
|
||||
output.len(),
|
||||
lhs.len()
|
||||
);
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.zip(lhs.iter().zip(rhs.iter()))
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_add_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Add a slice containing unsigned integers to another one element-wise and in place.
|
||||
///
|
||||
/// # Note
|
||||
@@ -108,6 +160,25 @@ where
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_add(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_add_assign_custom_mod<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
|
||||
lhs.iter_mut()
|
||||
.zip(rhs.iter())
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_add_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Add a slice containing unsigned integers to another one mutiplied by a scalar.
|
||||
///
|
||||
/// Let *a*,*b* be two slices, let *c* be a scalar, this computes: *a <- a+bc*
|
||||
@@ -185,6 +256,33 @@ where
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_sub(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_sub_custom_mod<Scalar>(
|
||||
output: &mut [Scalar],
|
||||
lhs: &[Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
assert!(
|
||||
output.len() == lhs.len(),
|
||||
"output (len: {}) and rhs (len: {}) must have the same length",
|
||||
output.len(),
|
||||
lhs.len()
|
||||
);
|
||||
|
||||
output
|
||||
.iter_mut()
|
||||
.zip(lhs.iter().zip(rhs.iter()))
|
||||
.for_each(|(out, (&lhs, &rhs))| *out = lhs.wrapping_sub_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Subtract a slice containing unsigned integers to another one, element-wise and in place.
|
||||
///
|
||||
/// # Note
|
||||
@@ -217,6 +315,25 @@ where
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_sub(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_sub_assign_custom_mod<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: &[Scalar],
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
|
||||
lhs.iter_mut()
|
||||
.zip(rhs.iter())
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_sub_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
/// Subtract a slice containing unsigned integers to another one mutiplied by a scalar,
|
||||
/// element-wise and in place.
|
||||
///
|
||||
@@ -254,6 +371,28 @@ pub fn slice_wrapping_sub_scalar_mul_assign<Scalar>(
|
||||
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_sub(rhs.wrapping_mul(scalar)));
|
||||
}
|
||||
|
||||
/// This primitive is meant to manage the sub_scalar_mul operation for values that were cast to a
|
||||
/// bigger type, for example u64 to u128, avoiding overflow on each multiplication (as u64::MAX *
|
||||
/// u64::MAX < u128::MAX )
|
||||
pub fn slice_wrapping_sub_scalar_mul_assign_custom_modulus<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: &[Scalar],
|
||||
scalar: Scalar,
|
||||
modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
assert!(
|
||||
lhs.len() == rhs.len(),
|
||||
"lhs (len: {}) and rhs (len: {}) must have the same length",
|
||||
lhs.len(),
|
||||
rhs.len()
|
||||
);
|
||||
lhs.iter_mut().zip(rhs.iter()).for_each(|(lhs, &rhs)| {
|
||||
*lhs = (*lhs).wrapping_sub_custom_mod(rhs.wrapping_mul_custom_mod(scalar, modulus), modulus)
|
||||
});
|
||||
}
|
||||
|
||||
/// Compute the opposite of a slice containing unsigned integers, element-wise and in place.
|
||||
///
|
||||
/// # Note
|
||||
@@ -302,6 +441,17 @@ where
|
||||
.for_each(|lhs| *lhs = (*lhs).wrapping_mul(rhs));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_scalar_mul_assign_custom_mod<Scalar>(
|
||||
lhs: &mut [Scalar],
|
||||
rhs: Scalar,
|
||||
custom_modulus: Scalar,
|
||||
) where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
lhs.iter_mut()
|
||||
.for_each(|lhs| *lhs = (*lhs).wrapping_mul_custom_mod(rhs, custom_modulus));
|
||||
}
|
||||
|
||||
pub fn slice_wrapping_scalar_div_assign<Scalar>(lhs: &mut [Scalar], rhs: Scalar)
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
|
||||
@@ -156,3 +156,10 @@ fn test_parallel_and_seeded_bsk_gen_equivalence_u64_custom_mod() {
|
||||
CiphertextModulus::try_new_power_of_2(63).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_and_seeded_bsk_gen_equivalence_u64_solinas_mod() {
|
||||
test_parallel_and_seeded_bsk_gen_equivalence::<u64>(
|
||||
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -133,6 +133,11 @@ fn test_parallel_and_seeded_lwe_list_encryption_equivalence_non_native_power_of_
|
||||
test_parallel_and_seeded_lwe_list_encryption_equivalence(TEST_PARAMS_3_BITS_63_U64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_and_seeded_lwe_list_encryption_equivalence_solinas_mod_u64() {
|
||||
test_parallel_and_seeded_lwe_list_encryption_equivalence(TEST_PARAMS_3_BITS_SOLINAS_U64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parallel_and_seeded_lwe_list_encryption_equivalence_native_mod_u32() {
|
||||
test_parallel_and_seeded_lwe_list_encryption_equivalence(DUMMY_NATIVE_U32);
|
||||
|
||||
@@ -3,6 +3,7 @@ use super::*;
|
||||
fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
|
||||
let lwe_dimension = params.lwe_dimension;
|
||||
let lwe_modular_std_dev = params.lwe_modular_std_dev;
|
||||
let glwe_moduluar_std_dev = params.glwe_modular_std_dev;
|
||||
let ciphertext_modulus = params.ciphertext_modulus;
|
||||
let message_modulus_log = params.message_modulus_log;
|
||||
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
|
||||
@@ -21,6 +22,7 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<S
|
||||
while msg != Scalar::ZERO {
|
||||
msg = msg.wrapping_sub(Scalar::ONE);
|
||||
for _ in 0..NB_TESTS {
|
||||
println!("{msg}");
|
||||
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
|
||||
lwe_dimension,
|
||||
&mut rsc.secret_random_generator,
|
||||
@@ -54,7 +56,7 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<S
|
||||
let ct = allocate_and_encrypt_new_lwe_ciphertext(
|
||||
&big_lwe_sk,
|
||||
plaintext,
|
||||
lwe_modular_std_dev,
|
||||
glwe_moduluar_std_dev,
|
||||
ciphertext_modulus,
|
||||
&mut rsc.encryption_random_generator,
|
||||
);
|
||||
|
||||
@@ -112,3 +112,10 @@ fn test_seeded_lwe_ksk_gen_equivalence_u32_custom_mod() {
|
||||
fn test_seeded_lwe_ksk_gen_equivalence_u64_custom_mod() {
|
||||
test_seeded_lwe_ksk_gen_equivalence::<u64>(CiphertextModulus::try_new_power_of_2(63).unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_seeded_lwe_ksk_gen_equivalence_u64_solinas_mod() {
|
||||
test_seeded_lwe_ksk_gen_equivalence::<u64>(
|
||||
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -126,7 +126,10 @@ fn lwe_encrypt_pbs_decrypt_custom_mod<
|
||||
}
|
||||
}
|
||||
|
||||
create_parametrized_test!(lwe_encrypt_pbs_decrypt_custom_mod);
|
||||
create_parametrized_test!(lwe_encrypt_pbs_decrypt_custom_mod {
|
||||
TEST_PARAMS_4_BITS_NATIVE_U64,
|
||||
TEST_PARAMS_3_BITS_63_U64
|
||||
});
|
||||
|
||||
// DISCLAIMER: all parameters here are not guaranteed to be secure or yield correct computations
|
||||
pub const TEST_PARAMS_4_BITS_NATIVE_U128: TestParams<u128> = TestParams {
|
||||
|
||||
@@ -128,6 +128,25 @@ pub const DUMMY_31_U32: TestParams<u32> = TestParams {
|
||||
ciphertext_modulus: unsafe { CiphertextModulus::new_unchecked(1 << 31) },
|
||||
};
|
||||
|
||||
pub const TEST_PARAMS_3_BITS_SOLINAS_U64: TestParams<u64> = TestParams {
|
||||
lwe_dimension: LweDimension(742),
|
||||
glwe_dimension: GlweDimension(1),
|
||||
polynomial_size: PolynomialSize(2048),
|
||||
lwe_modular_std_dev: StandardDev(0.000007069849454709433),
|
||||
glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533),
|
||||
pbs_base_log: DecompositionBaseLog(23),
|
||||
pbs_level: DecompositionLevelCount(1),
|
||||
ks_level: DecompositionLevelCount(5),
|
||||
ks_base_log: DecompositionBaseLog(3),
|
||||
pfks_level: DecompositionLevelCount(1),
|
||||
pfks_base_log: DecompositionBaseLog(23),
|
||||
pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533),
|
||||
cbs_level: DecompositionLevelCount(0),
|
||||
cbs_base_log: DecompositionBaseLog(0),
|
||||
message_modulus_log: CiphertextModulusLog(3),
|
||||
ciphertext_modulus: unsafe { CiphertextModulus::new_unchecked((1 << 64) - (1 << 32) + 1) },
|
||||
};
|
||||
|
||||
// Our representation of non native power of 2 moduli puts the information in the MSBs and leaves
|
||||
// the LSBs empty, this is what this function is checking
|
||||
pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]>>(
|
||||
@@ -135,26 +154,42 @@ pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]
|
||||
modulus: CiphertextModulus<Scalar>,
|
||||
) -> bool {
|
||||
if !modulus.is_native_modulus() {
|
||||
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
|
||||
// we want the bits under the mask to be 0
|
||||
let power_2_diff_mask = modulus.get_scaling_to_native_torus() - Scalar::ONE;
|
||||
return input
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|&x| (x & power_2_diff_mask) == Scalar::ZERO);
|
||||
// Power of two has a specific encoding in MSBs of the native torus to re-use native torus
|
||||
// implementations for free
|
||||
if modulus.is_power_of_two() {
|
||||
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
|
||||
// we want the bits under the mask to be 0
|
||||
let power_2_diff_mask =
|
||||
modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
|
||||
return input
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|&x| (x & power_2_diff_mask) == Scalar::ZERO);
|
||||
} else {
|
||||
// Custom moduli non power of two use the "true" modulus representation
|
||||
return input
|
||||
.as_ref()
|
||||
.iter()
|
||||
.all(|&x| x < Scalar::cast_from(modulus.get_custom_modulus()));
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
// See above
|
||||
// See above comments for modulus check logic
|
||||
pub fn check_scalar_respects_mod<Scalar: UnsignedInteger>(
|
||||
input: Scalar,
|
||||
modulus: CiphertextModulus<Scalar>,
|
||||
) -> bool {
|
||||
if !modulus.is_native_modulus() {
|
||||
let power_2_diff_mask = modulus.get_scaling_to_native_torus() - Scalar::ONE;
|
||||
return (input & power_2_diff_mask) == Scalar::ZERO;
|
||||
if modulus.is_power_of_two() {
|
||||
let power_2_diff_mask =
|
||||
modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
|
||||
return (input & power_2_diff_mask) == Scalar::ZERO;
|
||||
} else {
|
||||
return input < Scalar::cast_from(modulus.get_custom_modulus());
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
@@ -243,6 +278,7 @@ macro_rules! create_parametrized_test{
|
||||
create_parametrized_test!($name
|
||||
{
|
||||
TEST_PARAMS_4_BITS_NATIVE_U64,
|
||||
TEST_PARAMS_3_BITS_SOLINAS_U64,
|
||||
TEST_PARAMS_3_BITS_63_U64
|
||||
});
|
||||
};
|
||||
|
||||
@@ -21,6 +21,12 @@ pub struct CiphertextModulus<Scalar: UnsignedInteger> {
|
||||
_scalar: PhantomData<Scalar>,
|
||||
}
|
||||
|
||||
pub enum CiphertextModulusKind {
|
||||
Native,
|
||||
NonNativePowerOfTwo,
|
||||
NonNative,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize)]
|
||||
struct SerialiazableLweCiphertextModulus {
|
||||
pub modulus: u128,
|
||||
@@ -121,6 +127,30 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn try_new(modulus: u128) -> Result<Self, &'static str> {
|
||||
if Scalar::BITS < 128 && modulus > (1 << Scalar::BITS) {
|
||||
Err("Modulus is bigger than the maximum value of the associated Scalar type")
|
||||
} else {
|
||||
let res = match modulus {
|
||||
0 => CiphertextModulus::new_native(),
|
||||
modulus => {
|
||||
let Some(non_zero_modulus) = NonZeroU128::new(modulus) else {
|
||||
panic!("Got zero modulus for CiphertextModulusInner::Custom variant",)
|
||||
};
|
||||
CiphertextModulus {
|
||||
inner: CiphertextModulusInner::Custom(non_zero_modulus),
|
||||
_scalar: PhantomData,
|
||||
}
|
||||
}
|
||||
};
|
||||
let canonicalized_result = res.canonicalize();
|
||||
if Scalar::BITS > 64 && !canonicalized_result.is_compatible_with_native_modulus() {
|
||||
return Err("Non power of 2 moduli are not supported for types wider than u64");
|
||||
}
|
||||
Ok(canonicalized_result)
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn canonicalize(self) -> Self {
|
||||
match self.inner {
|
||||
CiphertextModulusInner::Native => self,
|
||||
@@ -151,10 +181,15 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
|
||||
res.canonicalize()
|
||||
}
|
||||
|
||||
pub fn get_scaling_to_native_torus(&self) -> Scalar {
|
||||
#[track_caller]
|
||||
pub fn get_power_of_two_scaling_to_native_torus(&self) -> Scalar {
|
||||
match self.inner {
|
||||
CiphertextModulusInner::Native => Scalar::ONE,
|
||||
CiphertextModulusInner::Custom(modulus) => {
|
||||
assert!(
|
||||
modulus.is_power_of_two(),
|
||||
"Cannot get scaling for non power of two modulus {modulus:}"
|
||||
);
|
||||
Scalar::ONE.wrapping_shl(Scalar::BITS as u32 - modulus.ilog2())
|
||||
}
|
||||
}
|
||||
@@ -182,12 +217,32 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
|
||||
self.is_native_modulus() || self.is_power_of_two()
|
||||
}
|
||||
|
||||
pub const fn is_non_native_power_of_two(&self) -> bool {
|
||||
match self.inner {
|
||||
CiphertextModulusInner::Native => false,
|
||||
CiphertextModulusInner::Custom(modulus) => modulus.is_power_of_two(),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn is_power_of_two(&self) -> bool {
|
||||
match self.inner {
|
||||
CiphertextModulusInner::Native => true,
|
||||
CiphertextModulusInner::Custom(modulus) => modulus.is_power_of_two(),
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn kind(&self) -> CiphertextModulusKind {
|
||||
match self.inner {
|
||||
CiphertextModulusInner::Native => CiphertextModulusKind::Native,
|
||||
CiphertextModulusInner::Custom(modulus) => {
|
||||
if modulus.is_power_of_two() {
|
||||
CiphertextModulusKind::NonNativePowerOfTwo
|
||||
} else {
|
||||
CiphertextModulusKind::NonNative
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Scalar: UnsignedInteger> std::fmt::Display for CiphertextModulus<Scalar> {
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecompositionIter;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
SignedDecompositionIter, SignedDecompositionIterNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
use crate::core_crypto::prelude::misc::divide_round_to_u128_custom_mod;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// A structure which allows to decompose unsigned integers into a set of smaller terms.
|
||||
@@ -174,3 +178,215 @@ where
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A structure which allows to decompose unsigned integers into a set of smaller terms for moduli
|
||||
/// which are non power of 2.
|
||||
///
|
||||
/// See the [module level](super) documentation for a description of the signed decomposition.
|
||||
#[derive(Debug)]
|
||||
pub struct SignedDecomposerNonNative<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
pub(crate) base_log: usize,
|
||||
pub(crate) level_count: usize,
|
||||
pub(crate) ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
}
|
||||
|
||||
impl<Scalar> SignedDecomposerNonNative<Scalar>
|
||||
where
|
||||
Scalar: UnsignedInteger,
|
||||
{
|
||||
/// Create a new decomposer.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3));
|
||||
/// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4));
|
||||
/// ```
|
||||
pub fn new(
|
||||
base_log: DecompositionBaseLog,
|
||||
level_count: DecompositionLevelCount,
|
||||
ciphertext_modulus: CiphertextModulus<Scalar>,
|
||||
) -> SignedDecomposerNonNative<Scalar> {
|
||||
debug_assert!(
|
||||
Scalar::BITS > base_log.0 * level_count.0,
|
||||
"Decomposed bits exceeds the size of the integer to be decomposed"
|
||||
);
|
||||
SignedDecomposerNonNative {
|
||||
base_log: base_log.0,
|
||||
level_count: level_count.0,
|
||||
ciphertext_modulus,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the logarithm in base two of the base of this decomposer.
|
||||
///
|
||||
/// If the decomposer uses a base $B=2^b$, this returns $b$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(decomposer.base_log(), DecompositionBaseLog(4));
|
||||
/// ```
|
||||
pub fn base_log(&self) -> DecompositionBaseLog {
|
||||
DecompositionBaseLog(self.base_log)
|
||||
}
|
||||
|
||||
/// Return the number of levels of this decomposer.
|
||||
///
|
||||
/// If the decomposer uses $l$ levels, this returns $l$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(decomposer.level_count(), DecompositionLevelCount(3));
|
||||
/// ```
|
||||
pub fn level_count(&self) -> DecompositionLevelCount {
|
||||
DecompositionLevelCount(self.level_count)
|
||||
}
|
||||
|
||||
/// Return the ciphertext modulus of this decomposer.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::<u64>::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// assert_eq!(
|
||||
/// decomposer.ciphertext_modulus(),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap()
|
||||
/// );
|
||||
/// ```
|
||||
pub fn ciphertext_modulus(&self) -> CiphertextModulus<Scalar> {
|
||||
self.ciphertext_modulus
|
||||
}
|
||||
|
||||
/// Return the closet value representable by the decomposition.
|
||||
///
|
||||
/// For some input integer `k`, decomposition base `B`, decomposition level count `l` and given
|
||||
/// ciphertext modulus `q` the performed operation is the following:
|
||||
///
|
||||
/// $$
|
||||
/// \lfloor \frac{k\cdot q}{B^{l}} \rceil \cdot B^{l}
|
||||
/// $$
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// let closest = decomposer.closest_representable(16982820785129133100u64);
|
||||
/// assert_eq!(closest, 16983074190859960320u64);
|
||||
/// ```
|
||||
#[inline]
|
||||
pub fn closest_representable(&self, input: Scalar) -> Scalar {
|
||||
let ciphertext_modulus = self.ciphertext_modulus.get_custom_modulus();
|
||||
// Floored approach
|
||||
// B^l
|
||||
let base_to_level_count = 1 << (self.base_log * self.level_count);
|
||||
// sr = floor(q/(B^l))
|
||||
let smallest_representable = ciphertext_modulus / base_to_level_count;
|
||||
|
||||
let input_128: u128 = input.cast_into();
|
||||
// rounded = round(input/sr)
|
||||
let rounded =
|
||||
divide_round_to_u128_custom_mod(input_128, smallest_representable, ciphertext_modulus);
|
||||
// rounded * sr
|
||||
let closest_representable = rounded * smallest_representable;
|
||||
Scalar::cast_from(closest_representable)
|
||||
}
|
||||
|
||||
/// Generate an iterator over the terms of the decomposition of the input.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// The returned iterator yields the terms $\tilde{\theta}\_i$ in order of decreasing $i$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::numeric::UnsignedInteger;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
///
|
||||
/// // These two values allow to take each arm of the half basis check below
|
||||
/// for value in [1u64 << 63, 16982820785129133100u64] {
|
||||
/// for term in decomposer.decompose(value) {
|
||||
/// assert!(1 <= term.level().0);
|
||||
/// assert!(term.level().0 <= 3);
|
||||
/// let term = term.value();
|
||||
/// let abs_term = if term < decomposer.ciphertext_modulus().get_custom_modulus() as u64 / 2
|
||||
/// {
|
||||
/// term
|
||||
/// } else {
|
||||
/// decomposer.ciphertext_modulus().get_custom_modulus() as u64 - term
|
||||
/// };
|
||||
/// println!("abs_term: {abs_term}");
|
||||
/// let half_basis = 2u64.pow(4) / 2u64;
|
||||
/// println!("half_basis: {half_basis}");
|
||||
/// assert!(abs_term <= half_basis);
|
||||
/// }
|
||||
/// assert_eq!(decomposer.decompose(1).count(), 3);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn decompose(&self, input: Scalar) -> SignedDecompositionIterNonNative<Scalar> {
|
||||
// Note that there would be no sense of making the decomposition on an input which was
|
||||
// not rounded to the closest representable first. We then perform it before decomposing.
|
||||
SignedDecompositionIterNonNative::new(
|
||||
self.closest_representable(input),
|
||||
DecompositionBaseLog(self.base_log),
|
||||
DecompositionLevelCount(self.level_count),
|
||||
self.ciphertext_modulus,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, DecompositionTerm};
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
DecompositionLevel, DecompositionTerm, DecompositionTermNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::numeric::UnsignedInteger;
|
||||
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
|
||||
@@ -122,3 +125,158 @@ fn decompose_one_level<S: UnsignedInteger>(base_log: usize, state: &mut S, mod_b
|
||||
*state += carry;
|
||||
res.wrapping_sub(carry << base_log)
|
||||
}
|
||||
|
||||
/// An iterator that yields the terms of the signed decomposition of an integer.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// This iterator yields the decomposition in reverse order. That means that the highest level
|
||||
/// will be yielded first.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SignedDecompositionIterNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
// The base log of the decomposition
|
||||
base_log: usize,
|
||||
// The number of levels of the decomposition
|
||||
level_count: usize,
|
||||
// The internal state of the decomposition
|
||||
state: T,
|
||||
// The current level
|
||||
current_level: usize,
|
||||
// A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form:
|
||||
// ...0001111
|
||||
mod_b_mask: T,
|
||||
// Ciphertext modulus
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
// A flag which store whether the iterator is a fresh one (for the recompose method)
|
||||
fresh: bool,
|
||||
}
|
||||
|
||||
impl<T> SignedDecompositionIterNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
pub(crate) fn new(
|
||||
input: T,
|
||||
base_log: DecompositionBaseLog,
|
||||
level: DecompositionLevelCount,
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) -> SignedDecompositionIterNonNative<T> {
|
||||
let base_to_the_level = 1 << (base_log.0 * level.0);
|
||||
let smallest_representable = ciphertext_modulus.get_custom_modulus() / base_to_the_level;
|
||||
|
||||
let input_128: u128 = input.cast_into();
|
||||
let state = T::cast_from(input_128 / smallest_representable);
|
||||
|
||||
SignedDecompositionIterNonNative {
|
||||
base_log: base_log.0,
|
||||
level_count: level.0,
|
||||
state,
|
||||
current_level: level.0,
|
||||
mod_b_mask: (T::ONE << base_log.0) - T::ONE,
|
||||
ciphertext_modulus,
|
||||
fresh: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_fresh(&self) -> bool {
|
||||
self.fresh
|
||||
}
|
||||
|
||||
/// Return the logarithm in base two of the base of this decomposition.
|
||||
///
|
||||
/// If the decomposition uses a base $B=2^b$, this returns $b$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// let val = 9_223_372_036_854_775_808u64;
|
||||
/// let decomp = decomposer.decompose(val);
|
||||
/// assert_eq!(decomp.base_log(), DecompositionBaseLog(4));
|
||||
/// ```
|
||||
pub fn base_log(&self) -> DecompositionBaseLog {
|
||||
DecompositionBaseLog(self.base_log)
|
||||
}
|
||||
|
||||
/// Return the number of levels of this decomposition.
|
||||
///
|
||||
/// If the decomposition uses $l$ levels, this returns $l$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
/// );
|
||||
/// let val = 9_223_372_036_854_775_808u64;
|
||||
/// let decomp = decomposer.decompose(val);
|
||||
/// assert_eq!(decomp.level_count(), DecompositionLevelCount(3));
|
||||
/// ```
|
||||
pub fn level_count(&self) -> DecompositionLevelCount {
|
||||
DecompositionLevelCount(self.level_count)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Iterator for SignedDecompositionIterNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
type Item = DecompositionTermNonNative<T>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
// The iterator is not fresh anymore
|
||||
self.fresh = false;
|
||||
// We check if the decomposition is over
|
||||
if self.current_level == 0 {
|
||||
return None;
|
||||
}
|
||||
// We decompose the current level
|
||||
let output = decompose_one_level_non_native(
|
||||
self.base_log,
|
||||
&mut self.state,
|
||||
self.mod_b_mask,
|
||||
T::cast_from(self.ciphertext_modulus.get_custom_modulus()),
|
||||
);
|
||||
self.current_level -= 1;
|
||||
// We return the output for this level
|
||||
Some(DecompositionTermNonNative::new(
|
||||
DecompositionLevel(self.current_level + 1),
|
||||
DecompositionBaseLog(self.base_log),
|
||||
output,
|
||||
self.ciphertext_modulus,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn decompose_one_level_non_native<S: UnsignedInteger>(
|
||||
base_log: usize,
|
||||
state: &mut S,
|
||||
mod_b_mask: S,
|
||||
ciphertext_modulus: S,
|
||||
) -> S {
|
||||
let res = *state & mod_b_mask;
|
||||
*state >>= base_log;
|
||||
let mut carry = (res.wrapping_sub(S::ONE) | *state) & res;
|
||||
carry >>= base_log - 1;
|
||||
*state += carry;
|
||||
res.wrapping_add(ciphertext_modulus)
|
||||
.wrapping_sub(carry << base_log)
|
||||
.wrapping_rem(ciphertext_modulus)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::DecompositionLevel;
|
||||
use crate::core_crypto::commons::numeric::{Numeric, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::DecompositionBaseLog;
|
||||
use crate::core_crypto::commons::traits::CastInto;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
|
||||
@@ -91,3 +93,119 @@ where
|
||||
DecompositionLevel(self.level)
|
||||
}
|
||||
}
|
||||
|
||||
/// A member of the decomposition.
|
||||
///
|
||||
/// If we decompose a value $\theta$ as a sum $\sum\_{i=1}^l\tilde{\theta}\_i\frac{q}{B^i}$, this
|
||||
/// represents a $\tilde{\theta}\_i$.
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
|
||||
pub struct DecompositionTermNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
level: usize,
|
||||
base_log: usize,
|
||||
value: T,
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
}
|
||||
|
||||
impl<T> DecompositionTermNonNative<T>
|
||||
where
|
||||
T: UnsignedInteger,
|
||||
{
|
||||
// Creates a new decomposition term.
|
||||
pub(crate) fn new(
|
||||
level: DecompositionLevel,
|
||||
base_log: DecompositionBaseLog,
|
||||
value: T,
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) -> DecompositionTermNonNative<T> {
|
||||
DecompositionTermNonNative {
|
||||
level: level.0,
|
||||
base_log: base_log.0,
|
||||
value,
|
||||
ciphertext_modulus,
|
||||
}
|
||||
}
|
||||
|
||||
/// Turn this term into a summand.
|
||||
///
|
||||
/// If our member represents one $\tilde{\theta}\_i$ of the decomposition, this method returns
|
||||
/// $\tilde{\theta}\_i\frac{q}{B^i}$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new(1 << 32).unwrap(),
|
||||
/// );
|
||||
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
|
||||
/// assert_eq!(output.to_recomposition_summand(), 1048576);
|
||||
/// ```
|
||||
pub fn to_recomposition_summand(&self) -> T {
|
||||
// Floored approach
|
||||
// * floor(q / B^j)
|
||||
let base_to_the_level = T::ONE << (self.base_log * self.level);
|
||||
let digit_radix =
|
||||
T::cast_from(self.ciphertext_modulus.get_custom_modulus()) / base_to_the_level;
|
||||
|
||||
self.value.wrapping_mul_custom_mod(
|
||||
digit_radix,
|
||||
self.ciphertext_modulus.get_custom_modulus().cast_into(),
|
||||
)
|
||||
}
|
||||
|
||||
/// Return the value of the term.
|
||||
///
|
||||
/// If our member represents one $\tilde{\theta}\_i$, this returns its actual value.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::SignedDecomposerNonNative;
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new(1 << 32).unwrap(),
|
||||
/// );
|
||||
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
|
||||
/// assert_eq!(output.value(), 1);
|
||||
/// ```
|
||||
pub fn value(&self) -> T {
|
||||
self.value
|
||||
}
|
||||
|
||||
/// Return the level of the term.
|
||||
///
|
||||
/// If our member represents one $\tilde{\theta}\_i$, this returns the value of $i$.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use tfhe::core_crypto::commons::math::decomposition::{
|
||||
/// DecompositionLevel, SignedDecomposerNonNative,
|
||||
/// };
|
||||
/// use tfhe::core_crypto::commons::parameters::{
|
||||
/// CiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
/// };
|
||||
/// let decomposer = SignedDecomposerNonNative::new(
|
||||
/// DecompositionBaseLog(4),
|
||||
/// DecompositionLevelCount(3),
|
||||
/// CiphertextModulus::try_new(1 << 32).unwrap(),
|
||||
/// );
|
||||
/// let output = decomposer.decompose(2u64.pow(19)).next().unwrap();
|
||||
/// assert_eq!(output.level(), DecompositionLevel(3));
|
||||
/// ```
|
||||
pub fn level(&self) -> DecompositionLevel {
|
||||
DecompositionLevel(self.level)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
|
||||
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulus;
|
||||
use crate::core_crypto::commons::math::decomposition::{
|
||||
SignedDecomposer, SignedDecomposerNonNative,
|
||||
};
|
||||
use crate::core_crypto::commons::math::random::{RandomGenerable, Uniform};
|
||||
use crate::core_crypto::commons::math::torus::UnsignedTorus;
|
||||
use crate::core_crypto::commons::numeric::{Numeric, SignedInteger, UnsignedInteger};
|
||||
use crate::core_crypto::commons::parameters::{DecompositionBaseLog, DecompositionLevelCount};
|
||||
use crate::core_crypto::commons::test_tools::{any_uint, any_usize, random_usize_between};
|
||||
use crate::core_crypto::commons::traits::CastInto;
|
||||
use std::fmt::Debug;
|
||||
|
||||
// Return a random decomposition valid for the size of the T type.
|
||||
@@ -69,7 +73,7 @@ fn test_round_to_closest_representable<T: UnsignedTorus>() {
|
||||
let bit: usize = log_b * level_max;
|
||||
|
||||
let val = val << (bits - bit);
|
||||
let delta = delta >> (bits - (bits - bit - 1));
|
||||
let delta = delta >> (bit + 1);
|
||||
|
||||
let decomposer = SignedDecomposer::new(
|
||||
DecompositionBaseLog(log_b),
|
||||
@@ -117,3 +121,171 @@ fn test_round_to_closest_twice_u32() {
|
||||
fn test_round_to_closest_twice_u64() {
|
||||
test_round_to_closest_twice::<u64>();
|
||||
}
|
||||
|
||||
// Return a random decomposition valid for the size of the T type.
|
||||
fn random_decomp_non_native<T: UnsignedInteger>(
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) -> SignedDecomposerNonNative<T> {
|
||||
let mut base_log;
|
||||
let mut level_count;
|
||||
loop {
|
||||
base_log = random_usize_between(2..T::BITS);
|
||||
level_count = random_usize_between(2..T::BITS);
|
||||
if base_log * level_count < T::BITS {
|
||||
break;
|
||||
}
|
||||
}
|
||||
SignedDecomposerNonNative::new(
|
||||
DecompositionBaseLog(base_log),
|
||||
DecompositionLevelCount(level_count),
|
||||
ciphertext_modulus,
|
||||
)
|
||||
}
|
||||
|
||||
fn test_round_to_closest_representable_non_native<T: UnsignedTorus>(
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) {
|
||||
// Manage limit cases
|
||||
{
|
||||
let log_b = any_usize();
|
||||
let level_max = any_usize();
|
||||
let bits = T::BITS;
|
||||
let log_b = (log_b % ((bits / 4) - 1)) + 1;
|
||||
let level_count = (level_max % 4) + 1;
|
||||
let rep_bits: usize = log_b * level_count;
|
||||
|
||||
let base_to_the_level_u128 = 1u128 << rep_bits;
|
||||
let smallest_representable_u128 =
|
||||
ciphertext_modulus.get_custom_modulus() / base_to_the_level_u128;
|
||||
let sub_smallest_representable_u128 = smallest_representable_u128 / 2;
|
||||
// Compute an epsilon that should not change the result of a closest representable
|
||||
let epsilon_u128 = any_uint::<u128>() % sub_smallest_representable_u128;
|
||||
|
||||
// Around 0
|
||||
let val = T::ZERO;
|
||||
|
||||
let decomposer = SignedDecomposerNonNative::new(
|
||||
DecompositionBaseLog(log_b),
|
||||
DecompositionLevelCount(level_count),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let val_u128: u128 = val.cast_into();
|
||||
let val_plus_epsilon: T = val_u128
|
||||
.wrapping_add(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_plus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_plus_epsilon: {val_plus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
|
||||
let val_minus_epsilon: T = val_u128
|
||||
.wrapping_add(ciphertext_modulus.get_custom_modulus())
|
||||
.wrapping_sub(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_minus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_minus_epsilon: {val_minus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
}
|
||||
|
||||
for _ in 0..1000 {
|
||||
let log_b = any_usize();
|
||||
let level_max = any_usize();
|
||||
let bits = T::BITS;
|
||||
let log_b = (log_b % ((bits / 4) - 1)) + 1;
|
||||
let level_count = (level_max % 4) + 1;
|
||||
let rep_bits: usize = log_b * level_count;
|
||||
|
||||
let base_to_the_level_u128 = 1u128 << rep_bits;
|
||||
let base_to_the_level = T::ONE << rep_bits;
|
||||
let smallest_representable_u128 =
|
||||
ciphertext_modulus.get_custom_modulus() / base_to_the_level_u128;
|
||||
let smallest_representable: T = smallest_representable_u128.cast_into();
|
||||
let sub_smallest_representable_u128 = smallest_representable_u128 / 2;
|
||||
// Compute an epsilon that should not change the result of a closest representable
|
||||
let epsilon_u128 = any_uint::<u128>() % sub_smallest_representable_u128;
|
||||
|
||||
let multiple_of_smallest_representable = any_uint::<T>() % base_to_the_level;
|
||||
let val = multiple_of_smallest_representable * smallest_representable;
|
||||
|
||||
let decomposer = SignedDecomposerNonNative::new(
|
||||
DecompositionBaseLog(log_b),
|
||||
DecompositionLevelCount(level_count),
|
||||
ciphertext_modulus,
|
||||
);
|
||||
|
||||
let val_u128: u128 = val.cast_into();
|
||||
let val_plus_epsilon: T = val_u128
|
||||
.wrapping_add(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_plus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_plus_epsilon: {val_plus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
|
||||
let val_minus_epsilon: T = val_u128
|
||||
.wrapping_add(ciphertext_modulus.get_custom_modulus())
|
||||
.wrapping_sub(epsilon_u128)
|
||||
.wrapping_rem(ciphertext_modulus.get_custom_modulus())
|
||||
.cast_into();
|
||||
|
||||
let closest = decomposer.closest_representable(val_minus_epsilon);
|
||||
assert_eq!(
|
||||
val, closest,
|
||||
"\n val_minus_epsilon: {val_minus_epsilon:064b}, \n \
|
||||
expected_closest: {val:064b}, \n \
|
||||
closest: {closest:064b}\n \
|
||||
decomp_base_log: {}, decomp_level_count: {}",
|
||||
decomposer.base_log, decomposer.level_count
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_to_closest_representable_non_native_u64() {
|
||||
test_round_to_closest_representable_non_native::<u64>(
|
||||
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
fn test_round_to_closest_twice_non_native<T: UnsignedTorus + Debug>(
|
||||
ciphertext_modulus: CiphertextModulus<T>,
|
||||
) {
|
||||
for _ in 0..1000 {
|
||||
let decomp = random_decomp_non_native(ciphertext_modulus);
|
||||
let input: T = any_uint();
|
||||
|
||||
let rounded_once = decomp.closest_representable(input);
|
||||
let rounded_twice = decomp.closest_representable(rounded_once);
|
||||
assert_eq!(rounded_once, rounded_twice);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_round_to_closest_twice_non_native_u64() {
|
||||
test_round_to_closest_twice_non_native::<u64>(
|
||||
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -211,6 +211,8 @@ impl<G: ByteRandomGenerator> RandomGenerator<G> {
|
||||
) where
|
||||
Scalar: UnsignedInteger + RandomGenerable<Uniform>,
|
||||
{
|
||||
// TODO
|
||||
// Implement the proper generation function for custom mod for the uniform distribution
|
||||
self.fill_slice_with_random_uniform(output);
|
||||
|
||||
if !custom_modulus.is_native_modulus() {
|
||||
|
||||
@@ -45,18 +45,30 @@ pub trait UnsignedInteger:
|
||||
/// Compute a subtraction, modulo the max of the type.
|
||||
#[must_use]
|
||||
fn wrapping_sub(self, other: Self) -> Self;
|
||||
/// Compute an addition, modulo a custom modulus.
|
||||
#[must_use]
|
||||
fn wrapping_add_custom_mod(self, other: Self, custom_modulus: Self) -> Self;
|
||||
/// Compute a subtraction, modulo a custom modulus.
|
||||
#[must_use]
|
||||
fn wrapping_sub_custom_mod(self, other: Self, custom_modulus: Self) -> Self;
|
||||
/// Compute a division, modulo the max of the type.
|
||||
#[must_use]
|
||||
fn wrapping_div(self, other: Self) -> Self;
|
||||
/// Compute a multiplication, modulo the max of the type.
|
||||
#[must_use]
|
||||
fn wrapping_mul(self, other: Self) -> Self;
|
||||
/// Compute a multiplication, modulo a custom modulus.
|
||||
#[must_use]
|
||||
fn wrapping_mul_custom_mod(self, other: Self, custom_modulus: Self) -> Self;
|
||||
/// Compute the remainder, modulo the max of the type.
|
||||
#[must_use]
|
||||
fn wrapping_rem(self, other: Self) -> Self;
|
||||
/// Compute a negation, modulo the max of the type.
|
||||
#[must_use]
|
||||
fn wrapping_neg(self) -> Self;
|
||||
/// Compute a negation, modulo the max of the type.
|
||||
#[must_use]
|
||||
fn wrapping_neg_custom_mod(self, custom_modulus: Self) -> Self;
|
||||
/// Compute an exponentiation, modulo the max of the type.
|
||||
#[must_use]
|
||||
fn wrapping_pow(self, exp: u32) -> Self;
|
||||
@@ -113,6 +125,26 @@ macro_rules! implement {
|
||||
self.wrapping_sub(other)
|
||||
}
|
||||
#[inline]
|
||||
fn wrapping_add_custom_mod(self, other: Self, custom_modulus: Self) -> Self {
|
||||
match self.overflowing_add(other) {
|
||||
(result, true) => {
|
||||
// We have (for u64) a result of the form 2^64 + p, here we compute p mod q
|
||||
let result = result.wrapping_rem(custom_modulus);
|
||||
// and here we compute 2^64 mod q and add to the result as mod is linear
|
||||
let self_max_mod_custom = Self::MAX - custom_modulus + Self::ONE;
|
||||
result.wrapping_add(self_max_mod_custom)
|
||||
}
|
||||
(result, false) => result.wrapping_rem(custom_modulus),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn wrapping_sub_custom_mod(self, other: Self, custom_modulus: Self) -> Self {
|
||||
match self.overflowing_sub(other) {
|
||||
(result, true) => result.wrapping_add(custom_modulus),
|
||||
(result, false) => result.wrapping_rem(custom_modulus),
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn wrapping_div(self, other: Self) -> Self {
|
||||
self.wrapping_div(other)
|
||||
}
|
||||
@@ -120,7 +152,70 @@ macro_rules! implement {
|
||||
fn wrapping_mul(self, other: Self) -> Self {
|
||||
self.wrapping_mul(other)
|
||||
}
|
||||
#[must_use]
|
||||
#[inline]
|
||||
fn wrapping_mul_custom_mod(self, other: Self, custom_modulus: Self) -> Self {
|
||||
let self_u128: u128 = self.cast_into();
|
||||
let other_u128: u128 = other.cast_into();
|
||||
let custom_modulus_u128: u128 = custom_modulus.cast_into();
|
||||
let (prod, wrong) = self_u128.overflowing_mul(other_u128);
|
||||
// if we are not able to multiply directly without wrapping around
|
||||
if wrong {
|
||||
// we try to do the multiplication as
|
||||
// (a + b*2^64)*(c + d*2^64) = ac + (bc + ad)*2^64 + bd*2^128
|
||||
// with the assumption that b and d are very small
|
||||
// writing bc + ad = e + f*2^64 where again f should be small if b and d are
|
||||
// we have that the product is
|
||||
// ac + e*2^64 + (bd + f)*2^128
|
||||
// where bd + f is small
|
||||
// let 2^128 = r modulo the custom modulus
|
||||
// so that the product is ac + e*2^64 + (bd + f)*r
|
||||
// this can be computed without wrap around modulo 2^128 if
|
||||
// (bd + f)*r < 2^128 otherwise there is an error
|
||||
// there should be no error if the modulus is not close to 2^128 or is equal to
|
||||
// 2^128 -r with r not too close to 2^128
|
||||
let self_top = self_u128 >> 64;
|
||||
let other_top = other_u128 >> 64;
|
||||
let self_bottom = self_u128 - (self_top << 64);
|
||||
let other_bottom = other_u128 - (other_top << 64);
|
||||
let bottom = self_bottom.wrapping_mul(other_bottom);
|
||||
let middle1 = self_bottom.wrapping_mul(other_top);
|
||||
let middle2 = other_bottom.wrapping_mul(self_top);
|
||||
let (middle, wrong) = middle1.overflowing_add(middle2);
|
||||
assert!(
|
||||
!wrong,
|
||||
"multiplication of custom u128s failed: {:?}, {:?}",
|
||||
self_u128, other_u128
|
||||
);
|
||||
let middle_top = middle >> 64;
|
||||
let middle_bottom = middle - (middle_top << 64);
|
||||
let middle = (middle_bottom << 64).wrapping_rem(custom_modulus_u128);
|
||||
let rem = 0u128
|
||||
.wrapping_sub(1u128)
|
||||
.wrapping_rem(custom_modulus_u128)
|
||||
.wrapping_add(1u128);
|
||||
let top = self_top.wrapping_mul(other_top);
|
||||
let (top, wrong) = top.overflowing_add(middle_top);
|
||||
assert!(
|
||||
!wrong,
|
||||
"multiplication of custom u128s failed: {:?}, {:?}",
|
||||
self_u128, other_u128
|
||||
);
|
||||
let (top, wrong) = top.overflowing_mul(rem);
|
||||
assert!(
|
||||
!wrong,
|
||||
"multiplication of custom u128s failed: {:?}, {:?}",
|
||||
self_u128, other_u128
|
||||
);
|
||||
let top = top.wrapping_rem(custom_modulus_u128);
|
||||
let out = top.wrapping_add(middle).wrapping_rem(custom_modulus_u128);
|
||||
out.wrapping_add(bottom)
|
||||
.wrapping_rem(custom_modulus_u128)
|
||||
.cast_into()
|
||||
} else {
|
||||
prod.wrapping_rem(custom_modulus_u128).cast_into()
|
||||
}
|
||||
}
|
||||
#[inline]
|
||||
fn wrapping_rem(self, other: Self) -> Self {
|
||||
self.wrapping_rem(other)
|
||||
}
|
||||
@@ -129,6 +224,12 @@ macro_rules! implement {
|
||||
self.wrapping_neg()
|
||||
}
|
||||
#[inline]
|
||||
fn wrapping_neg_custom_mod(self, custom_modulus: Self) -> Self {
|
||||
custom_modulus
|
||||
.wrapping_sub_custom_mod(self, custom_modulus)
|
||||
.wrapping_rem(custom_modulus)
|
||||
}
|
||||
#[inline]
|
||||
fn wrapping_shl(self, rhs: u32) -> Self {
|
||||
self.wrapping_shl(rhs)
|
||||
}
|
||||
@@ -205,4 +306,72 @@ mod test {
|
||||
.to_string()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrapping_add_custom_mod() {
|
||||
let a = u64::MAX;
|
||||
let b = u64::MAX;
|
||||
let custom_modulus_u128 = (1u128 << 64) - (1 << 32) + 1;
|
||||
let custom_modulus = custom_modulus_u128 as u64;
|
||||
|
||||
let a_u128: u128 = a.into();
|
||||
let b_u128: u128 = b.into();
|
||||
|
||||
let expected_res = ((a_u128 + b_u128) % custom_modulus_u128) as u64;
|
||||
|
||||
let res = a.wrapping_add_custom_mod(b, custom_modulus);
|
||||
assert_eq!(expected_res, res);
|
||||
|
||||
const NB_REPS: usize = 100_000_000;
|
||||
|
||||
use rand::Rng;
|
||||
let mut thread_rng = rand::thread_rng();
|
||||
for _ in 0..NB_REPS {
|
||||
let a: u64 = thread_rng.gen();
|
||||
let b: u64 = thread_rng.gen();
|
||||
|
||||
let a_u128: u128 = a.into();
|
||||
let b_u128: u128 = b.into();
|
||||
|
||||
let expected_res = ((a_u128 + b_u128) % custom_modulus_u128) as u64;
|
||||
|
||||
let res = a.wrapping_add_custom_mod(b, custom_modulus);
|
||||
assert_eq!(expected_res, res, "a: {a}, b: {b}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrapping_sub_custom_mod() {
|
||||
let custom_modulus_u128 = (1u128 << 64) - (1 << 32) + 1;
|
||||
let custom_modulus = custom_modulus_u128 as u64;
|
||||
|
||||
let a = 0u64;
|
||||
let b = u64::MAX % custom_modulus;
|
||||
|
||||
let a_u128: u128 = a.into();
|
||||
let b_u128: u128 = b.into();
|
||||
|
||||
let expected_res = ((a_u128 + custom_modulus_u128 - b_u128) % custom_modulus_u128) as u64;
|
||||
|
||||
let res = a.wrapping_sub_custom_mod(b, custom_modulus);
|
||||
assert_eq!(expected_res, res);
|
||||
|
||||
const NB_REPS: usize = 100_000_000;
|
||||
|
||||
use rand::Rng;
|
||||
let mut thread_rng = rand::thread_rng();
|
||||
for _ in 0..NB_REPS {
|
||||
let a: u64 = thread_rng.gen();
|
||||
let b: u64 = thread_rng.gen();
|
||||
|
||||
let a_u128: u128 = a.into();
|
||||
let b_u128: u128 = b.into();
|
||||
|
||||
let expected_res =
|
||||
((a_u128 + custom_modulus_u128 - b_u128) % custom_modulus_u128) as u64;
|
||||
|
||||
let res = a.wrapping_sub_custom_mod(b, custom_modulus);
|
||||
assert_eq!(expected_res, res, "a: {a}, b: {b}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,7 +248,7 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> LweKeyswitchKey<C>
|
||||
|
||||
/// Return a view of the [`LweKeyswitchKey`]. This is useful if an algorithm takes a view by
|
||||
/// value.
|
||||
pub fn as_view(&self) -> LweKeyswitchKey<&'_ [Scalar]> {
|
||||
pub fn as_view(&self) -> LweKeyswitchKeyView<'_, Scalar> {
|
||||
LweKeyswitchKey::from_container(
|
||||
self.as_ref(),
|
||||
self.decomp_base_log,
|
||||
@@ -280,7 +280,7 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> LweKeyswitchKey<C>
|
||||
|
||||
impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> LweKeyswitchKey<C> {
|
||||
/// Mutable variant of [`LweKeyswitchKey::as_view`].
|
||||
pub fn as_mut_view(&mut self) -> LweKeyswitchKey<&'_ mut [Scalar]> {
|
||||
pub fn as_mut_view(&mut self) -> LweKeyswitchKeyMutView<'_, Scalar> {
|
||||
let decomp_base_log = self.decomp_base_log;
|
||||
let decomp_level_count = self.decomp_level_count;
|
||||
let output_lwe_size = self.output_lwe_size;
|
||||
@@ -303,6 +303,8 @@ impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> LweKeyswitchKey
|
||||
|
||||
/// An [`LweKeyswitchKey`] owning the memory for its own storage.
|
||||
pub type LweKeyswitchKeyOwned<Scalar> = LweKeyswitchKey<Vec<Scalar>>;
|
||||
pub type LweKeyswitchKeyView<'a, Scalar> = LweKeyswitchKey<&'a [Scalar]>;
|
||||
pub type LweKeyswitchKeyMutView<'a, Scalar> = LweKeyswitchKey<&'a mut [Scalar]>;
|
||||
|
||||
impl<Scalar: UnsignedInteger> LweKeyswitchKeyOwned<Scalar> {
|
||||
/// Allocate memory and create a new owned [`LweKeyswitchKey`].
|
||||
|
||||
@@ -88,10 +88,26 @@ impl<Scalar, C: Container<Element = Scalar>> LweSecretKey<C> {
|
||||
pub fn into_container(self) -> C {
|
||||
self.data
|
||||
}
|
||||
|
||||
pub fn as_view(&self) -> LweSecretKeyView<'_, Scalar> {
|
||||
LweSecretKey {
|
||||
data: self.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Scalar, C: ContainerMut<Element = Scalar>> LweSecretKey<C> {
|
||||
pub fn as_mut_view(&mut self) -> LweSecretKeyMutView<'_, Scalar> {
|
||||
LweSecretKey {
|
||||
data: self.as_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// An [`LweSecretKey`] owning the memory for its own storage.
|
||||
pub type LweSecretKeyOwned<Scalar> = LweSecretKey<Vec<Scalar>>;
|
||||
pub type LweSecretKeyView<'a, Scalar> = LweSecretKey<&'a [Scalar]>;
|
||||
pub type LweSecretKeyMutView<'a, Scalar> = LweSecretKey<&'a mut [Scalar]>;
|
||||
|
||||
impl<Scalar> LweSecretKeyOwned<Scalar>
|
||||
where
|
||||
|
||||
@@ -265,6 +265,7 @@ where
|
||||
|
||||
let lut_poly_size = lut.polynomial_size();
|
||||
let ciphertext_modulus = lut.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
let monomial_degree = pbs_modulus_switch(
|
||||
*lwe_body,
|
||||
lut_poly_size,
|
||||
|
||||
@@ -187,6 +187,7 @@ where
|
||||
stack: PodStack<'_>,
|
||||
) {
|
||||
let align = CACHELINE_ALIGN;
|
||||
let ciphertext_modulus = accumulator.ciphertext_modulus();
|
||||
|
||||
let (mut local_accumulator_lo, stack) =
|
||||
stack.collect_aligned(align, accumulator.as_ref().iter().map(|i| *i as u64));
|
||||
@@ -229,7 +230,7 @@ where
|
||||
accumulator.ciphertext_modulus(),
|
||||
);
|
||||
|
||||
let ciphertext_modulus = local_accumulator.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
if !ciphertext_modulus.is_native_modulus() {
|
||||
// When we convert back from the fourier domain, integer values will contain up to
|
||||
// about 100 MSBs with information. In our representation of power of 2
|
||||
|
||||
@@ -226,6 +226,7 @@ impl<'a> FourierLweBootstrapKeyView<'a> {
|
||||
|
||||
let lut_poly_size = lut.polynomial_size();
|
||||
let ciphertext_modulus = lut.ciphertext_modulus();
|
||||
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
|
||||
let monomial_degree = pbs_modulus_switch(
|
||||
*lwe_body,
|
||||
lut_poly_size,
|
||||
|
||||
@@ -3,9 +3,7 @@
|
||||
//! The TFHE-rs preludes include convenient imports.
|
||||
//! Having `tfhe::core_crypto::prelude::*;` should be enough to start using the lib.
|
||||
|
||||
pub use super::algorithms::{
|
||||
add_external_product_assign, polynomial_algorithms, slice_algorithms, *,
|
||||
};
|
||||
pub use super::algorithms::{polynomial_algorithms, slice_algorithms, *};
|
||||
pub use super::commons::computation_buffers::ComputationBuffers;
|
||||
pub use super::commons::dispersion::*;
|
||||
pub use super::commons::generators::{EncryptionRandomGenerator, SecretRandomGenerator};
|
||||
|
||||
Reference in New Issue
Block a user