feat(core): Support LWE primitives for prime Q

This commit is contained in:
Arthur Meyre
2023-05-15 16:35:58 +02:00
parent ef55a9e076
commit dd7489692f
18 changed files with 1017 additions and 113 deletions

View File

@@ -3,6 +3,7 @@
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::dispersion::DispersionParameter;
use crate::core_crypto::commons::generators::{EncryptionRandomGenerator, SecretRandomGenerator};
use crate::core_crypto::commons::math::random::{ActivatedRandomGenerator, RandomGenerator};
@@ -36,6 +37,55 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
let ciphertext_modulus = output_mask.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
fill_lwe_mask_and_body_for_encryption_native_mod_compatible(
lwe_secret_key,
output_mask,
output_body,
encoded,
noise_parameters,
generator,
)
} else {
fill_lwe_mask_and_body_for_encryption_non_native_mod(
lwe_secret_key,
output_mask,
output_body,
encoded,
noise_parameters,
generator,
)
}
}
pub fn fill_lwe_mask_and_body_for_encryption_native_mod_compatible<
Scalar,
KeyCont,
OutputCont,
Gen,
>(
lwe_secret_key: &LweSecretKey<KeyCont>,
output_mask: &mut LweMask<OutputCont>,
output_body: LweBodyRefMut<Scalar>,
encoded: Plaintext<Scalar>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
KeyCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert_eq!(
output_mask.ciphertext_modulus(),
output_body.ciphertext_modulus(),
"Mismatched moduli between mask ({:?}) and body ({:?})",
output_mask.ciphertext_modulus(),
output_body.ciphertext_modulus()
);
let ciphertext_modulus = output_mask.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
@@ -57,6 +107,51 @@ pub fn fill_lwe_mask_and_body_for_encryption<Scalar, KeyCont, OutputCont, Gen>(
));
}
pub fn fill_lwe_mask_and_body_for_encryption_non_native_mod<Scalar, KeyCont, OutputCont, Gen>(
lwe_secret_key: &LweSecretKey<KeyCont>,
output_mask: &mut LweMask<OutputCont>,
output_body: LweBodyRefMut<Scalar>,
encoded: Plaintext<Scalar>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
KeyCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert_eq!(
output_mask.ciphertext_modulus(),
output_body.ciphertext_modulus(),
"Mismatched moduli between mask ({:?}) and body ({:?})",
output_mask.ciphertext_modulus(),
output_body.ciphertext_modulus()
);
let ciphertext_modulus = output_mask.ciphertext_modulus();
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
generator.fill_slice_with_random_mask_custom_mod(output_mask.as_mut(), ciphertext_modulus);
// generate an error from the normal distribution described by std_dev
*output_body.data = generator.random_noise_custom_mod(noise_parameters, ciphertext_modulus);
*output_body.data = (*output_body.data).wrapping_add_custom_mod(
encoded.0,
ciphertext_modulus.get_custom_modulus().cast_into(),
);
// compute the multisum between the secret key and the mask
*output_body.data = (*output_body.data).wrapping_add_custom_mod(
slice_wrapping_dot_product_custom_mod(
output_mask.as_ref(),
lwe_secret_key.as_ref(),
ciphertext_modulus.get_custom_modulus().cast_into(),
),
ciphertext_modulus.get_custom_modulus().cast_into(),
);
}
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`).
///
/// See the [`LWE ciphertext formal definition`](`LweCiphertext#lwe-encryption`) for the definition
@@ -298,10 +393,10 @@ pub fn trivially_encrypt_lwe_ciphertext<Scalar, OutputCont>(
*output_body.data = encoded.0;
let ciphertext_modulus = output_body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if !ciphertext_modulus.is_native_modulus() {
// Manage non native power of 2 encoding
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
*output_body.data = (*output_body.data)
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus())
}
}
@@ -377,8 +472,8 @@ where
*output_body.data = encoded.0;
let ciphertext_modulus = output_body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if !ciphertext_modulus.is_native_modulus() {
// Manage the non native power of 2 encoding
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
*output_body.data = (*output_body.data)
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
@@ -398,6 +493,24 @@ pub fn decrypt_lwe_ciphertext<Scalar, KeyCont, InputCont>(
lwe_secret_key: &LweSecretKey<KeyCont>,
lwe_ciphertext: &LweCiphertext<InputCont>,
) -> Plaintext<Scalar>
where
Scalar: UnsignedInteger,
KeyCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
{
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
decrypt_lwe_ciphertext_native_mod_compatible(lwe_secret_key, lwe_ciphertext)
} else {
decrypt_lwe_ciphertext_non_native_mod(lwe_secret_key, lwe_ciphertext)
}
}
pub fn decrypt_lwe_ciphertext_native_mod_compatible<Scalar, KeyCont, InputCont>(
lwe_secret_key: &LweSecretKey<KeyCont>,
lwe_ciphertext: &LweCiphertext<InputCont>,
) -> Plaintext<Scalar>
where
Scalar: UnsignedInteger,
KeyCont: Container<Element = Scalar>,
@@ -434,6 +547,39 @@ where
}
}
pub fn decrypt_lwe_ciphertext_non_native_mod<Scalar, KeyCont, InputCont>(
lwe_secret_key: &LweSecretKey<KeyCont>,
lwe_ciphertext: &LweCiphertext<InputCont>,
) -> Plaintext<Scalar>
where
Scalar: UnsignedInteger,
KeyCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
{
assert!(
lwe_ciphertext.lwe_size().to_lwe_dimension() == lwe_secret_key.lwe_dimension(),
"Mismatch between LweDimension of output ciphertext and input secret key. \
Got {:?} in output, and {:?} in secret key.",
lwe_ciphertext.lwe_size().to_lwe_dimension(),
lwe_secret_key.lwe_dimension()
);
let ciphertext_modulus = lwe_ciphertext.ciphertext_modulus();
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
let (mask, body) = lwe_ciphertext.get_mask_and_body();
Plaintext((*body.data).wrapping_sub_custom_mod(
slice_wrapping_dot_product_custom_mod(
mask.as_ref(),
lwe_secret_key.as_ref(),
ciphertext_modulus.get_custom_modulus().cast_into(),
),
ciphertext_modulus.get_custom_modulus().cast_into(),
))
}
/// Encrypt an input plaintext list in an output [`LWE ciphertext list`](`LweCiphertextList`).
///
/// See this [`formal definition`](`encrypt_lwe_ciphertext#formal-definition`) for the definition
@@ -781,10 +927,6 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
lwe_public_key.lwe_size().to_lwe_dimension()
);
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
output.as_mut().fill(Scalar::ZERO);
let mut tmp_zero_encryption =
@@ -806,17 +948,7 @@ pub fn encrypt_lwe_ciphertext_with_public_key<Scalar, KeyCont, OutputCont, Gen>(
lwe_ciphertext_add_assign(output, &tmp_zero_encryption);
}
let body = output.get_mut_body();
if ciphertext_modulus.is_native_modulus() {
*body.data = (*body.data).wrapping_add(encoded.0);
} else {
*body.data = (*body.data).wrapping_add(
encoded
.0
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
}
lwe_ciphertext_plaintext_add_assign(output, encoded);
}
/// Encrypt an input plaintext in an output [`LWE ciphertext`](`LweCiphertext`) using a
@@ -920,8 +1052,6 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
let ciphertext_modulus = output.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let mut tmp_zero_encryption =
LweCiphertext::new(Scalar::ZERO, lwe_public_key.lwe_size(), ciphertext_modulus);
@@ -933,7 +1063,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
let (mut mask, body) = tmp_zero_encryption.get_mut_mask_and_body();
random_generator
.fill_slice_with_random_uniform_custom_mod(mask.as_mut(), ciphertext_modulus);
if !ciphertext_modulus.is_native_modulus() {
if ciphertext_modulus.is_non_native_power_of_two() {
slice_wrapping_scalar_mul_assign(
mask.as_mut(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),
@@ -947,17 +1077,7 @@ pub fn encrypt_lwe_ciphertext_with_seeded_public_key<Scalar, KeyCont, OutputCont
lwe_ciphertext_add_assign(output, &tmp_zero_encryption);
}
// Add encoded plaintext
let body = output.get_mut_body();
if ciphertext_modulus.is_native_modulus() {
*body.data = (*body.data).wrapping_add(encoded.0);
} else {
*body.data = (*body.data).wrapping_add(
encoded
.0
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
}
lwe_ciphertext_plaintext_add_assign(output, encoded);
}
/// Convenience function to share the core logic of the seeded LWE encryption between all functions

View File

@@ -2,7 +2,9 @@
//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`).
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::commons::math::decomposition::SignedDecomposer;
use crate::core_crypto::commons::math::decomposition::{
SignedDecomposer, SignedDecomposerNonNative,
};
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
@@ -99,6 +101,34 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
if lwe_keyswitch_key
.ciphertext_modulus()
.is_compatible_with_native_modulus()
{
keyswitch_lwe_ciphertext_native_mod_compatible(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
)
} else {
keyswitch_lwe_ciphertext_non_native_mod(
lwe_keyswitch_key,
input_lwe_ciphertext,
output_lwe_ciphertext,
)
}
}
pub fn keyswitch_lwe_ciphertext_native_mod_compatible<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
@@ -116,6 +146,25 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus(),
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus(),
);
assert!(lwe_keyswitch_key
.ciphertext_modulus()
.is_compatible_with_native_modulus());
// Clear the output ciphertext, as it will get updated gradually
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
@@ -145,3 +194,79 @@ pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
}
}
}
pub fn keyswitch_lwe_ciphertext_non_native_mod<Scalar, KSKCont, InputCont, OutputCont>(
lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
input_lwe_ciphertext: &LweCiphertext<InputCont>,
output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
) where
Scalar: UnsignedInteger,
KSKCont: Container<Element = Scalar>,
InputCont: Container<Element = Scalar>,
OutputCont: ContainerMut<Element = Scalar>,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension()
== input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched input LweDimension. \
LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension()
== output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
"Mismatched output LweDimension. \
LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
lwe_keyswitch_key.output_key_lwe_dimension(),
output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
input_lwe_ciphertext.ciphertext_modulus(),
);
assert_eq!(
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus(),
"Mismatched CiphertextModulus. \
LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
lwe_keyswitch_key.ciphertext_modulus(),
output_lwe_ciphertext.ciphertext_modulus(),
);
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
// Clear the output ciphertext, as it will get updated gradually
output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
// Copy the input body to the output ciphertext
*output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
// We instantiate a decomposer
let decomposer = SignedDecomposerNonNative::new(
lwe_keyswitch_key.decomposition_base_log(),
lwe_keyswitch_key.decomposition_level_count(),
ciphertext_modulus,
);
for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
.iter()
.zip(input_lwe_ciphertext.get_mask().as_ref())
{
let decomposition_iter = decomposer.decompose(input_mask_element);
// Loop over the levels
for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
{
slice_wrapping_sub_scalar_mul_assign_custom_modulus(
output_lwe_ciphertext.as_mut(),
level_key_ciphertext.as_ref(),
decomposed.value(),
ciphertext_modulus.get_custom_modulus().cast_into(),
);
}
}
}

