From a41cd47b9e353ea4fc14a176044edb84af689e72 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Mon, 27 Oct 2025 16:23:13 +0100 Subject: [PATCH] refactor(test): make modulus switch config system make more sense - The config type can hold any type for the drift technique variant because the bounds are too weird to set on the type, the functions making use of the config type should properly declare the bounds --- .../tests/noise_distribution/br_dp_ks_ms.rs | 58 ++---- .../noise_distribution/br_rerand_dp_ks_ms.rs | 59 ++----- .../tests/noise_distribution/cpk_ks_ms.rs | 46 ++--- .../tests/noise_distribution/dp_ks_ms.rs | 68 ++------ .../dp_ks_pbs128_packingks.rs | 63 ++----- .../utils/noise_simulation.rs | 165 +++++++++++++++--- 6 files changed, 211 insertions(+), 248 deletions(-) diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs index 5e1296196..6b0a0e717 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs @@ -1,8 +1,7 @@ use super::dp_ks_ms::dp_ks_any_ms; use super::utils::noise_simulation::{ - NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe, - NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk, + NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -49,8 +48,7 @@ pub fn br_dp_ks_any_ms< bsk: &PBSKey, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, accumulator: &Accumulator, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, @@ -96,7 +94,6 @@ where scalar, ksk, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -126,12 +123,7 @@ where let id_lut = sks.generate_lookup_table(|x| x); let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); for _ in 0..10 { let input_zero_as_lwe = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0); @@ -142,8 +134,7 @@ where &sks, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), @@ -210,12 +201,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0); @@ -226,8 +212,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper( sks, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), @@ -418,30 +403,19 @@ where let noise_simulation_ksk = NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_bsk = NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); assert!(noise_simulation_bsk.matches_actual_shortint_server_key(&sks)); let max_scalar_mul = sks.max_noise_level.get(); @@ -464,8 +438,7 @@ where &noise_simulation_bsk, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_accumulator, br_input_modulus_log, &mut (), @@ -483,8 +456,7 @@ where &sks, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs index 302f8e9c7..d51c7c2ec 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs @@ -1,8 +1,7 @@ use super::dp_ks_ms::any_ms; use super::utils::noise_simulation::{ - DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe, - NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + DynLwe, NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk, + NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -73,8 +72,7 @@ pub fn br_rerand_dp_ks_any_ms< ksk_rerand: &KsKeyRerand, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, decomp_accumulator: &Accumulator, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, @@ -139,7 +137,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -234,12 +231,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = comp_private_key.encrypt_noiseless_decompression_input_dyn_lwe(cks, 0, &mut engine); @@ -280,8 +272,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper( ksk_rerand, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, br_input_modulus_log, &mut (), @@ -631,31 +622,20 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise

( NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); let noise_simulation_ksk_rerand = NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, rerand_ksk_params, params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_decomp_bsk = NoiseSimulationLweFourierBsk::new_from_comp_parameters(params, compression_params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); assert!(noise_simulation_ksk_rerand.matches_actual_shortint_keyswitching_key(&ksk_rerand)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); assert!(noise_simulation_decomp_bsk.matches_actual_shortint_decomp_key(&decomp_key)); let max_scalar_mul = sks.max_noise_level.get(); @@ -691,8 +671,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise

( &noise_simulation_ksk_rerand, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_accumulator, compute_br_input_modulus_log, &mut (), @@ -738,8 +717,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise

( &ksk_rerand, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, compute_br_input_modulus_log, &mut (), @@ -980,15 +958,9 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs

( KeySwitchingKeyBuildHelper::new((&cpk_private_key, None), (&cks, &sks), rerand_ksk_params); let ksk_rerand: KeySwitchingKeyView<'_> = ksk_rerand_builder.as_key_switching_key_view(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; - let max_scalar_mul = sks.max_noise_level.get(); let decomp_rescale_lut = decomp_key.rescaling_lut( @@ -1095,8 +1067,7 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs

( &ksk_rerand, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, compute_br_input_modulus_log, &mut (), diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs index cbaf312c6..e8f865902 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs @@ -1,7 +1,6 @@ use super::dp_ks_ms::any_ms; use super::utils::noise_simulation::{ - DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + DynLwe, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -44,8 +43,7 @@ pub fn cpk_ks_any_ms< >( input: InputCt, ksk_ds: &KsKeyDs, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, ) -> (InputCt, KsResult, Option, MsResult) @@ -76,7 +74,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -141,12 +138,7 @@ fn cpk_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = { let compact_list = cpk.encrypt_iter_with_modulus_with_engine( @@ -165,8 +157,7 @@ fn cpk_ks_any_ms_inner_helper( let (input, after_ks_ds, after_drift, after_ms) = cpk_ks_any_ms( ct, ksk_ds, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); @@ -365,37 +356,25 @@ fn noise_check_encrypt_cpk_ks_ms_noise

( NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); let noise_simulation_ksk_ds = NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, ksk_ds_params, params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); assert!(noise_simulation_ksk_ds.matches_actual_shortint_keyswitching_key(&ksk_ds)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); let (_input_sim, _after_ks_ds_sim, _after_drift_sim, after_ms_sim) = { let noise_simulation_input = NoiseSimulationLwe::encrypt_with_cpk(&cpk); cpk_ks_any_ms( noise_simulation_input, &noise_simulation_ksk_ds, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), compute_br_input_modulus_log, &mut (), ) @@ -417,8 +396,7 @@ fn noise_check_encrypt_cpk_ks_ms_noise

( let (_input, _after_ks_ds, _before_ms, after_ms) = cpk_ks_any_ms( sample_input, &ksk_ds, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, compute_br_input_modulus_log, &mut (), ); diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs index 329fcacd4..bb8c0c65f 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs @@ -22,8 +22,7 @@ use rayon::prelude::*; pub fn any_ms( input: &InputCt, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, ) -> (Option, MsResult) @@ -46,10 +45,9 @@ where SideResources = Resources, >, { - match (modulus_switch_configuration, mod_switch_noise_reduction_key) { - ( - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction, - Some(mod_switch_noise_reduction_key), + match modulus_switch_configuration { + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction( + mod_switch_noise_reduction_key, ) => { let (mut drift_technique_result, mut ms_result) = mod_switch_noise_reduction_key .allocate_drift_technique_standard_mod_switch_result(side_resources); @@ -63,13 +61,13 @@ where (Some(drift_technique_result), ms_result) } - (NoiseSimulationModulusSwitchConfig::Standard, None) => { + NoiseSimulationModulusSwitchConfig::Standard => { let mut ms_result = input.allocate_standard_mod_switch_result(side_resources); input.standard_mod_switch(br_input_modulus_log, &mut ms_result, side_resources); (None, ms_result) } - (NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction, None) => { + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => { let mut ms_result = input.allocate_centered_binary_shifted_standard_mod_switch_result(side_resources); input.centered_binary_shifted_and_standard_mod_switch( @@ -80,7 +78,6 @@ where (None, ms_result) } - _ => panic!("Inconsistent modulus switch and drift key configuration"), } } @@ -98,8 +95,7 @@ pub fn dp_ks_any_ms< input: InputCt, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, ) -> ( @@ -140,7 +136,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -169,12 +164,7 @@ where let id_lut = sks.generate_lookup_table(|x| x); let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); for _ in 0..10 { let input_zero = cks.encrypt(0); @@ -184,8 +174,7 @@ where input_zero_as_lwe, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); @@ -234,12 +223,7 @@ fn encrypt_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = DynLwe::U64(cks.unchecked_encrypt(msg).ct); @@ -247,8 +231,7 @@ fn encrypt_dp_ks_any_ms_inner_helper( ct, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); @@ -419,28 +402,17 @@ where let noise_simulation_ksk = NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); let max_scalar_mul = sks.max_noise_level.get(); @@ -450,8 +422,7 @@ where noise_simulation, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), br_input_modulus_log, &mut (), ) @@ -466,8 +437,7 @@ where sample_input, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs index 4a0cdd6d9..c1c703470 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs @@ -43,8 +43,7 @@ fn dp_ks_any_ms_standard_pbs128< input: InputCt, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key_128: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, bsk_128: &Bsk, br_input_modulus_log: CiphertextModulusLog, accumulator: &Accumulator, @@ -93,7 +92,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key_128, br_input_modulus_log, side_resources, ); @@ -131,8 +129,7 @@ fn dp_ks_any_ms_standard_pbs128_packing_ks< input: Vec, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key_128: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, bsk_128: &Bsk, br_input_modulus_log: CiphertextModulusLog, accumulator: &Accumulator, @@ -199,7 +196,6 @@ where scalar, ksk, modulus_switch_configuration, - mod_switch_noise_reduction_key_128, bsk_128, br_input_modulus_log, accumulator, @@ -260,15 +256,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks

( let lwe_per_glwe = noise_squashing_compression_key.lwe_per_glwe(); - let noise_simulation_modulus_switch_config = - noise_squashing_key.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => { - Some(&noise_squashing_key) - } - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config(); let br_input_modulus_log = noise_squashing_key.br_input_modulus_log(); @@ -304,8 +292,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks

( input_zero_as_lwe, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &noise_squashing_key, br_input_modulus_log, &id_lut, @@ -427,15 +414,8 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper( ) }; - let noise_simulation_modulus_switch_config = - noise_squashing_key.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => { - Some(noise_squashing_key) - } - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config(); + let bsk_polynomial_size = noise_squashing_key.polynomial_size(); let bsk_glwe_size = noise_squashing_key.glwe_size(); let br_input_modulus_log = noise_squashing_key.br_input_modulus_log(); @@ -470,8 +450,7 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper( inputs, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, noise_squashing_key, br_input_modulus_log, &id_lut, @@ -712,8 +691,8 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise

( let noise_simulation_ksk = NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_bsk128 = NoiseSimulationLweFourier128Bsk::new_from_parameters(params, noise_squashing_params); let noise_simulation_packing_key = NoiseSimulationLwePackingKeyswitchKey::new_from_params( @@ -721,26 +700,11 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise

( noise_squashing_compression_params, ); - let noise_simulation_modulus_switch_config = - noise_squashing_key.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => { - Some(&noise_squashing_key) - } - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config(); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!( - noise_simulation_drift_key.matches_actual_shortint_noise_squashing_key(drift_key) - ) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_noise_squashing_modulus_switch_config(modulus_switch_config)); assert!( noise_simulation_bsk128.matches_actual_shortint_noise_squashing_key(&noise_squashing_key) ); @@ -766,8 +730,7 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise

( vec![noise_simulation; noise_squashing_compression_key.lwe_per_glwe().0], max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_bsk128, br_input_modulus_log, &noise_simulation_accumulator, diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs index 592e73b10..354f66360 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs @@ -462,36 +462,129 @@ impl AllocateLweKeyswitchResult for ServerKey { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum NoiseSimulationModulusSwitchConfig { +pub enum NoiseSimulationModulusSwitchConfig { Standard, - DriftTechniqueNoiseReduction, + DriftTechniqueNoiseReduction(DriftKey), CenteredMeanNoiseReduction, } -impl NoiseSimulationModulusSwitchConfig { - pub fn expected_average_after_ms(self, polynomial_size: PolynomialSize) -> f64 { - match self { - Self::Standard => 0.0f64, - Self::DriftTechniqueNoiseReduction => 0.0f64, - Self::CenteredMeanNoiseReduction => { - // Half case subtracted before entering the blind rotate - -1.0f64 / (4.0 * polynomial_size.0 as f64) +impl NoiseSimulationModulusSwitchConfig { + pub fn new_from_atomic_pattern_parameters(params: AtomicPatternParameters) -> Self { + let drift_key = + NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + + match params { + AtomicPatternParameters::Standard(pbsparameters) => match pbsparameters { + PBSParameters::PBS(classic_pbsparameters) => { + match classic_pbsparameters.modulus_switch_noise_reduction_params { + ModulusSwitchType::Standard => Self::Standard, + ModulusSwitchType::DriftTechniqueNoiseReduction(_) => { + Self::DriftTechniqueNoiseReduction( + drift_key.expect("Invalid drift key configuration"), + ) + } + ModulusSwitchType::CenteredMeanNoiseReduction => { + Self::CenteredMeanNoiseReduction + } + } + } + PBSParameters::MultiBitPBS(_) => { + panic!( + "Unsupported ShortintBootstrappingKey::MultiBit \ + for NoiseSimulationModulusSwitchConfig" + ) + } + }, + AtomicPatternParameters::KeySwitch32(key_switch32_pbsparameters) => { + match &key_switch32_pbsparameters.modulus_switch_noise_reduction_params { + ModulusSwitchType::Standard => Self::Standard, + ModulusSwitchType::DriftTechniqueNoiseReduction(_) => { + Self::DriftTechniqueNoiseReduction( + drift_key.expect("Invalid drift key configuration"), + ) + } + ModulusSwitchType::CenteredMeanNoiseReduction => { + Self::CenteredMeanNoiseReduction + } + } + } + } + } + + pub fn matches_shortint_server_key_modulus_switch_config( + &self, + shortint_config: NoiseSimulationModulusSwitchConfig<&ServerKey>, + ) -> bool { + match (self, shortint_config) { + (Self::Standard, NoiseSimulationModulusSwitchConfig::Standard) => true, + ( + Self::DriftTechniqueNoiseReduction(noise_sim), + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(sks), + ) => noise_sim.matches_actual_shortint_server_key(sks), + ( + Self::CenteredMeanNoiseReduction, + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction, + ) => true, + _ => false, + } + } + + pub fn matches_shortint_noise_squashing_modulus_switch_config( + &self, + shortint_config: NoiseSimulationModulusSwitchConfig<&NoiseSquashingKey>, + ) -> bool { + match (self, shortint_config) { + (Self::Standard, NoiseSimulationModulusSwitchConfig::Standard) => true, + ( + Self::DriftTechniqueNoiseReduction(noise_sim), + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(sns), + ) => noise_sim.matches_actual_shortint_noise_squashing_key(sns), + ( + Self::CenteredMeanNoiseReduction, + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction, + ) => true, + _ => false, + } + } +} + +impl NoiseSimulationModulusSwitchConfig { + fn new_from_config_and_key( + config: &ModulusSwitchConfiguration, + key: DriftKey, + ) -> Self { + match config { + ModulusSwitchConfiguration::Standard => Self::Standard, + ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_) => { + Self::DriftTechniqueNoiseReduction(key) + } + ModulusSwitchConfiguration::CenteredMeanNoiseReduction => { + Self::CenteredMeanNoiseReduction } } } } -impl From<&ModulusSwitchConfiguration> - for NoiseSimulationModulusSwitchConfig -{ - fn from(value: &ModulusSwitchConfiguration) -> Self { - match value { - ModulusSwitchConfiguration::Standard => Self::Standard, - ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_) => { - Self::DriftTechniqueNoiseReduction +impl NoiseSimulationModulusSwitchConfig { + pub fn as_ref(&self) -> NoiseSimulationModulusSwitchConfig<&DriftKey> { + match self { + Self::Standard => NoiseSimulationModulusSwitchConfig::Standard, + Self::DriftTechniqueNoiseReduction(key) => { + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(key) } - ModulusSwitchConfiguration::CenteredMeanNoiseReduction => { - Self::CenteredMeanNoiseReduction + Self::CenteredMeanNoiseReduction => { + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction + } + } + } + + pub fn expected_average_after_ms(self, polynomial_size: PolynomialSize) -> f64 { + match self { + Self::Standard => 0.0f64, + Self::DriftTechniqueNoiseReduction(_) => 0.0f64, + Self::CenteredMeanNoiseReduction => { + // Half case subtracted before entering the blind rotate + -1.0f64 / (4.0 * polynomial_size.0 as f64) } } } @@ -518,16 +611,21 @@ impl ServerKey { } } - pub fn noise_simulation_modulus_switch_config(&self) -> NoiseSimulationModulusSwitchConfig { + pub fn noise_simulation_modulus_switch_config( + &self, + ) -> NoiseSimulationModulusSwitchConfig<&Self> { match &self.atomic_pattern { AtomicPatternServerKey::Standard(standard_atomic_pattern_server_key) => { match &standard_atomic_pattern_server_key.bootstrapping_key { ShortintBootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), ShortintBootstrappingKey::MultiBit { .. } => { - todo!("Unsupported ShortintBootstrappingKey::MultiBit for noise simulation") + panic!("MultiBit ServerKey does not support the drift technique") } } } @@ -536,9 +634,12 @@ impl ServerKey { ShortintBootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), ShortintBootstrappingKey::MultiBit { .. } => { - todo!("Unsupported ShortintBootstrappingKey::MultiBit for noise simulation") + panic!("MultiBit ServerKey does not support the drift technique") } } } @@ -839,7 +940,9 @@ impl> LweClassicFftBootstrap NoiseSimulationModulusSwitchConfig { + pub fn noise_simulation_modulus_switch_config( + &self, + ) -> NoiseSimulationModulusSwitchConfig<&Self> { match self.atomic_pattern() { AtomicPatternNoiseSquashingKey::Standard( standard_atomic_pattern_noise_squashing_key, @@ -847,7 +950,10 @@ impl NoiseSquashingKey { Shortint128BootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), Shortint128BootstrappingKey::MultiBit { .. } => { panic!("MultiBit ServerKey does not support the drift technique") } @@ -858,7 +964,10 @@ impl NoiseSquashingKey { Shortint128BootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), Shortint128BootstrappingKey::MultiBit { .. } => { panic!("MultiBit ServerKey does not support the drift technique") }