diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs index 5e1296196..6b0a0e717 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/br_dp_ks_ms.rs @@ -1,8 +1,7 @@ use super::dp_ks_ms::dp_ks_any_ms; use super::utils::noise_simulation::{ - NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe, - NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk, + NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -49,8 +48,7 @@ pub fn br_dp_ks_any_ms< bsk: &PBSKey, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, accumulator: &Accumulator, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, @@ -96,7 +94,6 @@ where scalar, ksk, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -126,12 +123,7 @@ where let id_lut = sks.generate_lookup_table(|x| x); let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); for _ in 0..10 { let input_zero_as_lwe = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0); @@ -142,8 +134,7 @@ where &sks, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), @@ -210,12 +201,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0); @@ -226,8 +212,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper( sks, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), @@ -418,30 +403,19 @@ where let noise_simulation_ksk = NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_bsk = NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); assert!(noise_simulation_bsk.matches_actual_shortint_server_key(&sks)); let max_scalar_mul = sks.max_noise_level.get(); @@ -464,8 +438,7 @@ where &noise_simulation_bsk, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_accumulator, br_input_modulus_log, &mut (), @@ -483,8 +456,7 @@ where &sks, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &id_lut, br_input_modulus_log, &mut (), diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs index 302f8e9c7..d51c7c2ec 100644 --- a/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs +++ b/tfhe/src/shortint/server_key/tests/noise_distribution/br_rerand_dp_ks_ms.rs @@ -1,8 +1,7 @@ use super::dp_ks_ms::any_ms; use super::utils::noise_simulation::{ - DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe, - NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey, - NoiseSimulationModulusSwitchConfig, + DynLwe, NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk, + NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig, }; use super::utils::traits::*; use super::utils::{ @@ -73,8 +72,7 @@ pub fn br_rerand_dp_ks_any_ms< ksk_rerand: &KsKeyRerand, scalar: DPScalar, ksk: &KsKey, - modulus_switch_configuration: NoiseSimulationModulusSwitchConfig, - mod_switch_noise_reduction_key: Option<&DriftKey>, + modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>, decomp_accumulator: &Accumulator, br_input_modulus_log: CiphertextModulusLog, side_resources: &mut Resources, @@ -139,7 +137,6 @@ where let (drift_technique_result, ms_result) = any_ms( &ks_result, modulus_switch_configuration, - mod_switch_noise_reduction_key, br_input_modulus_log, side_resources, ); @@ -234,12 +231,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper( }; let br_input_modulus_log = sks.br_input_modulus_log(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let ct = comp_private_key.encrypt_noiseless_decompression_input_dyn_lwe(cks, 0, &mut engine); @@ -280,8 +272,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper( ksk_rerand, scalar_for_multiplication, sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, br_input_modulus_log, &mut (), @@ -631,31 +622,20 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise
( NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params); let noise_simulation_ksk_rerand = NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, rerand_ksk_params, params); - let noise_simulation_drift_key = - NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params); + let noise_simulation_modulus_switch_config = + NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params); let noise_simulation_decomp_bsk = NoiseSimulationLweFourierBsk::new_from_comp_parameters(params, compression_params); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); let expected_average_after_ms = - noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size()); - - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; + modulus_switch_config.expected_average_after_ms(params.polynomial_size()); assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks)); assert!(noise_simulation_ksk_rerand.matches_actual_shortint_keyswitching_key(&ksk_rerand)); - match (noise_simulation_drift_key, drift_key) { - (Some(noise_simulation_drift_key), Some(drift_key)) => { - assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key)) - } - (None, None) => (), - _ => panic!("Inconsistent Drift Key configuration"), - } + assert!(noise_simulation_modulus_switch_config + .matches_shortint_server_key_modulus_switch_config(modulus_switch_config)); assert!(noise_simulation_decomp_bsk.matches_actual_shortint_decomp_key(&decomp_key)); let max_scalar_mul = sks.max_noise_level.get(); @@ -691,8 +671,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise
( &noise_simulation_ksk_rerand, max_scalar_mul, &noise_simulation_ksk, - noise_simulation_modulus_switch_config, - noise_simulation_drift_key.as_ref(), + noise_simulation_modulus_switch_config.as_ref(), &noise_simulation_accumulator, compute_br_input_modulus_log, &mut (), @@ -738,8 +717,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise
( &ksk_rerand, max_scalar_mul, &sks, - noise_simulation_modulus_switch_config, - drift_key, + modulus_switch_config, &decomp_rescale_lut, compute_br_input_modulus_log, &mut (), @@ -980,15 +958,9 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs
( KeySwitchingKeyBuildHelper::new((&cpk_private_key, None), (&cks, &sks), rerand_ksk_params); let ksk_rerand: KeySwitchingKeyView<'_> = ksk_rerand_builder.as_key_switching_key_view(); - let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config(); + let modulus_switch_config = sks.noise_simulation_modulus_switch_config(); let compute_br_input_modulus_log = sks.br_input_modulus_log(); - let drift_key = match noise_simulation_modulus_switch_config { - NoiseSimulationModulusSwitchConfig::Standard => None, - NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks), - NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None, - }; - let max_scalar_mul = sks.max_noise_level.get(); let decomp_rescale_lut = decomp_key.rescaling_lut( @@ -1095,8 +1067,7 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs
(
&ksk_rerand,
max_scalar_mul,
&sks,
- noise_simulation_modulus_switch_config,
- drift_key,
+ modulus_switch_config,
&decomp_rescale_lut,
compute_br_input_modulus_log,
&mut (),
diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs
index cbaf312c6..e8f865902 100644
--- a/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs
+++ b/tfhe/src/shortint/server_key/tests/noise_distribution/cpk_ks_ms.rs
@@ -1,7 +1,6 @@
use super::dp_ks_ms::any_ms;
use super::utils::noise_simulation::{
- DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey,
- NoiseSimulationModulusSwitchConfig,
+ DynLwe, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
};
use super::utils::traits::*;
use super::utils::{
@@ -44,8 +43,7 @@ pub fn cpk_ks_any_ms<
>(
input: InputCt,
ksk_ds: &KsKeyDs,
- modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
- mod_switch_noise_reduction_key: Option<&DriftKey>,
+ modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
br_input_modulus_log: CiphertextModulusLog,
side_resources: &mut Resources,
) -> (InputCt, KsResult, Option (
NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_ksk_ds =
NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, ksk_ds_params, params);
- let noise_simulation_drift_key =
- NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
+ let noise_simulation_modulus_switch_config =
+ NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
- let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
+ let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let compute_br_input_modulus_log = sks.br_input_modulus_log();
let expected_average_after_ms =
- noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size());
-
- let drift_key = match noise_simulation_modulus_switch_config {
- NoiseSimulationModulusSwitchConfig::Standard => None,
- NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
- NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
- };
+ modulus_switch_config.expected_average_after_ms(params.polynomial_size());
assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks));
assert!(noise_simulation_ksk_ds.matches_actual_shortint_keyswitching_key(&ksk_ds));
- match (noise_simulation_drift_key, drift_key) {
- (Some(noise_simulation_drift_key), Some(drift_key)) => {
- assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key))
- }
- (None, None) => (),
- _ => panic!("Inconsistent Drift Key configuration"),
- }
+ assert!(noise_simulation_modulus_switch_config
+ .matches_shortint_server_key_modulus_switch_config(modulus_switch_config));
let (_input_sim, _after_ks_ds_sim, _after_drift_sim, after_ms_sim) = {
let noise_simulation_input = NoiseSimulationLwe::encrypt_with_cpk(&cpk);
cpk_ks_any_ms(
noise_simulation_input,
&noise_simulation_ksk_ds,
- noise_simulation_modulus_switch_config,
- noise_simulation_drift_key.as_ref(),
+ noise_simulation_modulus_switch_config.as_ref(),
compute_br_input_modulus_log,
&mut (),
)
@@ -417,8 +396,7 @@ fn noise_check_encrypt_cpk_ks_ms_noise (
let (_input, _after_ks_ds, _before_ms, after_ms) = cpk_ks_any_ms(
sample_input,
&ksk_ds,
- noise_simulation_modulus_switch_config,
- drift_key,
+ modulus_switch_config,
compute_br_input_modulus_log,
&mut (),
);
diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs
index 329fcacd4..bb8c0c65f 100644
--- a/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs
+++ b/tfhe/src/shortint/server_key/tests/noise_distribution/dp_ks_ms.rs
@@ -22,8 +22,7 @@ use rayon::prelude::*;
pub fn any_ms (
let lwe_per_glwe = noise_squashing_compression_key.lwe_per_glwe();
- let noise_simulation_modulus_switch_config =
- noise_squashing_key.noise_simulation_modulus_switch_config();
- let drift_key = match noise_simulation_modulus_switch_config {
- NoiseSimulationModulusSwitchConfig::Standard => None,
- NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => {
- Some(&noise_squashing_key)
- }
- NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
- };
+ let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config();
let br_input_modulus_log = noise_squashing_key.br_input_modulus_log();
@@ -304,8 +292,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks (
input_zero_as_lwe,
max_scalar_mul,
&sks,
- noise_simulation_modulus_switch_config,
- drift_key,
+ modulus_switch_config,
&noise_squashing_key,
br_input_modulus_log,
&id_lut,
@@ -427,15 +414,8 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper(
)
};
- let noise_simulation_modulus_switch_config =
- noise_squashing_key.noise_simulation_modulus_switch_config();
- let drift_key = match noise_simulation_modulus_switch_config {
- NoiseSimulationModulusSwitchConfig::Standard => None,
- NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => {
- Some(noise_squashing_key)
- }
- NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
- };
+ let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config();
+
let bsk_polynomial_size = noise_squashing_key.polynomial_size();
let bsk_glwe_size = noise_squashing_key.glwe_size();
let br_input_modulus_log = noise_squashing_key.br_input_modulus_log();
@@ -470,8 +450,7 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper(
inputs,
scalar_for_multiplication,
sks,
- noise_simulation_modulus_switch_config,
- drift_key,
+ modulus_switch_config,
noise_squashing_key,
br_input_modulus_log,
&id_lut,
@@ -712,8 +691,8 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise (
let noise_simulation_ksk =
NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params);
- let noise_simulation_drift_key =
- NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
+ let noise_simulation_modulus_switch_config =
+ NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_bsk128 =
NoiseSimulationLweFourier128Bsk::new_from_parameters(params, noise_squashing_params);
let noise_simulation_packing_key = NoiseSimulationLwePackingKeyswitchKey::new_from_params(
@@ -721,26 +700,11 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise (
noise_squashing_compression_params,
);
- let noise_simulation_modulus_switch_config =
- noise_squashing_key.noise_simulation_modulus_switch_config();
- let drift_key = match noise_simulation_modulus_switch_config {
- NoiseSimulationModulusSwitchConfig::Standard => None,
- NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => {
- Some(&noise_squashing_key)
- }
- NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
- };
+ let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config();
assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks));
- match (noise_simulation_drift_key, drift_key) {
- (Some(noise_simulation_drift_key), Some(drift_key)) => {
- assert!(
- noise_simulation_drift_key.matches_actual_shortint_noise_squashing_key(drift_key)
- )
- }
- (None, None) => (),
- _ => panic!("Inconsistent Drift Key configuration"),
- }
+ assert!(noise_simulation_modulus_switch_config
+ .matches_shortint_noise_squashing_modulus_switch_config(modulus_switch_config));
assert!(
noise_simulation_bsk128.matches_actual_shortint_noise_squashing_key(&noise_squashing_key)
);
@@ -766,8 +730,7 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise (
vec![noise_simulation; noise_squashing_compression_key.lwe_per_glwe().0],
max_scalar_mul,
&noise_simulation_ksk,
- noise_simulation_modulus_switch_config,
- noise_simulation_drift_key.as_ref(),
+ noise_simulation_modulus_switch_config.as_ref(),
&noise_simulation_bsk128,
br_input_modulus_log,
&noise_simulation_accumulator,
diff --git a/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs b/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs
index 592e73b10..354f66360 100644
--- a/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs
+++ b/tfhe/src/shortint/server_key/tests/noise_distribution/utils/noise_simulation.rs
@@ -462,36 +462,129 @@ impl AllocateLweKeyswitchResult for ServerKey {
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
-pub enum NoiseSimulationModulusSwitchConfig {
+pub enum NoiseSimulationModulusSwitchConfig