chore(shortint): add tests for the KS32 AP

This commit is contained in:
Nicolas Sarlin
2025-04-04 17:41:08 +02:00
committed by Nicolas Sarlin
parent 8a26df9177
commit 597c61bbdb
6 changed files with 104 additions and 24 deletions

View File

@@ -192,7 +192,7 @@ def filter_shortint_tests(input_args):
msg_carry_pairs.append((4, 4))
filter_expression = [
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks_pbs.*/)"
f"test(/^shortint::.*_param{multi_bit_filter}{group_filter}_message_{msg}_carry_{carry}(_compact_pk)?_ks(32)?_pbs.*/)"
for msg, carry in msg_carry_pairs
]
filter_expression.append("test(/^shortint::.*_ci_run_filter/)")

View File

@@ -551,3 +551,46 @@ impl From<KS32AtomicPatternServerKey> for AtomicPatternServerKey {
Self::KeySwitch32(value)
}
}
#[cfg(test)]
mod test {
use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
use crate::shortint::{gen_keys, ServerKey};
use super::AtomicPatternServerKey;
// Test an implementation of the KS32 AP as a dynamic atomic pattern
#[test]
fn test_ks32_as_dyn_ap_ci_run_filter() {
let (client_key, server_key) =
gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128);
// Convert the static ks 32 server key into a dynamic one
let AtomicPatternServerKey::KeySwitch32(ks32_key) = server_key.atomic_pattern else {
panic!("We know from parameters that AP is KS32")
};
let ap_key = AtomicPatternServerKey::Dynamic(Box::new(ks32_key));
// Re create the server key with the DAP
let server_key = ServerKey::from_raw_parts(
ap_key,
server_key.message_modulus,
server_key.carry_modulus,
server_key.max_degree,
server_key.max_noise_level,
);
// Do some operation
let msg1 = 1;
let msg2 = 0;
let ct_1 = client_key.encrypt(msg1);
let ct_2 = client_key.encrypt(msg2);
let ct_3 = server_key.add(&ct_1, &ct_2);
let output = client_key.decrypt(&ct_3);
assert_eq!(output, 1);
}
}

View File

@@ -207,9 +207,11 @@ impl<AP: AtomicPattern> GenericServerKey<AP> {
#[cfg(test)]
pub(crate) mod test {
use crate::core_crypto::prelude::decrypt_lwe_ciphertext;
use crate::shortint::oprf::create_random_from_seed_modulus_switched;
use crate::shortint::{ClientKey, ServerKey};
use crate::core_crypto::prelude::{decrypt_lwe_ciphertext, LweSecretKey};
use crate::shortint::{ClientKey, ServerKey, ShortintParameterSet};
use super::*;
use rayon::prelude::*;
use statrs::distribution::ContinuousCDF;
use std::collections::HashMap;
@@ -222,22 +224,34 @@ pub(crate) mod test {
#[test]
fn oprf_compare_plain_ci_run_filter() {
use crate::shortint::gen_keys;
use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
for seed in 0..1000 {
oprf_compare_plain_from_seed(Seed(seed), &ck, &sk);
oprf_compare_plain_from_seed::<u64>(Seed(seed), &ck, &sk);
}
let (ck, sk) = gen_keys(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128);
for seed in 0..1000 {
oprf_compare_plain_from_seed::<u32>(Seed(seed), &ck, &sk);
}
}
fn oprf_compare_plain_from_seed(seed: Seed, ck: &ClientKey, sk: &ServerKey) {
fn oprf_compare_plain_from_seed<Scalar: UnsignedInteger + CastFrom<u64> + CastInto<u64>>(
seed: Seed,
ck: &ClientKey,
sk: &ServerKey,
) {
let params = ck.parameters;
let random_bits_count = 2;
let input_p = 2 * params.polynomial_size().0 as u64;
let log_input_p = input_p.ilog2();
let log_input_p = input_p.ilog2() as usize;
let p_prime = 1 << random_bits_count;
@@ -255,15 +269,24 @@ pub(crate) mod test {
params
.polynomial_size()
.to_blind_rotation_input_modulus_log(),
sk.ciphertext_modulus,
CiphertextModulus::new_native(),
);
let sk = ck.small_lwe_secret_key();
let sk = LweSecretKey::from_container(
ck.small_lwe_secret_key()
.as_ref()
.iter()
.copied()
.map(|x| Scalar::cast_from(x))
.collect::<Vec<_>>(),
);
let plain_prf_input = decrypt_lwe_ciphertext(&sk, &ct)
let plain_prf_input = CastInto::<u64>::cast_into(
decrypt_lwe_ciphertext(&sk, &ct)
.0
.wrapping_add(1 << (64 - log_input_p - 1))
>> (64 - log_input_p);
.wrapping_add(Scalar::ONE << (Scalar::BITS - log_input_p - 1))
>> (Scalar::BITS - log_input_p),
);
let half_negacyclic_part = |x| 2 * (x / poly_delta) + 1;
@@ -296,8 +319,14 @@ pub(crate) mod test {
let p_value_limit: f64 = 0.000_01;
use crate::shortint::gen_keys;
use crate::shortint::parameters::test_params::TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;
use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS);
for params in [
ShortintParameterSet::from(PARAM_MESSAGE_2_CARRY_2_KS_PBS),
ShortintParameterSet::from(TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128),
] {
let (ck, sk) = gen_keys(params);
let test_uniformity = |distinct_values: u64, f: &(dyn Fn(usize) -> u64 + Sync)| {
test_uniformity(sample_count, p_value_limit, distinct_values, f)
@@ -306,11 +335,13 @@ pub(crate) mod test {
let random_bits_count = 2;
test_uniformity(1 << random_bits_count, &|seed| {
let img = sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count);
let img =
sk.generate_oblivious_pseudo_random(Seed(seed as u128), random_bits_count);
ck.decrypt_message_and_carry(&img)
});
}
}
pub fn test_uniformity<F>(sample_count: usize, p_value_limit: f64, distinct_values: u64, f: F)
where

View File

@@ -44,6 +44,7 @@ use current_params::multi_bit::tuniform::p_fail_2_minus_64::ks_pbs_gpu::{
V1_1_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_3_CARRY_3_KS_PBS_TUNIFORM_2M64,
};
use current_params::noise_squashing::p_fail_2_minus_128::V1_1_NOISE_SQUASHING_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// Aliases
// Compute Gaussian

View File

@@ -1,5 +1,5 @@
use super::current_params::*;
use super::AtomicPatternParameters;
use super::{AtomicPatternParameters, KeySwitch32PBSParameters};
use super::{
ClassicPBSParameters, CompactPublicKeyEncryptionParameters, CompressionParameters,
@@ -209,3 +209,7 @@ pub const TEST_COMP_PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128: CompressionPa
pub const TEST_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128:
CompressionParameters =
V1_1_COMP_PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
// KS32 PBS AP
pub const TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128: KeySwitch32PBSParameters =
V1_1_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128;

View File

@@ -55,7 +55,8 @@ macro_rules! create_parameterized_test{
TEST_PARAM_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64
TEST_PARAM_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M64,
TEST_PARAM_MESSAGE_2_CARRY_2_KS32_PBS_TUNIFORM_2M128
});
};
}