View File

@@ -5,7 +5,9 @@
use crate::core_crypto::algorithms::*;
use crate::core_crypto::commons::dispersion::DispersionParameter;
use crate::core_crypto::commons::generators::EncryptionRandomGenerator;
use crate::core_crypto::commons::math::decomposition::{DecompositionLevel, DecompositionTerm};
use crate::core_crypto::commons::math::decomposition::{
DecompositionLevel, DecompositionTerm, DecompositionTermNonNative,
};
use crate::core_crypto::commons::math::random::ActivatedRandomGenerator;
use crate::core_crypto::commons::parameters::*;
use crate::core_crypto::commons::traits::*;
@@ -75,60 +77,217 @@ pub fn generate_lwe_keyswitch_key<Scalar, InputKeyCont, OutputKeyCont, KSKeyCont
KSKeyCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
"The destination LweKeyswitchKey input LweDimension is not equal \
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
"The destination LweKeyswitchKey output LweDimension is not equal \
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
lwe_keyswitch_key.output_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
.as_ref()
.iter()
.zip(lwe_keyswitch_key.iter_mut())
{
// We fill the buffer with the powers of the key elmements
for (level, message) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.iter_mut())
{
// Here we take the decomposition term from the native torus, bring it to the torus we
// are working with by dividing by the scaling factor and the encryption will take care
// of mapping that back to the native torus
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
encrypt_lwe_ciphertext_list(
if ciphertext_modulus.is_compatible_with_native_modulus() {
generate_lwe_keyswitch_key_native_mod_compatible(
input_lwe_sk,
output_lwe_sk,
&mut keyswitch_key_block,
&decomposition_plaintexts_buffer,
lwe_keyswitch_key,
noise_parameters,
generator,
);
)
} else {
generate_lwe_keyswitch_key_non_native_mod(
input_lwe_sk,
output_lwe_sk,
lwe_keyswitch_key,
noise_parameters,
generator,
)
}
}
pub fn generate_lwe_keyswitch_key_native_mod_compatible<
Scalar,
InputKeyCont,
OutputKeyCont,
KSKeyCont,
Gen,
>(
input_lwe_sk: &LweSecretKey<InputKeyCont>,
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
lwe_keyswitch_key: &mut LweKeyswitchKey<KSKeyCont>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
fn implementation<Scalar, Gen>(
input_lwe_sk: LweSecretKeyView<'_, Scalar>,
output_lwe_sk: LweSecretKeyView<'_, Scalar>,
mut lwe_keyswitch_key: LweKeyswitchKeyMutView<'_, Scalar>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
Gen: ByteRandomGenerator,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
"The destination LweKeyswitchKey input LweDimension is not equal \
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
"The destination LweKeyswitchKey output LweDimension is not equal \
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
lwe_keyswitch_key.output_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
assert!(lwe_keyswitch_key
.ciphertext_modulus()
.is_compatible_with_native_modulus());
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
.as_ref()
.iter()
.zip(lwe_keyswitch_key.iter_mut())
{
// We fill the buffer with the powers of the key elmements
for (level, message) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.iter_mut())
{
// Here we take the decomposition term from the native torus, bring it to the torus
// we are working with by dividing by the scaling factor and the
// encryption will take care of mapping that back to the native
// torus
*message.0 = DecompositionTerm::new(level, decomp_base_log, *input_key_element)
.to_recomposition_summand()
.wrapping_div(ciphertext_modulus.get_power_of_two_scaling_to_native_torus());
}
encrypt_lwe_ciphertext_list(
&output_lwe_sk,
&mut keyswitch_key_block,
&decomposition_plaintexts_buffer,
noise_parameters,
generator,
);
}
}
implementation(
input_lwe_sk.as_view(),
output_lwe_sk.as_view(),
lwe_keyswitch_key.as_mut_view(),
noise_parameters,
generator,
)
}
pub fn generate_lwe_keyswitch_key_non_native_mod<
Scalar,
InputKeyCont,
OutputKeyCont,
KSKeyCont,
Gen,
>(
input_lwe_sk: &LweSecretKey<InputKeyCont>,
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
lwe_keyswitch_key: &mut LweKeyswitchKey<KSKeyCont>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
Gen: ByteRandomGenerator,
{
fn implementation<Scalar, Gen>(
input_lwe_sk: LweSecretKeyView<'_, Scalar>,
output_lwe_sk: LweSecretKeyView<'_, Scalar>,
mut lwe_keyswitch_key: LweKeyswitchKeyMutView<'_, Scalar>,
noise_parameters: impl DispersionParameter,
generator: &mut EncryptionRandomGenerator<Gen>,
) where
Scalar: UnsignedTorus,
Gen: ByteRandomGenerator,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
"The destination LweKeyswitchKey input LweDimension is not equal \
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
"The destination LweKeyswitchKey output LweDimension is not equal \
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
lwe_keyswitch_key.output_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
assert!(!lwe_keyswitch_key
.ciphertext_modulus()
.is_compatible_with_native_modulus());
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
.as_ref()
.iter()
.zip(lwe_keyswitch_key.iter_mut())
{
// We fill the buffer with the powers of the key elmements
for (level, message) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.iter_mut())
{
*message.0 = DecompositionTermNonNative::new(
level,
decomp_base_log,
*input_key_element,
ciphertext_modulus,
)
.to_recomposition_summand();
}
encrypt_lwe_ciphertext_list(
&output_lwe_sk,
&mut keyswitch_key_block,
&decomposition_plaintexts_buffer,
noise_parameters,
generator,
);
}
}
implementation(
input_lwe_sk.as_view(),
output_lwe_sk.as_view(),
lwe_keyswitch_key.as_mut_view(),
noise_parameters,
generator,
)
}
/// Allocate a new [`LWE keyswitch key`](`LweKeyswitchKey`) and fill it with an actual keyswitching
/// key constructed from an input and an output key [`LWE secret key`](`LweSecretKey`).
///
@@ -237,6 +396,46 @@ pub fn generate_seeded_lwe_keyswitch_key<
KSKeyCont: ContainerMut<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
generate_seeded_lwe_keyswitch_key_native_mod_compatible(
input_lwe_sk,
output_lwe_sk,
lwe_keyswitch_key,
noise_parameters,
noise_seeder,
)
} else {
generate_seeded_lwe_keyswitch_key_non_native_mod(
input_lwe_sk,
output_lwe_sk,
lwe_keyswitch_key,
noise_parameters,
noise_seeder,
)
}
}
pub fn generate_seeded_lwe_keyswitch_key_native_mod_compatible<
Scalar,
InputKeyCont,
OutputKeyCont,
KSKeyCont,
NoiseSeeder,
>(
input_lwe_sk: &LweSecretKey<InputKeyCont>,
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
lwe_keyswitch_key: &mut SeededLweKeyswitchKey<KSKeyCont>,
noise_parameters: impl DispersionParameter,
noise_seeder: &mut NoiseSeeder,
) where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
@@ -297,6 +496,86 @@ pub fn generate_seeded_lwe_keyswitch_key<
}
}
pub fn generate_seeded_lwe_keyswitch_key_non_native_mod<
Scalar,
InputKeyCont,
OutputKeyCont,
KSKeyCont,
NoiseSeeder,
>(
input_lwe_sk: &LweSecretKey<InputKeyCont>,
output_lwe_sk: &LweSecretKey<OutputKeyCont>,
lwe_keyswitch_key: &mut SeededLweKeyswitchKey<KSKeyCont>,
noise_parameters: impl DispersionParameter,
noise_seeder: &mut NoiseSeeder,
) where
Scalar: UnsignedTorus,
InputKeyCont: Container<Element = Scalar>,
OutputKeyCont: Container<Element = Scalar>,
KSKeyCont: ContainerMut<Element = Scalar>,
// Maybe Sized allows to pass Box<dyn Seeder>.
NoiseSeeder: Seeder + ?Sized,
{
assert!(
lwe_keyswitch_key.input_key_lwe_dimension() == input_lwe_sk.lwe_dimension(),
"The destination SeededLweKeyswitchKey input LweDimension is not equal \
to the input LweSecretKey LweDimension. Destination: {:?}, input: {:?}",
lwe_keyswitch_key.input_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
assert!(
lwe_keyswitch_key.output_key_lwe_dimension() == output_lwe_sk.lwe_dimension(),
"The destination SeededLweKeyswitchKey output LweDimension is not equal \
to the output LweSecretKey LweDimension. Destination: {:?}, output: {:?}",
lwe_keyswitch_key.output_key_lwe_dimension(),
input_lwe_sk.lwe_dimension()
);
let decomp_base_log = lwe_keyswitch_key.decomposition_base_log();
let decomp_level_count = lwe_keyswitch_key.decomposition_level_count();
let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
// The plaintexts used to encrypt a key element will be stored in this buffer
let mut decomposition_plaintexts_buffer =
PlaintextListOwned::new(Scalar::ZERO, PlaintextCount(decomp_level_count.0));
let mut generator = EncryptionRandomGenerator::<ActivatedRandomGenerator>::new(
lwe_keyswitch_key.compression_seed().seed,
noise_seeder,
);
// Iterate over the input key elements and the destination lwe_keyswitch_key memory
for (input_key_element, mut keyswitch_key_block) in input_lwe_sk
.as_ref()
.iter()
.zip(lwe_keyswitch_key.iter_mut())
{
// We fill the buffer with the powers of the key elmements
for (level, message) in (1..=decomp_level_count.0)
.rev()
.map(DecompositionLevel)
.zip(decomposition_plaintexts_buffer.iter_mut())
{
*message.0 = DecompositionTermNonNative::new(
level,
decomp_base_log,
*input_key_element,
ciphertext_modulus,
)
.to_recomposition_summand();
}
encrypt_seeded_lwe_ciphertext_list_with_existing_generator(
output_lwe_sk,
&mut keyswitch_key_block,
&decomposition_plaintexts_buffer,
noise_parameters,
&mut generator,
);
}
}
/// Allocate a new [`seeded LWE keyswitch key`](`SeededLweKeyswitchKey`) and fill it with an actual
/// keyswitching key constructed from an input and an output key
/// [`LWE secret key`](`LweSecretKey`).

View File

@@ -2,6 +2,7 @@
//! like addition, multiplication, etc.
use crate::core_crypto::algorithms::slice_algorithms::*;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::numeric::UnsignedInteger;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
@@ -71,6 +72,22 @@ pub fn lwe_ciphertext_add_assign<Scalar, LhsCont, RhsCont>(
Scalar: UnsignedInteger,
LhsCont: ContainerMut<Element = Scalar>,
RhsCont: Container<Element = Scalar>,
{
let ciphertext_modulus = rhs.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
lwe_ciphertext_add_assign_native_mod_compatible(lhs, rhs)
} else {
lwe_ciphertext_add_assign_non_native_mod(lhs, rhs)
}
}
pub fn lwe_ciphertext_add_assign_native_mod_compatible<Scalar, LhsCont, RhsCont>(
lhs: &mut LweCiphertext<LhsCont>,
rhs: &LweCiphertext<RhsCont>,
) where
Scalar: UnsignedInteger,
LhsCont: ContainerMut<Element = Scalar>,
RhsCont: Container<Element = Scalar>,
{
assert_eq!(
lhs.ciphertext_modulus(),
@@ -79,10 +96,37 @@ pub fn lwe_ciphertext_add_assign<Scalar, LhsCont, RhsCont>(
lhs.ciphertext_modulus(),
rhs.ciphertext_modulus()
);
let ciphertext_modulus = rhs.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
slice_wrapping_add_assign(lhs.as_mut(), rhs.as_ref());
}
pub fn lwe_ciphertext_add_assign_non_native_mod<Scalar, LhsCont, RhsCont>(
lhs: &mut LweCiphertext<LhsCont>,
rhs: &LweCiphertext<RhsCont>,
) where
Scalar: UnsignedInteger,
LhsCont: ContainerMut<Element = Scalar>,
RhsCont: Container<Element = Scalar>,
{
assert_eq!(
lhs.ciphertext_modulus(),
rhs.ciphertext_modulus(),
"Mismatched moduli between lhs ({:?}) and rhs ({:?}) LweCiphertext",
lhs.ciphertext_modulus(),
rhs.ciphertext_modulus()
);
let ciphertext_modulus = rhs.ciphertext_modulus();
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
slice_wrapping_add_assign_custom_mod(
lhs.as_mut(),
rhs.as_ref(),
ciphertext_modulus.get_custom_modulus().cast_into(),
);
}
/// Add the right-hand side [`LWE ciphertext`](`LweCiphertext`) to the left-hand side [`LWE
/// ciphertext`](`LweCiphertext`) writing the result in the output [`LWE
/// ciphertext`](`LweCiphertext`).
@@ -235,20 +279,51 @@ pub fn lwe_ciphertext_plaintext_add_assign<Scalar, InCont>(
) where
Scalar: UnsignedInteger,
InCont: ContainerMut<Element = Scalar>,
{
let ciphertext_modulus = lhs.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
lwe_ciphertext_plaintext_add_assign_native_mod_compatible(lhs, rhs)
} else {
lwe_ciphertext_plaintext_add_assign_non_native_mod(lhs, rhs)
}
}
pub fn lwe_ciphertext_plaintext_add_assign_native_mod_compatible<Scalar, InCont>(
lhs: &mut LweCiphertext<InCont>,
rhs: Plaintext<Scalar>,
) where
Scalar: UnsignedInteger,
InCont: ContainerMut<Element = Scalar>,
{
let body = lhs.get_mut_body();
let ciphertext_modulus = body.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
if ciphertext_modulus.is_native_modulus() {
*body.data = (*body.data).wrapping_add(rhs.0);
} else {
*body.data = (*body.data).wrapping_add(
rhs.0
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
);
match ciphertext_modulus.kind() {
CiphertextModulusKind::Native => *body.data = (*body.data).wrapping_add(rhs.0),
CiphertextModulusKind::NonNativePowerOfTwo => {
*body.data = (*body.data).wrapping_add(
rhs.0
.wrapping_mul(ciphertext_modulus.get_power_of_two_scaling_to_native_torus()),
)
}
CiphertextModulusKind::NonNative => unreachable!(),
}
}
pub fn lwe_ciphertext_plaintext_add_assign_non_native_mod<Scalar, InCont>(
lhs: &mut LweCiphertext<InCont>,
rhs: Plaintext<Scalar>,
) where
Scalar: UnsignedInteger,
InCont: ContainerMut<Element = Scalar>,
{
let body = lhs.get_mut_body();
let ciphertext_modulus = body.ciphertext_modulus();
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
*body.data = (*body.data)
.wrapping_add_custom_mod(rhs.0, ciphertext_modulus.get_custom_modulus().cast_into());
}
/// Add the right-hand side encoded [`Plaintext`] to the left-hand side [`LWE
/// ciphertext`](`LweCiphertext`) updating it in-place.
///
@@ -311,6 +386,21 @@ pub fn lwe_ciphertext_plaintext_sub_assign<Scalar, InCont>(
) where
Scalar: UnsignedInteger,
InCont: ContainerMut<Element = Scalar>,
{
let ciphertext_modulus = lhs.ciphertext_modulus();
if ciphertext_modulus.is_compatible_with_native_modulus() {
lwe_ciphertext_plaintext_sub_assign_native_mod_compatible(lhs, rhs)
} else {
lwe_ciphertext_plaintext_sub_assign_non_native_mod(lhs, rhs)
}
}
pub fn lwe_ciphertext_plaintext_sub_assign_native_mod_compatible<Scalar, InCont>(
lhs: &mut LweCiphertext<InCont>,
rhs: Plaintext<Scalar>,
) where
Scalar: UnsignedInteger,
InCont: ContainerMut<Element = Scalar>,
{
let body = lhs.get_mut_body();
let ciphertext_modulus = body.ciphertext_modulus();
@@ -325,6 +415,20 @@ pub fn lwe_ciphertext_plaintext_sub_assign<Scalar, InCont>(
}
}
pub fn lwe_ciphertext_plaintext_sub_assign_non_native_mod<Scalar, InCont>(
lhs: &mut LweCiphertext<InCont>,
rhs: Plaintext<Scalar>,
) where
Scalar: UnsignedInteger,
InCont: ContainerMut<Element = Scalar>,
{
let body = lhs.get_mut_body();
let ciphertext_modulus = body.ciphertext_modulus();
assert!(!ciphertext_modulus.is_compatible_with_native_modulus());
*body.data = (*body.data)
.wrapping_sub_custom_mod(rhs.0, ciphertext_modulus.get_custom_modulus().cast_into());
}
/// Compute the opposite of the input [`LWE ciphertext`](`LweCiphertext`) and update it in place.
///
/// # Example

View File

@@ -1,6 +1,7 @@
//! Module with primitives pertaining to [`SeededLweCiphertext`] decompression.
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::math::random::RandomGenerator;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
@@ -26,12 +27,12 @@ pub fn decompress_seeded_lwe_ciphertext_with_existing_generator<Scalar, OutputCo
);
let ciphertext_modulus = output_lwe.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
let (mut output_mask, output_body) = output_lwe.get_mut_mask_and_body();
// generate a uniformly random mask
generator.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
if !ciphertext_modulus.is_native_modulus() {
// Manage the specific encoding for non native power of 2
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
slice_wrapping_scalar_mul_assign(
output_mask.as_mut(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),

View File

@@ -1,6 +1,7 @@
//! Module with primitives pertaining to [`SeededLweCiphertextList`] decompression.
use crate::core_crypto::algorithms::slice_algorithms::slice_wrapping_scalar_mul_assign;
use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
use crate::core_crypto::commons::math::random::RandomGenerator;
use crate::core_crypto::commons::traits::*;
use crate::core_crypto::entities::*;
@@ -32,7 +33,6 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
);
let ciphertext_modulus = output_list.ciphertext_modulus();
assert!(ciphertext_modulus.is_compatible_with_native_modulus());
for (mut lwe_out, body_in) in output_list.iter_mut().zip(input_seeded_list.iter()) {
let (mut output_mask, output_body) = lwe_out.get_mut_mask_and_body();
@@ -40,7 +40,8 @@ pub fn decompress_seeded_lwe_ciphertext_list_with_existing_generator<
// generate a uniformly random mask
generator
.fill_slice_with_random_uniform_custom_mod(output_mask.as_mut(), ciphertext_modulus);
if !ciphertext_modulus.is_native_modulus() {
// Manage the non native power of 2 encoding
if let CiphertextModulusKind::NonNativePowerOfTwo = ciphertext_modulus.kind() {
slice_wrapping_scalar_mul_assign(
output_mask.as_mut(),
ciphertext_modulus.get_power_of_two_scaling_to_native_torus(),

View File

@@ -36,6 +36,31 @@ where
})
}
/// This primitive is meant to manage the dot product avoiding overflow on multiplication by casting
/// to u128, for example for u64, avoiding overflow on each multiplication (as u64::MAX * u64::MAX <
/// u128::MAX)
pub fn slice_wrapping_dot_product_custom_mod<Scalar>(
lhs: &[Scalar],
rhs: &[Scalar],
modulus: Scalar,
) -> Scalar
where
Scalar: UnsignedInteger,
{
assert!(
lhs.len() == rhs.len(),
"lhs (len: {}) and rhs (len: {}) must have the same length",
lhs.len(),
rhs.len()
);
lhs.iter()
.zip(rhs.iter())
.fold(Scalar::ZERO, |acc, (&left, &right)| {
acc.wrapping_add_custom_mod(left.wrapping_mul_custom_mod(right, modulus), modulus)
})
}
/// Add a slice containing unsigned integers to another one element-wise.
///
/// # Note
@@ -108,6 +133,25 @@ where
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_add(rhs));
}
pub fn slice_wrapping_add_assign_custom_mod<Scalar>(
lhs: &mut [Scalar],
rhs: &[Scalar],
custom_modulus: Scalar,
) where
Scalar: UnsignedInteger,
{
assert!(
lhs.len() == rhs.len(),
"lhs (len: {}) and rhs (len: {}) must have the same length",
lhs.len(),
rhs.len()
);
lhs.iter_mut()
.zip(rhs.iter())
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_add_custom_mod(rhs, custom_modulus));
}
/// Add a slice containing unsigned integers to another one mutiplied by a scalar.
///
/// Let *a*,*b* be two slices, let *c* be a scalar, this computes: *a <- a+bc*
@@ -254,6 +298,28 @@ pub fn slice_wrapping_sub_scalar_mul_assign<Scalar>(
.for_each(|(lhs, &rhs)| *lhs = (*lhs).wrapping_sub(rhs.wrapping_mul(scalar)));
}
/// This primitive is meant to manage the sub_scalar_mul operation for values that were cast to a
/// bigger type, for example u64 to u128, avoiding overflow on each multiplication (as u64::MAX *
/// u64::MAX < u128::MAX )
pub fn slice_wrapping_sub_scalar_mul_assign_custom_modulus<Scalar>(
lhs: &mut [Scalar],
rhs: &[Scalar],
scalar: Scalar,
modulus: Scalar,
) where
Scalar: UnsignedInteger,
{
assert!(
lhs.len() == rhs.len(),
"lhs (len: {}) and rhs (len: {}) must have the same length",
lhs.len(),
rhs.len()
);
lhs.iter_mut().zip(rhs.iter()).for_each(|(lhs, &rhs)| {
*lhs = (*lhs).wrapping_sub_custom_mod(rhs.wrapping_mul_custom_mod(scalar, modulus), modulus)
});
}
/// Compute the opposite of a slice containing unsigned integers, element-wise and in place.
///
/// # Note

