From f2148d50bee85d906fe6648ffc3b1c4e9ef8a441 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 13 Dec 2024 16:16:03 +0000 Subject: [PATCH] test(shortint): add normality check after KS in classic AP --- .../noise_distribution/atomic_pattern.rs | 137 +++++++++++------- 1 file changed, 87 insertions(+), 50 deletions(-) diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/atomic_pattern.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/atomic_pattern.rs index bc28e7635..21732133e 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/atomic_pattern.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/atomic_pattern.rs @@ -30,7 +30,8 @@ use crate::core_crypto::commons::noise_formulas::lwe_programmable_bootstrap_128: use crate::core_crypto::commons::noise_formulas::modulus_switch::modulus_switch_additive_variance; use crate::core_crypto::commons::noise_formulas::secure_noise::{ minimal_lwe_variance_for_132_bits_security_gaussian, - minimal_lwe_variance_for_132_bits_security_tuniform, variance_to_tuniform_bound_log2, + minimal_lwe_variance_for_132_bits_security_tuniform, + //variance_to_tuniform_bound_log2, }; use crate::core_crypto::commons::parameters::{ CiphertextModulus as CoreCiphertextModulus, DecompositionBaseLog, DecompositionLevelCount, @@ -38,7 +39,8 @@ use crate::core_crypto::commons::parameters::{ }; use crate::core_crypto::commons::test_tools::{ arithmetic_mean, clopper_pearson_exact_confidence_interval, equivalent_pfail_gaussian_noise, - mean_confidence_interval, torus_modular_diff, variance, variance_confidence_interval, + mean_confidence_interval, normality_test_f64, torus_modular_diff, variance, + variance_confidence_interval, }; use crate::core_crypto::commons::traits::{Container, UnsignedInteger}; use crate::core_crypto::entities::{GlweSecretKey, LweCiphertext, LweSecretKey, Plaintext}; @@ -343,7 +345,7 @@ fn classic_pbs_atomic_pattern_inner_helper( single_sks: &ServerKey, msg: u64, scalar_for_multiplication: u8, -) -> DecryptionAndNoiseResult { +) -> (DecryptionAndNoiseResult, DecryptionAndNoiseResult) { assert!(params.pbs_only()); assert!( matches!(params.encryption_key_choice(), EncryptionKeyChoice::Big), @@ -460,12 +462,21 @@ fn classic_pbs_atomic_pattern_inner_helper( *dst = modulus_switch(*src, br_input_modulus_log) << shift_to_map_to_native; } - DecryptionAndNoiseResult::new( - &after_ms, - &cks.small_lwe_secret_key(), - msg, - delta, - cleartext_modulus, + ( + DecryptionAndNoiseResult::new( + &after_ks_lwe, + &cks.small_lwe_secret_key(), + msg, + delta, + cleartext_modulus, + ), + DecryptionAndNoiseResult::new( + &after_ms, + &cks.small_lwe_secret_key(), + msg, + delta, + cleartext_modulus, + ), ) } @@ -475,21 +486,30 @@ fn classic_pbs_atomic_pattern_noise_helper( single_sks: &ServerKey, msg: u64, scalar_for_multiplication: u8, -) -> NoiseSample { - let decryption_and_noise_result = classic_pbs_atomic_pattern_inner_helper( - params, - single_cks, - single_sks, - msg, - scalar_for_multiplication, - ); +) -> (NoiseSample, NoiseSample) { + let (decryption_and_noise_result_after_ks, decryption_and_noise_result_after_ms) = + classic_pbs_atomic_pattern_inner_helper( + params, + single_cks, + single_sks, + msg, + scalar_for_multiplication, + ); - match decryption_and_noise_result { - DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise, - DecryptionAndNoiseResult::DecryptionFailed => { - panic!("Failed decryption, noise measurement will be wrong.") - } - } + ( + match decryption_and_noise_result_after_ks { + DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise, + DecryptionAndNoiseResult::DecryptionFailed => { + panic!("Failed decryption, noise measurement will be wrong.") + } + }, + match decryption_and_noise_result_after_ms { + DecryptionAndNoiseResult::DecryptionSucceeded { noise } => noise, + DecryptionAndNoiseResult::DecryptionFailed => { + panic!("Failed decryption, noise measurement will be wrong.") + } + }, + ) } /// Return 1 if the decryption failed, otherwise 0, allowing to sum the results of threads to get @@ -501,15 +521,16 @@ fn classic_pbs_atomic_pattern_pfail_helper( msg: u64, scalar_for_multiplication: u8, ) -> f64 { - let decryption_and_noise_result = classic_pbs_atomic_pattern_inner_helper( - params, - single_cks, - single_sks, - msg, - scalar_for_multiplication, - ); + let (_decryption_and_noise_result_after_ks, decryption_and_noise_result_after_ms) = + classic_pbs_atomic_pattern_inner_helper( + params, + single_cks, + single_sks, + msg, + scalar_for_multiplication, + ); - match decryption_and_noise_result { + match decryption_and_noise_result_after_ms { DecryptionAndNoiseResult::DecryptionSucceeded { .. } => 0.0, DecryptionAndNoiseResult::DecryptionFailed => 1.0, } @@ -614,9 +635,11 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam let expected_variance_after_ms = Variance(expected_variance_after_ks.0 + ms_additive_var.0); let cleartext_modulus = params.message_modulus().0 * params.carry_modulus().0; - let mut noise_samples = vec![]; + let mut noise_samples_after_ks = vec![]; + let mut noise_samples_after_ms = vec![]; for msg in 0..cleartext_modulus { - let current_noise_samples: Vec<_> = (0..1000) + let (current_noise_samples_after_ks, current_noise_samples_after_ms): (Vec<_>, Vec<_>) = (0 + ..1000) .into_par_iter() .map(|_| { classic_pbs_atomic_pattern_noise_helper( @@ -626,41 +649,44 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam msg, scalar_for_multiplication.try_into().unwrap(), ) - .value }) - .collect(); + .unzip(); - noise_samples.extend(current_noise_samples); + noise_samples_after_ks.extend(current_noise_samples_after_ks.into_iter().map(|x| x.value)); + noise_samples_after_ms.extend(current_noise_samples_after_ms.into_iter().map(|x| x.value)); } - let measured_mean = arithmetic_mean(&noise_samples); - let measured_variance = variance(&noise_samples); + let measured_mean_after_ms = arithmetic_mean(&noise_samples_after_ms); + let measured_variance_after_ms = variance(&noise_samples_after_ms); let mean_ci = mean_confidence_interval( - noise_samples.len() as f64, - measured_mean, - measured_variance.get_standard_dev(), + noise_samples_after_ms.len() as f64, + measured_mean_after_ms, + measured_variance_after_ms.get_standard_dev(), 0.99, ); - let variance_ci = - variance_confidence_interval(noise_samples.len() as f64, measured_variance, 0.99); + let variance_ci = variance_confidence_interval( + noise_samples_after_ms.len() as f64, + measured_variance_after_ms, + 0.99, + ); - let expected_mean = 0.0; + let expected_mean_after_ms = 0.0; - println!("measured_variance={measured_variance:?}"); + println!("measured_variance_after_ms={measured_variance_after_ms:?}"); println!("expected_variance_after_ms={expected_variance_after_ms:?}"); println!("variance_lower_bound={:?}", variance_ci.lower_bound()); println!("variance_upper_bound={:?}", variance_ci.upper_bound()); - println!("measured_mean={measured_mean:?}"); - println!("expected_mean={expected_mean:?}"); + println!("measured_mean_after_ms={measured_mean_after_ms:?}"); + println!("expected_mean_after_ms={expected_mean_after_ms:?}"); println!("mean_lower_bound={:?}", mean_ci.lower_bound()); println!("mean_upper_bound={:?}", mean_ci.upper_bound()); // Expected mean is 0 - assert!(mean_ci.mean_is_in_interval(expected_mean)); + assert!(mean_ci.mean_is_in_interval(expected_mean_after_ms)); // We want to be smaller but secure or in the interval - if measured_variance <= expected_variance_after_ms { + if measured_variance_after_ms <= expected_variance_after_ms { let noise_for_security = match params.lwe_noise_distribution() { DynamicDistribution::Gaussian(_) => { minimal_lwe_variance_for_132_bits_security_gaussian( @@ -684,11 +710,22 @@ fn noise_check_shortint_classic_pbs_atomic_pattern_noise(params: ClassicPBSParam ); } - assert!(measured_variance >= noise_for_security); + assert!(measured_variance_after_ms >= noise_for_security); } else { assert!(variance_ci.variance_is_in_interval(expected_variance_after_ms)); } + let normality_check = normality_test_f64( + &noise_samples_after_ks[..5000.min(noise_samples_after_ks.len())], + 0.01, + ); + + if normality_check.null_hypothesis_is_valid { + println!("Normality check after KS is OK\n"); + } else { + panic!("Normality check after KS failed"); + } + // Normality check of heavily discretized gaussian does not seem to work // let normality_check = normality_test_f64(&noise_samples[..5000.min(noise_samples.len())], // 0.05); assert!(normality_check.null_hypothesis_is_valid);