diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs index 5e1296196..6b0a0e717 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs @@ -1,8 +1,7 @@ use super::dp_ks_ms::dp_ks_any_ms; use super::utils::noise_simulation::{ - NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe, - NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk, + NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -49,8 +48,7 @@ pub fn br_dp_ks_any_ms< bsk: &PBSKey, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, accumulator: &Accumulator, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, @@ -96,7 +94,6 @@ where scalar, ksk, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -126,12 +123,7 @@ where let id_lut = sks.generate_lookup_table(|x| x); let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); for _ in 0..10 { let input_zero_as_lwe = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0); @@ -142,8 +134,7 @@ where &sks, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), @@ -210,12 +201,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0); @@ -226,8 +212,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper( sks, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), @@ -418,30 +403,19 @@ where let noise_simulation_ksk = NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_bsk = NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); assert!(noise_simulation_bsk.matches_actual_shortint_server_key(&sks)); let max_scalar_mul = sks.max_noise_level.get(); @@ -464,8 +438,7 @@ where &noise_simulation_bsk, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_accumulator, br_input_modulus_log, &mut (), @@ -483,8 +456,7 @@ where &sks, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs index 302f8e9c7..d51c7c2ec 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs @@ -1,8 +1,7 @@ use super::dp_ks_ms::any_ms; use super::utils::noise_simulation::{ - DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe, - NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + DynLwe, NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk, + NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -73,8 +72,7 @@ pub fn br_rerand_dp_ks_any_ms< ksk_rerand: &KsKeyRerand, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, decomp_accumulator: &Accumulator, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, @@ -139,7 +137,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -234,12 +231,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = comp_private_key.encrypt_noiseless_decompression_input_dyn_lwe(cks, 0, &mut engine); @@ -280,8 +272,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper( ksk_rerand, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, br_input_modulus_log, &mut (), @@ -631,31 +622,20 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise

( NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); let noise_simulation_ksk_rerand = NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, rerand_ksk_params, params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_decomp_bsk = NoiseSimulationLweFourierBsk::new_from_comp_parameters(params, compression_params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); assert!(noise_simulation_ksk_rerand.matches_actual_shortint_keyswitching_key(&ksk_rerand)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); assert!(noise_simulation_decomp_bsk.matches_actual_shortint_decomp_key(&decomp_key)); let max_scalar_mul = sks.max_noise_level.get(); @@ -691,8 +671,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise

( &noise_simulation_ksk_rerand, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_accumulator, compute_br_input_modulus_log, &mut (), @@ -738,8 +717,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise

( &ksk_rerand, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, compute_br_input_modulus_log, &mut (), @@ -980,15 +958,9 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs

( KeySwitchingKeyBuildHelper::new((&cpk_private_key, None), (&cks, &sks), rerand_ksk_params); let ksk_rerand: KeySwitchingKeyView<'_> = ksk_rerand_builder.as_key_switching_key_view(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; - let max_scalar_mul = sks.max_noise_level.get(); let decomp_rescale_lut = decomp_key.rescaling_lut( @@ -1095,8 +1067,7 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs

( &ksk_rerand, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, compute_br_input_modulus_log, &mut (), diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs index cbaf312c6..e8f865902 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs @@ -1,7 +1,6 @@ use super::dp_ks_ms::any_ms; use super::utils::noise_simulation::{ - DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + DynLwe, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -44,8 +43,7 @@ pub fn cpk_ks_any_ms< >( input: InputCt, ksk_ds: &KsKeyDs, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, ) -> (InputCt, KsResult, Option, MsResult) @@ -76,7 +74,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -141,12 +138,7 @@ fn cpk_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = { let compact_list = cpk.encrypt_iter_with_modulus_with_engine( @@ -165,8 +157,7 @@ fn cpk_ks_any_ms_inner_helper( let (input, after_ks_ds, after_drift, after_ms) = cpk_ks_any_ms( ct, ksk_ds, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); @@ -365,37 +356,25 @@ fn noise_check_encrypt_cpk_ks_ms_noise

( NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); let noise_simulation_ksk_ds = NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, ksk_ds_params, params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); assert!(noise_simulation_ksk_ds.matches_actual_shortint_keyswitching_key(&ksk_ds)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); let (_input_sim, _after_ks_ds_sim, _after_drift_sim, after_ms_sim) = { let noise_simulation_input = NoiseSimulationLwe::encrypt_with_cpk(&cpk); cpk_ks_any_ms( noise_simulation_input, &noise_simulation_ksk_ds, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), compute_br_input_modulus_log, &mut (), ) @@ -417,8 +396,7 @@ fn noise_check_encrypt_cpk_ks_ms_noise

( let (_input, _after_ks_ds, _before_ms, after_ms) = cpk_ks_any_ms( sample_input, &ksk_ds, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, compute_br_input_modulus_log, &mut (), ); diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs index 329fcacd4..bb8c0c65f 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs @@ -22,8 +22,7 @@ use rayon::prelude::*; pub fn any_ms( input: &InputCt, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, ) -> (Option, MsResult) @@ -46,10 +45,9 @@ where SideResources = Resources, >, { - match (modulus_switch_configuration, mod_switch_noise_reduction_key) { - ( - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction, - Some(mod_switch_noise_reduction_key), + match modulus_switch_configuration { + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction( + mod_switch_noise_reduction_key, ) => { let (mut drift_technique_result, mut ms_result) = mod_switch_noise_reduction_key .allocate_drift_technique_standard_mod_switch_result(side_resources); @@ -63,13 +61,13 @@ where (Some(drift_technique_result), ms_result) } - (NoiseSimulationModulusSwitchConfig::Standard, None) => { + NoiseSimulationModulusSwitchConfig::Standard => { let mut ms_result = input.allocate_standard_mod_switch_result(side_resources); input.standard_mod_switch(br_input_modulus_log, &mut ms_result, side_resources); (None, ms_result) } - (NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction, None) => { + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => { let mut ms_result = input.allocate_centered_binary_shifted_standard_mod_switch_result(side_resources); input.centered_binary_shifted_and_standard_mod_switch( @@ -80,7 +78,6 @@ where (None, ms_result) } - _ => panic!("Inconsistent modulus switch and drift key configuration"), } } @@ -98,8 +95,7 @@ pub fn dp_ks_any_ms< input: InputCt, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, ) -> ( @@ -140,7 +136,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -169,12 +164,7 @@ where let id_lut = sks.generate_lookup_table(|x| x); let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); for _ in 0..10 { let input_zero = cks.encrypt(0); @@ -184,8 +174,7 @@ where input_zero_as_lwe, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); @@ -234,12 +223,7 @@ fn encrypt_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = DynLwe::U64(cks.unchecked_encrypt(msg).ct); @@ -247,8 +231,7 @@ fn encrypt_dp_ks_any_ms_inner_helper( ct, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); @@ -419,28 +402,17 @@ where let noise_simulation_ksk = NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); let max_scalar_mul = sks.max_noise_level.get(); @@ -450,8 +422,7 @@ where noise_simulation, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), br_input_modulus_log, &mut (), ) @@ -466,8 +437,7 @@ where sample_input, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, br_input_modulus_log, &mut (), ); diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs index 4a0cdd6d9..c1c703470 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_pbs128_packingks.rs @@ -43,8 +43,7 @@ fn dp_ks_any_ms_standard_pbs128< input: InputCt, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key_128: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, bsk_128: &Bsk, br_input_modulus_log: CiphertextModulusLog, accumulator: &Accumulator, @@ -93,7 +92,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key_128, br_input_modulus_log, side_resources, ); @@ -131,8 +129,7 @@ fn dp_ks_any_ms_standard_pbs128_packing_ks< input: Vec, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key_128: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, bsk_128: &Bsk, br_input_modulus_log: CiphertextModulusLog, accumulator: &Accumulator, @@ -199,7 +196,6 @@ where scalar, ksk, modulus_switch_configuration, - mod_switch_noise_reduction_key_128, bsk_128, br_input_modulus_log, accumulator, @@ -260,15 +256,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks

( let lwe_per_glwe = noise_squashing_compression_key.lwe_per_glwe(); - let noise_simulation_modulus_switch_config = - noise_squashing_key.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => { - Some(&noise_squashing_key) - } - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config(); let br_input_modulus_log = noise_squashing_key.br_input_modulus_log(); @@ -304,8 +292,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks

( input_zero_as_lwe, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &noise_squashing_key, br_input_modulus_log, &id_lut, @@ -427,15 +414,8 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper( ) }; - let noise_simulation_modulus_switch_config = - noise_squashing_key.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => { - Some(noise_squashing_key) - } - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config(); + let bsk_polynomial_size = noise_squashing_key.polynomial_size(); let bsk_glwe_size = noise_squashing_key.glwe_size(); let br_input_modulus_log = noise_squashing_key.br_input_modulus_log(); @@ -470,8 +450,7 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper( inputs, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, noise_squashing_key, br_input_modulus_log, &id_lut, @@ -712,8 +691,8 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise

( let noise_simulation_ksk = NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_bsk128 = NoiseSimulationLweFourier128Bsk::new_from_parameters(params, noise_squashing_params); let noise_simulation_packing_key = NoiseSimulationLwePackingKeyswitchKey::new_from_params( @@ -721,26 +700,11 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise

( noise_squashing_compression_params, ); - let noise_simulation_modulus_switch_config = - noise_squashing_key.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => { - Some(&noise_squashing_key) - } - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config(); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!( - noise_simulation_drift_key.matches_actual_shortint_noise_squashing_key(drift_key) - ) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_noise_squashing_modulus_switch_config(modulus_switch_config)); assert!( noise_simulation_bsk128.matches_actual_shortint_noise_squashing_key(&noise_squashing_key) ); @@ -766,8 +730,7 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise

( vec![noise_simulation; noise_squashing_compression_key.lwe_per_glwe().0], max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_bsk128, br_input_modulus_log, &noise_simulation_accumulator, diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs index 592e73b10..354f66360 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs @@ -462,36 +462,129 @@ impl AllocateLweKeyswitchResult for ServerKey { } #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum NoiseSimulationModulusSwitchConfig { +pub enum NoiseSimulationModulusSwitchConfig { Standard, - DriftTechniqueNoiseReduction, + DriftTechniqueNoiseReduction(DriftKey), CenteredMeanNoiseReduction, } -impl NoiseSimulationModulusSwitchConfig { - pub fn expected_average_after_ms(self, polynomial_size: PolynomialSize) -> f64 { - match self { - Self::Standard => 0.0f64, - Self::DriftTechniqueNoiseReduction => 0.0f64, - Self::CenteredMeanNoiseReduction => { - // Half case subtracted before entering the blind rotate - -1.0f64 / (4.0 * polynomial_size.0 as f64) +impl NoiseSimulationModulusSwitchConfig { + pub fn new_from_atomic_pattern_parameters(params: AtomicPatternParameters) -> Self { + let drift_key = + NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + + match params { + AtomicPatternParameters::Standard(pbsparameters) => match pbsparameters { + PBSParameters::PBS(classic_pbsparameters) => { + match classic_pbsparameters.modulus_switch_noise_reduction_params { + ModulusSwitchType::Standard => Self::Standard, + ModulusSwitchType::DriftTechniqueNoiseReduction(_) => { + Self::DriftTechniqueNoiseReduction( + drift_key.expect("Invalid drift key configuration"), + ) + } + ModulusSwitchType::CenteredMeanNoiseReduction => { + Self::CenteredMeanNoiseReduction + } + } + } + PBSParameters::MultiBitPBS(_) => { + panic!( + "Unsupported ShortintBootstrappingKey::MultiBit \ + for NoiseSimulationModulusSwitchConfig" + ) + } + }, + AtomicPatternParameters::KeySwitch32(key_switch32_pbsparameters) => { + match &key_switch32_pbsparameters.modulus_switch_noise_reduction_params { + ModulusSwitchType::Standard => Self::Standard, + ModulusSwitchType::DriftTechniqueNoiseReduction(_) => { + Self::DriftTechniqueNoiseReduction( + drift_key.expect("Invalid drift key configuration"), + ) + } + ModulusSwitchType::CenteredMeanNoiseReduction => { + Self::CenteredMeanNoiseReduction + } + } + } + } + } + + pub fn matches_shortint_server_key_modulus_switch_config( + &self, + shortint_config: NoiseSimulationModulusSwitchConfig<&ServerKey>, + ) -> bool { + match (self, shortint_config) { + (Self::Standard, NoiseSimulationModulusSwitchConfig::Standard) => true, + ( + Self::DriftTechniqueNoiseReduction(noise_sim), + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(sks), + ) => noise_sim.matches_actual_shortint_server_key(sks), + ( + Self::CenteredMeanNoiseReduction, + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction, + ) => true, + _ => false, + } + } + + pub fn matches_shortint_noise_squashing_modulus_switch_config( + &self, + shortint_config: NoiseSimulationModulusSwitchConfig<&NoiseSquashingKey>, + ) -> bool { + match (self, shortint_config) { + (Self::Standard, NoiseSimulationModulusSwitchConfig::Standard) => true, + ( + Self::DriftTechniqueNoiseReduction(noise_sim), + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(sns), + ) => noise_sim.matches_actual_shortint_noise_squashing_key(sns), + ( + Self::CenteredMeanNoiseReduction, + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction, + ) => true, + _ => false, + } + } +} + +impl NoiseSimulationModulusSwitchConfig { + fn new_from_config_and_key( + config: &ModulusSwitchConfiguration, + key: DriftKey, + ) -> Self { + match config { + ModulusSwitchConfiguration::Standard => Self::Standard, + ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_) => { + Self::DriftTechniqueNoiseReduction(key) + } + ModulusSwitchConfiguration::CenteredMeanNoiseReduction => { + Self::CenteredMeanNoiseReduction } } } } -impl From<&ModulusSwitchConfiguration> - for NoiseSimulationModulusSwitchConfig -{ - fn from(value: &ModulusSwitchConfiguration) -> Self { - match value { - ModulusSwitchConfiguration::Standard => Self::Standard, - ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_) => { - Self::DriftTechniqueNoiseReduction +impl NoiseSimulationModulusSwitchConfig { + pub fn as_ref(&self) -> NoiseSimulationModulusSwitchConfig<&DriftKey> { + match self { + Self::Standard => NoiseSimulationModulusSwitchConfig::Standard, + Self::DriftTechniqueNoiseReduction(key) => { + NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(key) } - ModulusSwitchConfiguration::CenteredMeanNoiseReduction => { - Self::CenteredMeanNoiseReduction + Self::CenteredMeanNoiseReduction => { + NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction + } + } + } + + pub fn expected_average_after_ms(self, polynomial_size: PolynomialSize) -> f64 { + match self { + Self::Standard => 0.0f64, + Self::DriftTechniqueNoiseReduction(_) => 0.0f64, + Self::CenteredMeanNoiseReduction => { + // Half case subtracted before entering the blind rotate + -1.0f64 / (4.0 * polynomial_size.0 as f64) } } } @@ -518,16 +611,21 @@ impl ServerKey { } } - pub fn noise_simulation_modulus_switch_config(&self) -> NoiseSimulationModulusSwitchConfig { + pub fn noise_simulation_modulus_switch_config( + &self, + ) -> NoiseSimulationModulusSwitchConfig<&Self> { match &self.atomic_pattern { AtomicPatternServerKey::Standard(standard_atomic_pattern_server_key) => { match &standard_atomic_pattern_server_key.bootstrapping_key { ShortintBootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), ShortintBootstrappingKey::MultiBit { .. } => { - todo!("Unsupported ShortintBootstrappingKey::MultiBit for noise simulation") + panic!("MultiBit ServerKey does not support the drift technique") } } } @@ -536,9 +634,12 @@ impl ServerKey { ShortintBootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), ShortintBootstrappingKey::MultiBit { .. } => { - todo!("Unsupported ShortintBootstrappingKey::MultiBit for noise simulation") + panic!("MultiBit ServerKey does not support the drift technique") } } } @@ -839,7 +940,9 @@ impl> LweClassicFftBootstrap NoiseSimulationModulusSwitchConfig { + pub fn noise_simulation_modulus_switch_config( + &self, + ) -> NoiseSimulationModulusSwitchConfig<&Self> { match self.atomic_pattern() { AtomicPatternNoiseSquashingKey::Standard( standard_atomic_pattern_noise_squashing_key, @@ -847,7 +950,10 @@ impl NoiseSquashingKey { Shortint128BootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), Shortint128BootstrappingKey::MultiBit { .. } => { panic!("MultiBit ServerKey does not support the drift technique") } @@ -858,7 +964,10 @@ impl NoiseSquashingKey { Shortint128BootstrappingKey::Classic { bsk: _, modulus_switch_noise_reduction_key, - } => modulus_switch_noise_reduction_key.into(), + } => NoiseSimulationModulusSwitchConfig::new_from_config_and_key( + modulus_switch_noise_reduction_key, + self, + ), Shortint128BootstrappingKey::MultiBit { .. } => { panic!("MultiBit ServerKey does not support the drift technique") }