test(shortint): add normality check after KS in classic AP

This commit is contained in:
Arthur Meyre
2024-12-13 16:16:03 +00:00
committed by Nicolas Sarlin
parent 474e17f6ad
commit f2148d50be

View File

@@ -30,7 +30,8 @@ use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap_128:
use crate::core_crypto::commons::noise_formulas::modulus_switch::modulus_switch_additive_variance;
use crate::core_crypto::commons::noise_formulas::secure_noise::{
minimal_lwe_variance_for_132_bits_security_gaussian,
minimal_lwe_variance_for_132_bits_security_tuniform, variance_to_tuniform_bound_log2,
minimal_lwe_variance_for_132_bits_security_tuniform,
//variance_to_tuniform_bound_log2,
};
use crate::core_crypto::commons::parameters::{
CiphertextModulus as CoreCiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
@@ -38,7 +39,8 @@ use crate::core_crypto::commons::parameters::{
};
use crate::core_crypto::commons::test_tools::{
arithmetic_mean, clopper_pearson_exact_confidence_interval, equivalent_pfail_gaussian_noise,
mean_confidence_interval, torus_modular_diff, variance, variance_confidence_interval,
mean_confidence_interval, normality_test_f64, torus_modular_diff, variance,
variance_confidence_interval,
};
use crate::core_crypto::commons::traits::{Container, UnsignedInteger};
use crate::core_crypto::entities::{GlweSecretKey, LweCiphertext, LweSecretKey, Plaintext};
@@ -343,7 +345,7 @@ fn classic_pbs_atomic_pattern_inner_helper(
single_sks: &ServerKey,
msg: u64,
scalar_for_multiplication: u8,
) -> DecryptionAndNoiseResult {
) -> (DecryptionAndNoiseResult, DecryptionAndNoiseResult) {
assert!(params.pbs_only());
assert!(
matches!(params.encryption_key_choice(), EncryptionKeyChoice::Big),
@@ -460,12 +462,21 @@ fn classic_pbs_atomic_pattern_inner_helper(
*dst = modulus_switch(*src, br_input_modulus_log) << shift_to_map_to_native;
}
DecryptionAndNoiseResult::new(
&after_ms,
&cks.small_lwe_secret_key(),
msg,
delta,
cleartext_modulus,
(
DecryptionAndNoiseResult::new(
&after_ks_lwe,
&cks.small_lwe_secret_key(),
msg,
delta,
cleartext_modulus,
),
DecryptionAndNoiseResult::new(
&after_ms,
&cks.small_lwe_secret_key(),
msg,
delta,
cleartext_modulus,
),
)
}
@@ -475,21 +486,30 @@ fn classic_pbs_atomic_pattern_noise_helper(
single_sks: &ServerKey,
msg: u64,
scalar_for_multiplication: u8,
) -> NoiseSample {
let decryption_and_noise_result = classic_pbs_atomic_pattern_inner_helper(
params,
single_cks,
single_sks,
msg,
scalar_for_multiplication,
);
) -> (NoiseSample, NoiseSample) {
let (decryption_and_noise_result_after_ks, decryption_and_noise_result_after_ms) =
classic_pbs_atomic_pattern_inner_helper(
params,
single_cks,
single_sks,
msg,
scalar_for_multiplication,
);
match decryption_and_noise_result {
DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise,
DecryptionAndNoiseResult::DecryptionFailed => {
panic!("Failed decryption, noise measurement will be wrong.")
}
}
(
match decryption_and_noise_result_after_ks {
DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise,
DecryptionAndNoiseResult::DecryptionFailed => {
panic!("Failed decryption, noise measurement will be wrong.")
}
},
match decryption_and_noise_result_after_ms {
DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise,
DecryptionAndNoiseResult::DecryptionFailed => {
panic!("Failed decryption, noise measurement will be wrong.")
}
},
)
}
/// Return 1 if the decryption failed, otherwise 0, allowing to sum the results of threads to get
@@ -501,15 +521,16 @@ fn classic_pbs_atomic_pattern_pfail_helper(
msg: u64,
scalar_for_multiplication: u8,
) -> f64 {
let decryption_and_noise_result = classic_pbs_atomic_pattern_inner_helper(
params,
single_cks,
single_sks,
msg,
scalar_for_multiplication,
);
let (_decryption_and_noise_result_after_ks, decryption_and_noise_result_after_ms) =
classic_pbs_atomic_pattern_inner_helper(
params,
single_cks,
single_sks,
msg,
scalar_for_multiplication,
);
match decryption_and_noise_result {
match decryption_and_noise_result_after_ms {
DecryptionAndNoiseResult::DecryptionSucceeded { .. } => 0.0,
DecryptionAndNoiseResult::DecryptionFailed => 1.0,
}
@@ -614,9 +635,11 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam
let expected_variance_after_ms = Variance(expected_variance_after_ks.0 + ms_additive_var.0);
let cleartext_modulus = params.message_modulus().0 * params.carry_modulus().0;
let mut noise_samples = vec![];
let mut noise_samples_after_ks = vec![];
let mut noise_samples_after_ms = vec![];
for msg in 0..cleartext_modulus {
let current_noise_samples: Vec<_> = (0..1000)
let (current_noise_samples_after_ks, current_noise_samples_after_ms): (Vec<_>, Vec<_>) = (0
..1000)
.into_par_iter()
.map(|_| {
classic_pbs_atomic_pattern_noise_helper(
@@ -626,41 +649,44 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam
msg,
scalar_for_multiplication.try_into().unwrap(),
)
.value
})
.collect();
.unzip();
noise_samples.extend(current_noise_samples);
noise_samples_after_ks.extend(current_noise_samples_after_ks.into_iter().map(|x| x.value));
noise_samples_after_ms.extend(current_noise_samples_after_ms.into_iter().map(|x| x.value));
}
let measured_mean = arithmetic_mean(&noise_samples);
let measured_variance = variance(&noise_samples);
let measured_mean_after_ms = arithmetic_mean(&noise_samples_after_ms);
let measured_variance_after_ms = variance(&noise_samples_after_ms);
let mean_ci = mean_confidence_interval(
noise_samples.len() as f64,
measured_mean,
measured_variance.get_standard_dev(),
noise_samples_after_ms.len() as f64,
measured_mean_after_ms,
measured_variance_after_ms.get_standard_dev(),
0.99,
);
let variance_ci =
variance_confidence_interval(noise_samples.len() as f64, measured_variance, 0.99);
let variance_ci = variance_confidence_interval(
noise_samples_after_ms.len() as f64,
measured_variance_after_ms,
0.99,
);
let expected_mean = 0.0;
let expected_mean_after_ms = 0.0;
println!("measured_variance={measured_variance:?}");
println!("measured_variance_after_ms={measured_variance_after_ms:?}");
println!("expected_variance_after_ms={expected_variance_after_ms:?}");
println!("variance_lower_bound={:?}", variance_ci.lower_bound());
println!("variance_upper_bound={:?}", variance_ci.upper_bound());
println!("measured_mean={measured_mean:?}");
println!("expected_mean={expected_mean:?}");
println!("measured_mean_after_ms={measured_mean_after_ms:?}");
println!("expected_mean_after_ms={expected_mean_after_ms:?}");
println!("mean_lower_bound={:?}", mean_ci.lower_bound());
println!("mean_upper_bound={:?}", mean_ci.upper_bound());
// Expected mean is 0
assert!(mean_ci.mean_is_in_interval(expected_mean));
assert!(mean_ci.mean_is_in_interval(expected_mean_after_ms));
// We want to be smaller but secure or in the interval
if measured_variance <= expected_variance_after_ms {
if measured_variance_after_ms <= expected_variance_after_ms {
let noise_for_security = match params.lwe_noise_distribution() {
DynamicDistribution::Gaussian(_) => {
minimal_lwe_variance_for_132_bits_security_gaussian(
@@ -684,11 +710,22 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam
);
}
assert!(measured_variance >= noise_for_security);
assert!(measured_variance_after_ms >= noise_for_security);
} else {
assert!(variance_ci.variance_is_in_interval(expected_variance_after_ms));
}
let normality_check = normality_test_f64(
&noise_samples_after_ks[..5000.min(noise_samples_after_ks.len())],
0.01,
);
if normality_check.null_hypothesis_is_valid {
println!("Normality check after KS is OK\n");
} else {
panic!("Normality check after KS failed");
}
// Normality check of heavily discretized gaussian does not seem to work
// let normality_check = normality_test_f64(&noise_samples[..5000.min(noise_samples.len())],
// 0.05); assert!(normality_check.null_hypothesis_is_valid);