mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-04-28 03:01:21 -04:00
test(shortint): add normality check after KS in classic AP
This commit is contained in:
committed by
Nicolas Sarlin
parent
474e17f6ad
commit
f2148d50be
@@ -30,7 +30,8 @@ use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap_128:
|
||||
use crate::core_crypto::commons::noise_formulas::modulus_switch::modulus_switch_additive_variance;
|
||||
use crate::core_crypto::commons::noise_formulas::secure_noise::{
|
||||
minimal_lwe_variance_for_132_bits_security_gaussian,
|
||||
minimal_lwe_variance_for_132_bits_security_tuniform, variance_to_tuniform_bound_log2,
|
||||
minimal_lwe_variance_for_132_bits_security_tuniform,
|
||||
//variance_to_tuniform_bound_log2,
|
||||
};
|
||||
use crate::core_crypto::commons::parameters::{
|
||||
CiphertextModulus as CoreCiphertextModulus, DecompositionBaseLog, DecompositionLevelCount,
|
||||
@@ -38,7 +39,8 @@ use crate::core_crypto::commons::parameters::{
|
||||
};
|
||||
use crate::core_crypto::commons::test_tools::{
|
||||
arithmetic_mean, clopper_pearson_exact_confidence_interval, equivalent_pfail_gaussian_noise,
|
||||
mean_confidence_interval, torus_modular_diff, variance, variance_confidence_interval,
|
||||
mean_confidence_interval, normality_test_f64, torus_modular_diff, variance,
|
||||
variance_confidence_interval,
|
||||
};
|
||||
use crate::core_crypto::commons::traits::{Container, UnsignedInteger};
|
||||
use crate::core_crypto::entities::{GlweSecretKey, LweCiphertext, LweSecretKey, Plaintext};
|
||||
@@ -343,7 +345,7 @@ fn classic_pbs_atomic_pattern_inner_helper(
|
||||
single_sks: &ServerKey,
|
||||
msg: u64,
|
||||
scalar_for_multiplication: u8,
|
||||
) -> DecryptionAndNoiseResult {
|
||||
) -> (DecryptionAndNoiseResult, DecryptionAndNoiseResult) {
|
||||
assert!(params.pbs_only());
|
||||
assert!(
|
||||
matches!(params.encryption_key_choice(), EncryptionKeyChoice::Big),
|
||||
@@ -460,12 +462,21 @@ fn classic_pbs_atomic_pattern_inner_helper(
|
||||
*dst = modulus_switch(*src, br_input_modulus_log) << shift_to_map_to_native;
|
||||
}
|
||||
|
||||
DecryptionAndNoiseResult::new(
|
||||
&after_ms,
|
||||
&cks.small_lwe_secret_key(),
|
||||
msg,
|
||||
delta,
|
||||
cleartext_modulus,
|
||||
(
|
||||
DecryptionAndNoiseResult::new(
|
||||
&after_ks_lwe,
|
||||
&cks.small_lwe_secret_key(),
|
||||
msg,
|
||||
delta,
|
||||
cleartext_modulus,
|
||||
),
|
||||
DecryptionAndNoiseResult::new(
|
||||
&after_ms,
|
||||
&cks.small_lwe_secret_key(),
|
||||
msg,
|
||||
delta,
|
||||
cleartext_modulus,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -475,21 +486,30 @@ fn classic_pbs_atomic_pattern_noise_helper(
|
||||
single_sks: &ServerKey,
|
||||
msg: u64,
|
||||
scalar_for_multiplication: u8,
|
||||
) -> NoiseSample {
|
||||
let decryption_and_noise_result = classic_pbs_atomic_pattern_inner_helper(
|
||||
params,
|
||||
single_cks,
|
||||
single_sks,
|
||||
msg,
|
||||
scalar_for_multiplication,
|
||||
);
|
||||
) -> (NoiseSample, NoiseSample) {
|
||||
let (decryption_and_noise_result_after_ks, decryption_and_noise_result_after_ms) =
|
||||
classic_pbs_atomic_pattern_inner_helper(
|
||||
params,
|
||||
single_cks,
|
||||
single_sks,
|
||||
msg,
|
||||
scalar_for_multiplication,
|
||||
);
|
||||
|
||||
match decryption_and_noise_result {
|
||||
DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise,
|
||||
DecryptionAndNoiseResult::DecryptionFailed => {
|
||||
panic!("Failed decryption, noise measurement will be wrong.")
|
||||
}
|
||||
}
|
||||
(
|
||||
match decryption_and_noise_result_after_ks {
|
||||
DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise,
|
||||
DecryptionAndNoiseResult::DecryptionFailed => {
|
||||
panic!("Failed decryption, noise measurement will be wrong.")
|
||||
}
|
||||
},
|
||||
match decryption_and_noise_result_after_ms {
|
||||
DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise,
|
||||
DecryptionAndNoiseResult::DecryptionFailed => {
|
||||
panic!("Failed decryption, noise measurement will be wrong.")
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Return 1 if the decryption failed, otherwise 0, allowing to sum the results of threads to get
|
||||
@@ -501,15 +521,16 @@ fn classic_pbs_atomic_pattern_pfail_helper(
|
||||
msg: u64,
|
||||
scalar_for_multiplication: u8,
|
||||
) -> f64 {
|
||||
let decryption_and_noise_result = classic_pbs_atomic_pattern_inner_helper(
|
||||
params,
|
||||
single_cks,
|
||||
single_sks,
|
||||
msg,
|
||||
scalar_for_multiplication,
|
||||
);
|
||||
let (_decryption_and_noise_result_after_ks, decryption_and_noise_result_after_ms) =
|
||||
classic_pbs_atomic_pattern_inner_helper(
|
||||
params,
|
||||
single_cks,
|
||||
single_sks,
|
||||
msg,
|
||||
scalar_for_multiplication,
|
||||
);
|
||||
|
||||
match decryption_and_noise_result {
|
||||
match decryption_and_noise_result_after_ms {
|
||||
DecryptionAndNoiseResult::DecryptionSucceeded { .. } => 0.0,
|
||||
DecryptionAndNoiseResult::DecryptionFailed => 1.0,
|
||||
}
|
||||
@@ -614,9 +635,11 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam
|
||||
let expected_variance_after_ms = Variance(expected_variance_after_ks.0 + ms_additive_var.0);
|
||||
|
||||
let cleartext_modulus = params.message_modulus().0 * params.carry_modulus().0;
|
||||
let mut noise_samples = vec![];
|
||||
let mut noise_samples_after_ks = vec![];
|
||||
let mut noise_samples_after_ms = vec![];
|
||||
for msg in 0..cleartext_modulus {
|
||||
let current_noise_samples: Vec<_> = (0..1000)
|
||||
let (current_noise_samples_after_ks, current_noise_samples_after_ms): (Vec<_>, Vec<_>) = (0
|
||||
..1000)
|
||||
.into_par_iter()
|
||||
.map(|_| {
|
||||
classic_pbs_atomic_pattern_noise_helper(
|
||||
@@ -626,41 +649,44 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam
|
||||
msg,
|
||||
scalar_for_multiplication.try_into().unwrap(),
|
||||
)
|
||||
.value
|
||||
})
|
||||
.collect();
|
||||
.unzip();
|
||||
|
||||
noise_samples.extend(current_noise_samples);
|
||||
noise_samples_after_ks.extend(current_noise_samples_after_ks.into_iter().map(|x| x.value));
|
||||
noise_samples_after_ms.extend(current_noise_samples_after_ms.into_iter().map(|x| x.value));
|
||||
}
|
||||
|
||||
let measured_mean = arithmetic_mean(&noise_samples);
|
||||
let measured_variance = variance(&noise_samples);
|
||||
let measured_mean_after_ms = arithmetic_mean(&noise_samples_after_ms);
|
||||
let measured_variance_after_ms = variance(&noise_samples_after_ms);
|
||||
|
||||
let mean_ci = mean_confidence_interval(
|
||||
noise_samples.len() as f64,
|
||||
measured_mean,
|
||||
measured_variance.get_standard_dev(),
|
||||
noise_samples_after_ms.len() as f64,
|
||||
measured_mean_after_ms,
|
||||
measured_variance_after_ms.get_standard_dev(),
|
||||
0.99,
|
||||
);
|
||||
|
||||
let variance_ci =
|
||||
variance_confidence_interval(noise_samples.len() as f64, measured_variance, 0.99);
|
||||
let variance_ci = variance_confidence_interval(
|
||||
noise_samples_after_ms.len() as f64,
|
||||
measured_variance_after_ms,
|
||||
0.99,
|
||||
);
|
||||
|
||||
let expected_mean = 0.0;
|
||||
let expected_mean_after_ms = 0.0;
|
||||
|
||||
println!("measured_variance={measured_variance:?}");
|
||||
println!("measured_variance_after_ms={measured_variance_after_ms:?}");
|
||||
println!("expected_variance_after_ms={expected_variance_after_ms:?}");
|
||||
println!("variance_lower_bound={:?}", variance_ci.lower_bound());
|
||||
println!("variance_upper_bound={:?}", variance_ci.upper_bound());
|
||||
println!("measured_mean={measured_mean:?}");
|
||||
println!("expected_mean={expected_mean:?}");
|
||||
println!("measured_mean_after_ms={measured_mean_after_ms:?}");
|
||||
println!("expected_mean_after_ms={expected_mean_after_ms:?}");
|
||||
println!("mean_lower_bound={:?}", mean_ci.lower_bound());
|
||||
println!("mean_upper_bound={:?}", mean_ci.upper_bound());
|
||||
|
||||
// Expected mean is 0
|
||||
assert!(mean_ci.mean_is_in_interval(expected_mean));
|
||||
assert!(mean_ci.mean_is_in_interval(expected_mean_after_ms));
|
||||
// We want to be smaller but secure or in the interval
|
||||
if measured_variance <= expected_variance_after_ms {
|
||||
if measured_variance_after_ms <= expected_variance_after_ms {
|
||||
let noise_for_security = match params.lwe_noise_distribution() {
|
||||
DynamicDistribution::Gaussian(_) => {
|
||||
minimal_lwe_variance_for_132_bits_security_gaussian(
|
||||
@@ -684,11 +710,22 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam
|
||||
);
|
||||
}
|
||||
|
||||
assert!(measured_variance >= noise_for_security);
|
||||
assert!(measured_variance_after_ms >= noise_for_security);
|
||||
} else {
|
||||
assert!(variance_ci.variance_is_in_interval(expected_variance_after_ms));
|
||||
}
|
||||
|
||||
let normality_check = normality_test_f64(
|
||||
&noise_samples_after_ks[..5000.min(noise_samples_after_ks.len())],
|
||||
0.01,
|
||||
);
|
||||
|
||||
if normality_check.null_hypothesis_is_valid {
|
||||
println!("Normality check after KS is OK\n");
|
||||
} else {
|
||||
panic!("Normality check after KS failed");
|
||||
}
|
||||
|
||||
// Normality check of heavily discretized gaussian does not seem to work
|
||||
// let normality_check = normality_test_f64(&noise_samples[..5000.min(noise_samples.len())],
|
||||
// 0.05); assert!(normality_check.null_hypothesis_is_valid);
|
||||
|
||||
Reference in New Issue
Block a user