View File

@@ -133,6 +133,11 @@ fn test_parallel_and_seeded_lwe_list_encryption_equivalence_non_native_power_of_
test_parallel_and_seeded_lwe_list_encryption_equivalence(TEST_PARAMS_3_BITS_63_U64);
}
#[test]
fn test_parallel_and_seeded_lwe_list_encryption_equivalence_solinas_mod_u64() {
test_parallel_and_seeded_lwe_list_encryption_equivalence(TEST_PARAMS_3_BITS_SOLINAS_U64);
}
#[test]
fn test_parallel_and_seeded_lwe_list_encryption_equivalence_native_mod_u32() {
test_parallel_and_seeded_lwe_list_encryption_equivalence(DUMMY_NATIVE_U32);

View File

@@ -3,6 +3,7 @@ use super::*;
fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<Scalar>) {
let lwe_dimension = params.lwe_dimension;
let lwe_modular_std_dev = params.lwe_modular_std_dev;
let glwe_moduluar_std_dev = params.glwe_modular_std_dev;
let ciphertext_modulus = params.ciphertext_modulus;
let message_modulus_log = params.message_modulus_log;
let encoding_with_padding = get_encoding_with_padding(ciphertext_modulus);
@@ -21,6 +22,7 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<S
while msg != Scalar::ZERO {
msg = msg.wrapping_sub(Scalar::ONE);
for _ in 0..NB_TESTS {
println!("{msg}");
let lwe_sk = allocate_and_generate_new_binary_lwe_secret_key(
lwe_dimension,
&mut rsc.secret_random_generator,
@@ -54,7 +56,7 @@ fn lwe_encrypt_ks_decrypt_custom_mod<Scalar: UnsignedTorus>(params: TestParams<S
let ct = allocate_and_encrypt_new_lwe_ciphertext(
&big_lwe_sk,
plaintext,
lwe_modular_std_dev,
glwe_moduluar_std_dev,
ciphertext_modulus,
&mut rsc.encryption_random_generator,
);

View File

@@ -112,3 +112,10 @@ fn test_seeded_lwe_ksk_gen_equivalence_u32_custom_mod() {
fn test_seeded_lwe_ksk_gen_equivalence_u64_custom_mod() {
test_seeded_lwe_ksk_gen_equivalence::<u64>(CiphertextModulus::try_new_power_of_2(63).unwrap())
}
#[test]
fn test_seeded_lwe_ksk_gen_equivalence_u64_solinas_mod() {
test_seeded_lwe_ksk_gen_equivalence::<u64>(
CiphertextModulus::try_new((1 << 64) - (1 << 32) + 1).unwrap(),
)
}

View File

@@ -126,7 +126,10 @@ fn lwe_encrypt_pbs_decrypt_custom_mod<
}
}
create_parametrized_test!(lwe_encrypt_pbs_decrypt_custom_mod);
create_parametrized_test!(lwe_encrypt_pbs_decrypt_custom_mod {
TEST_PARAMS_4_BITS_NATIVE_U64,
TEST_PARAMS_3_BITS_63_U64
});
// DISCLAIMER: all parameters here are not guaranteed to be secure or yield correct computations
pub const TEST_PARAMS_4_BITS_NATIVE_U128: TestParams<u128> = TestParams {

View File

@@ -128,6 +128,25 @@ pub const DUMMY_31_U32: TestParams<u32> = TestParams {
ciphertext_modulus: unsafe { CiphertextModulus::new_unchecked(1 << 31) },
};
pub const TEST_PARAMS_3_BITS_SOLINAS_U64: TestParams<u64> = TestParams {
lwe_dimension: LweDimension(742),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_modular_std_dev: StandardDev(0.000007069849454709433),
glwe_modular_std_dev: StandardDev(0.00000000000000029403601535432533),
pbs_base_log: DecompositionBaseLog(23),
pbs_level: DecompositionLevelCount(1),
ks_level: DecompositionLevelCount(5),
ks_base_log: DecompositionBaseLog(3),
pfks_level: DecompositionLevelCount(1),
pfks_base_log: DecompositionBaseLog(23),
pfks_modular_std_dev: StandardDev(0.00000000000000029403601535432533),
cbs_level: DecompositionLevelCount(0),
cbs_base_log: DecompositionBaseLog(0),
message_modulus_log: CiphertextModulusLog(3),
ciphertext_modulus: unsafe { CiphertextModulus::new_unchecked((1 << 64) - (1 << 32) + 1) },
};
// Our representation of non native power of 2 moduli puts the information in the MSBs and leaves
// the LSBs empty, this is what this function is checking
pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]>>(
@@ -135,26 +154,42 @@ pub fn check_content_respects_mod<Scalar: UnsignedInteger, Input: AsRef<[Scalar]
modulus: CiphertextModulus<Scalar>,
) -> bool {
if !modulus.is_native_modulus() {
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
// we want the bits under the mask to be 0
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return input
.as_ref()
.iter()
.all(|&x| (x & power_2_diff_mask) == Scalar::ZERO);
// Power of two has a specific encoding in MSBs of the native torus to re-use native torus
// implementations for free
if modulus.is_power_of_two() {
// If our modulus is 2^60, the scaling is 2^4 = 00...00010000, minus 1 = 00...00001111
// we want the bits under the mask to be 0
let power_2_diff_mask =
modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return input
.as_ref()
.iter()
.all(|&x| (x & power_2_diff_mask) == Scalar::ZERO);
} else {
// Custom moduli non power of two use the "true" modulus representation
return input
.as_ref()
.iter()
.all(|&x| x < Scalar::cast_from(modulus.get_custom_modulus()));
}
}
true
}
// See above
// See above comments for modulus check logic
pub fn check_scalar_respects_mod<Scalar: UnsignedInteger>(
input: Scalar,
modulus: CiphertextModulus<Scalar>,
) -> bool {
if !modulus.is_native_modulus() {
let power_2_diff_mask = modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return (input & power_2_diff_mask) == Scalar::ZERO;
if modulus.is_power_of_two() {
let power_2_diff_mask =
modulus.get_power_of_two_scaling_to_native_torus() - Scalar::ONE;
return (input & power_2_diff_mask) == Scalar::ZERO;
} else {
return input < Scalar::cast_from(modulus.get_custom_modulus());
}
}
true
@@ -243,6 +278,7 @@ macro_rules! create_parametrized_test{
create_parametrized_test!($name
{
TEST_PARAMS_4_BITS_NATIVE_U64,
TEST_PARAMS_3_BITS_SOLINAS_U64,
TEST_PARAMS_3_BITS_63_U64
});
};

View File

@@ -21,6 +21,12 @@ pub struct CiphertextModulus<Scalar: UnsignedInteger> {
_scalar: PhantomData<Scalar>,
}
pub enum CiphertextModulusKind {
Native,
NonNativePowerOfTwo,
NonNative,
}
#[derive(serde::Serialize, serde::Deserialize)]
struct SerialiazableLweCiphertextModulus {
pub modulus: u128,
@@ -137,7 +143,11 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
}
}
};
Ok(res.canonicalize())
let canonicalized_result = res.canonicalize();
if Scalar::BITS > 64 && !canonicalized_result.is_compatible_with_native_modulus() {
return Err("Non power of 2 moduli are not supported for types wider than u64");
}
Ok(canonicalized_result)
}
}
@@ -171,6 +181,7 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
res.canonicalize()
}
#[track_caller]
pub fn get_power_of_two_scaling_to_native_torus(&self) -> Scalar {
match self.inner {
CiphertextModulusInner::Native => Scalar::ONE,
@@ -206,12 +217,32 @@ impl<Scalar: UnsignedInteger> CiphertextModulus<Scalar> {
self.is_native_modulus() || self.is_power_of_two()
}
pub const fn is_non_native_power_of_two(&self) -> bool {
match self.inner {
CiphertextModulusInner::Native => false,
CiphertextModulusInner::Custom(modulus) => modulus.is_power_of_two(),
}
}
pub const fn is_power_of_two(&self) -> bool {
match self.inner {
CiphertextModulusInner::Native => true,
CiphertextModulusInner::Custom(modulus) => modulus.is_power_of_two(),
}
}
pub const fn kind(&self) -> CiphertextModulusKind {
match self.inner {
CiphertextModulusInner::Native => CiphertextModulusKind::Native,
CiphertextModulusInner::Custom(modulus) => {
if modulus.is_power_of_two() {
CiphertextModulusKind::NonNativePowerOfTwo
} else {
CiphertextModulusKind::NonNative
}
}
}
}
}
impl<Scalar: UnsignedInteger> std::fmt::Display for CiphertextModulus<Scalar> {

View File

@@ -211,7 +211,8 @@ impl<G: ByteRandomGenerator> RandomGenerator<G> {
) where
Scalar: UnsignedInteger + RandomGenerable<Uniform>,
{
assert!(custom_modulus.is_compatible_with_native_modulus());
// TODO
// Implement the proper generation function for custom mod for the uniform distribution
self.fill_slice_with_random_uniform(output);
if !custom_modulus.is_native_modulus() {

View File

@@ -45,12 +45,21 @@ pub trait UnsignedInteger:
/// Compute a subtraction, modulo the max of the type.
#[must_use]
fn wrapping_sub(self, other: Self) -> Self;
/// Compute an addition, modulo a custom modulus.
#[must_use]
fn wrapping_add_custom_mod(self, other: Self, custom_modulus: Self) -> Self;
/// Compute a subtraction, modulo a custom modulus.
#[must_use]
fn wrapping_sub_custom_mod(self, other: Self, custom_modulus: Self) -> Self;
/// Compute a division, modulo the max of the type.
#[must_use]
fn wrapping_div(self, other: Self) -> Self;
/// Compute a multiplication, modulo the max of the type.
#[must_use]
fn wrapping_mul(self, other: Self) -> Self;
/// Compute a multiplication, modulo a custom modulus.
#[must_use]
fn wrapping_mul_custom_mod(self, other: Self, custom_modulus: Self) -> Self;
/// Compute the remainder, modulo the max of the type.
#[must_use]
fn wrapping_rem(self, other: Self) -> Self;
@@ -113,6 +122,26 @@ macro_rules! implement {
self.wrapping_sub(other)
}
#[inline]
fn wrapping_add_custom_mod(self, other: Self, custom_modulus: Self) -> Self {
match self.overflowing_add(other) {
(result, true) => {
// We have (for u64) a result of the form 2^64 + p, here we compute p mod q
let result = result.wrapping_rem(custom_modulus);
// and here we compute 2^64 mod q and add to the result as mod is linear
let self_max_mod_custom = Self::MAX - custom_modulus + Self::ONE;
result.wrapping_add(self_max_mod_custom)
}
(result, false) => result.wrapping_rem(custom_modulus),
}
}
#[inline]
fn wrapping_sub_custom_mod(self, other: Self, custom_modulus: Self) -> Self {
match self.overflowing_sub(other) {
(result, true) => result.wrapping_add(custom_modulus),
(result, false) => result.wrapping_rem(custom_modulus),
}
}
#[inline]
fn wrapping_div(self, other: Self) -> Self {
self.wrapping_div(other)
}
@@ -120,7 +149,17 @@ macro_rules! implement {
fn wrapping_mul(self, other: Self) -> Self {
self.wrapping_mul(other)
}
#[must_use]
#[inline]
fn wrapping_mul_custom_mod(self, other: Self, custom_modulus: Self) -> Self {
let self_u128: u128 = self.cast_into();
let other_u128: u128 = other.cast_into();
let custom_modulus_u128: u128 = custom_modulus.cast_into();
self_u128
.wrapping_mul(other_u128)
.wrapping_rem(custom_modulus_u128)
.cast_into()
}
#[inline]
fn wrapping_rem(self, other: Self) -> Self {
self.wrapping_rem(other)
}
@@ -205,4 +244,72 @@ mod test {
.to_string()
);
}
#[test]
fn test_wrapping_add_custom_mod() {
let a = u64::MAX;
let b = u64::MAX;
let custom_modulus_u128 = (1u128 << 64) - (1 << 32) + 1;
let custom_modulus = custom_modulus_u128 as u64;
let a_u128: u128 = a.into();
let b_u128: u128 = b.into();
let expected_res = ((a_u128 + b_u128) % custom_modulus_u128) as u64;
let res = a.wrapping_add_custom_mod(b, custom_modulus);
assert_eq!(expected_res, res);
const NB_REPS: usize = 100_000_000;
use rand::Rng;
let mut thread_rng = rand::thread_rng();
for _ in 0..NB_REPS {
let a: u64 = thread_rng.gen();
let b: u64 = thread_rng.gen();
let a_u128: u128 = a.into();
let b_u128: u128 = b.into();
let expected_res = ((a_u128 + b_u128) % custom_modulus_u128) as u64;
let res = a.wrapping_add_custom_mod(b, custom_modulus);
assert_eq!(expected_res, res, "a: {a}, b: {b}");
}
}
#[test]
fn test_wrapping_sub_custom_mod() {
let custom_modulus_u128 = (1u128 << 64) - (1 << 32) + 1;
let custom_modulus = custom_modulus_u128 as u64;
let a = 0u64;
let b = u64::MAX % custom_modulus;
let a_u128: u128 = a.into();
let b_u128: u128 = b.into();
let expected_res = ((a_u128 + custom_modulus_u128 - b_u128) % custom_modulus_u128) as u64;
let res = a.wrapping_sub_custom_mod(b, custom_modulus);
assert_eq!(expected_res, res);
const NB_REPS: usize = 100_000_000;
use rand::Rng;
let mut thread_rng = rand::thread_rng();
for _ in 0..NB_REPS {
let a: u64 = thread_rng.gen();
let b: u64 = thread_rng.gen();
let a_u128: u128 = a.into();
let b_u128: u128 = b.into();
let expected_res =
((a_u128 + custom_modulus_u128 - b_u128) % custom_modulus_u128) as u64;
let res = a.wrapping_sub_custom_mod(b, custom_modulus);
assert_eq!(expected_res, res, "a: {a}, b: {b}");
}
}
}

View File

@@ -248,7 +248,7 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> LweKeyswitchKey<C>
/// Return a view of the [`LweKeyswitchKey`]. This is useful if an algorithm takes a view by
/// value.
pub fn as_view(&self) -> LweKeyswitchKey<&'_ [Scalar]> {
pub fn as_view(&self) -> LweKeyswitchKeyView<'_, Scalar> {
LweKeyswitchKey::from_container(
self.as_ref(),
self.decomp_base_log,
@@ -280,7 +280,7 @@ impl<Scalar: UnsignedInteger, C: Container<Element = Scalar>> LweKeyswitchKey<C>
impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> LweKeyswitchKey<C> {
/// Mutable variant of [`LweKeyswitchKey::as_view`].
pub fn as_mut_view(&mut self) -> LweKeyswitchKey<&'_ mut [Scalar]> {
pub fn as_mut_view(&mut self) -> LweKeyswitchKeyMutView<'_, Scalar> {
let decomp_base_log = self.decomp_base_log;
let decomp_level_count = self.decomp_level_count;
let output_lwe_size = self.output_lwe_size;
@@ -303,6 +303,8 @@ impl<Scalar: UnsignedInteger, C: ContainerMut<Element = Scalar>> LweKeyswitchKey
/// An [`LweKeyswitchKey`] owning the memory for its own storage.
pub type LweKeyswitchKeyOwned<Scalar> = LweKeyswitchKey<Vec<Scalar>>;
pub type LweKeyswitchKeyView<'a, Scalar> = LweKeyswitchKey<&'a [Scalar]>;
pub type LweKeyswitchKeyMutView<'a, Scalar> = LweKeyswitchKey<&'a mut [Scalar]>;
impl<Scalar: UnsignedInteger> LweKeyswitchKeyOwned<Scalar> {
/// Allocate memory and create a new owned [`LweKeyswitchKey`].

View File

@@ -88,10 +88,26 @@ impl<Scalar, C: Container<Element = Scalar>> LweSecretKey<C> {
pub fn into_container(self) -> C {
self.data
}
pub fn as_view(&self) -> LweSecretKeyView<'_, Scalar> {
LweSecretKey {
data: self.as_ref(),
}
}
}
impl<Scalar, C: ContainerMut<Element = Scalar>> LweSecretKey<C> {
pub fn as_mut_view(&mut self) -> LweSecretKeyMutView<'_, Scalar> {
LweSecretKey {
data: self.as_mut(),
}
}
}
/// An [`LweSecretKey`] owning the memory for its own storage.
pub type LweSecretKeyOwned<Scalar> = LweSecretKey<Vec<Scalar>>;
pub type LweSecretKeyView<'a, Scalar> = LweSecretKey<&'a [Scalar]>;
pub type LweSecretKeyMutView<'a, Scalar> = LweSecretKey<&'a mut [Scalar]>;
impl<Scalar> LweSecretKeyOwned<Scalar>
where

View File

@@ -3,9 +3,7 @@
//! The TFHE-rs preludes include convenient imports.
//! Having `tfhe::core_crypto::prelude::*;` should be enough to start using the lib.
pub use super::algorithms::{
add_external_product_assign, polynomial_algorithms, slice_algorithms, *,
};
pub use super::algorithms::{polynomial_algorithms, slice_algorithms, *};
pub use super::commons::computation_buffers::ComputationBuffers;
pub use super::commons::dispersion::*;
pub use super::commons::generators::{EncryptionRandomGenerator, SecretRandomGenerator};