refactor(test): make modulus switch config system make more sense

- The config type can hold any type for the drift technique variant because
the bounds are too weird to set on the type, the functions making use of
the config type should properly declare the bounds
This commit is contained in:
Arthur Meyre
2025-10-27 16:23:13 +01:00
committed by IceTDrinker
parent d95b46cb9b
commit a41cd47b9e
6 changed files with 211 additions and 248 deletions

View File

@@ -1,8 +1,7 @@
use super::dp_ks_ms::dp_ks_any_ms;
use super::utils::noise_simulation::{
NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe,
NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey,
NoiseSimulationModulusSwitchConfig,
NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
};
use super::utils::traits::*;
use super::utils::{
@@ -49,8 +48,7 @@ pub fn br_dp_ks_any_ms<
bsk: &PBSKey,
scalar: DPScalar,
ksk: &KsKey,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
mod_switch_noise_reduction_key: Option<&DriftKey>,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
accumulator: &Accumulator,
br_input_modulus_log: CiphertextModulusLog,
side_resources: &mut Resources,
@@ -96,7 +94,6 @@ where
scalar,
ksk,
modulus_switch_configuration,
mod_switch_noise_reduction_key,
br_input_modulus_log,
side_resources,
);
@@ -126,12 +123,7 @@ where
let id_lut = sks.generate_lookup_table(|x| x);
let br_input_modulus_log = sks.br_input_modulus_log();
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
for _ in 0..10 {
let input_zero_as_lwe = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0);
@@ -142,8 +134,7 @@ where
&sks,
max_scalar_mul,
&sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
&id_lut,
br_input_modulus_log,
&mut (),
@@ -210,12 +201,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper(
};
let br_input_modulus_log = sks.br_input_modulus_log();
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let ct = cks.encrypt_noiseless_pbs_input_dyn_lwe(br_input_modulus_log, 0);
@@ -226,8 +212,7 @@ fn encrypt_br_dp_ks_any_ms_inner_helper(
sks,
scalar_for_multiplication,
sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
&id_lut,
br_input_modulus_log,
&mut (),
@@ -418,30 +403,19 @@ where
let noise_simulation_ksk =
NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_drift_key =
NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_bsk =
NoiseSimulationLweFourierBsk::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let br_input_modulus_log = sks.br_input_modulus_log();
let expected_average_after_ms =
noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size());
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
modulus_switch_config.expected_average_after_ms(params.polynomial_size());
assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks));
match (noise_simulation_drift_key, drift_key) {
(Some(noise_simulation_drift_key), Some(drift_key)) => {
assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key))
}
(None, None) => (),
_ => panic!("Inconsistent Drift Key configuration"),
}
assert!(noise_simulation_modulus_switch_config
.matches_shortint_server_key_modulus_switch_config(modulus_switch_config));
assert!(noise_simulation_bsk.matches_actual_shortint_server_key(&sks));
let max_scalar_mul = sks.max_noise_level.get();
@@ -464,8 +438,7 @@ where
&noise_simulation_bsk,
max_scalar_mul,
&noise_simulation_ksk,
noise_simulation_modulus_switch_config,
noise_simulation_drift_key.as_ref(),
noise_simulation_modulus_switch_config.as_ref(),
&noise_simulation_accumulator,
br_input_modulus_log,
&mut (),
@@ -483,8 +456,7 @@ where
&sks,
max_scalar_mul,
&sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
&id_lut,
br_input_modulus_log,
&mut (),

View File

@@ -1,8 +1,7 @@
use super::dp_ks_ms::any_ms;
use super::utils::noise_simulation::{
DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationGlwe, NoiseSimulationLwe,
NoiseSimulationLweFourierBsk, NoiseSimulationLweKeyswitchKey,
NoiseSimulationModulusSwitchConfig,
DynLwe, NoiseSimulationGlwe, NoiseSimulationLwe, NoiseSimulationLweFourierBsk,
NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
};
use super::utils::traits::*;
use super::utils::{
@@ -73,8 +72,7 @@ pub fn br_rerand_dp_ks_any_ms<
ksk_rerand: &KsKeyRerand,
scalar: DPScalar,
ksk: &KsKey,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
mod_switch_noise_reduction_key: Option<&DriftKey>,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
decomp_accumulator: &Accumulator,
br_input_modulus_log: CiphertextModulusLog,
side_resources: &mut Resources,
@@ -139,7 +137,6 @@ where
let (drift_technique_result, ms_result) = any_ms(
&ks_result,
modulus_switch_configuration,
mod_switch_noise_reduction_key,
br_input_modulus_log,
side_resources,
);
@@ -234,12 +231,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper(
};
let br_input_modulus_log = sks.br_input_modulus_log();
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let ct = comp_private_key.encrypt_noiseless_decompression_input_dyn_lwe(cks, 0, &mut engine);
@@ -280,8 +272,7 @@ fn encrypt_decomp_br_rerand_dp_ks_any_ms_inner_helper(
ksk_rerand,
scalar_for_multiplication,
sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
&decomp_rescale_lut,
br_input_modulus_log,
&mut (),
@@ -631,31 +622,20 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise<P>(
NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_ksk_rerand =
NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, rerand_ksk_params, params);
let noise_simulation_drift_key =
NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_decomp_bsk =
NoiseSimulationLweFourierBsk::new_from_comp_parameters(params, compression_params);
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let compute_br_input_modulus_log = sks.br_input_modulus_log();
let expected_average_after_ms =
noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size());
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
modulus_switch_config.expected_average_after_ms(params.polynomial_size());
assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks));
assert!(noise_simulation_ksk_rerand.matches_actual_shortint_keyswitching_key(&ksk_rerand));
match (noise_simulation_drift_key, drift_key) {
(Some(noise_simulation_drift_key), Some(drift_key)) => {
assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key))
}
(None, None) => (),
_ => panic!("Inconsistent Drift Key configuration"),
}
assert!(noise_simulation_modulus_switch_config
.matches_shortint_server_key_modulus_switch_config(modulus_switch_config));
assert!(noise_simulation_decomp_bsk.matches_actual_shortint_decomp_key(&decomp_key));
let max_scalar_mul = sks.max_noise_level.get();
@@ -691,8 +671,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise<P>(
&noise_simulation_ksk_rerand,
max_scalar_mul,
&noise_simulation_ksk,
noise_simulation_modulus_switch_config,
noise_simulation_drift_key.as_ref(),
noise_simulation_modulus_switch_config.as_ref(),
&noise_simulation_accumulator,
compute_br_input_modulus_log,
&mut (),
@@ -738,8 +717,7 @@ fn noise_check_encrypt_br_rerand_dp_ks_ms_noise<P>(
&ksk_rerand,
max_scalar_mul,
&sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
&decomp_rescale_lut,
compute_br_input_modulus_log,
&mut (),
@@ -980,15 +958,9 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs<P>(
KeySwitchingKeyBuildHelper::new((&cpk_private_key, None), (&cks, &sks), rerand_ksk_params);
let ksk_rerand: KeySwitchingKeyView<'_> = ksk_rerand_builder.as_key_switching_key_view();
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let compute_br_input_modulus_log = sks.br_input_modulus_log();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let max_scalar_mul = sks.max_noise_level.get();
let decomp_rescale_lut = decomp_key.rescaling_lut(
@@ -1095,8 +1067,7 @@ fn sanity_check_encrypt_br_rerand_dp_ks_ms_pbs<P>(
&ksk_rerand,
max_scalar_mul,
&sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
&decomp_rescale_lut,
compute_br_input_modulus_log,
&mut (),

View File

@@ -1,7 +1,6 @@
use super::dp_ks_ms::any_ms;
use super::utils::noise_simulation::{
DynLwe, NoiseSimulationDriftTechniqueKey, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey,
NoiseSimulationModulusSwitchConfig,
DynLwe, NoiseSimulationLwe, NoiseSimulationLweKeyswitchKey, NoiseSimulationModulusSwitchConfig,
};
use super::utils::traits::*;
use super::utils::{
@@ -44,8 +43,7 @@ pub fn cpk_ks_any_ms<
>(
input: InputCt,
ksk_ds: &KsKeyDs,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
mod_switch_noise_reduction_key: Option<&DriftKey>,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
br_input_modulus_log: CiphertextModulusLog,
side_resources: &mut Resources,
) -> (InputCt, KsResult, Option<DriftTechniqueResult>, MsResult)
@@ -76,7 +74,6 @@ where
let (drift_technique_result, ms_result) = any_ms(
&ks_result,
modulus_switch_configuration,
mod_switch_noise_reduction_key,
br_input_modulus_log,
side_resources,
);
@@ -141,12 +138,7 @@ fn cpk_ks_any_ms_inner_helper(
};
let br_input_modulus_log = sks.br_input_modulus_log();
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let ct = {
let compact_list = cpk.encrypt_iter_with_modulus_with_engine(
@@ -165,8 +157,7 @@ fn cpk_ks_any_ms_inner_helper(
let (input, after_ks_ds, after_drift, after_ms) = cpk_ks_any_ms(
ct,
ksk_ds,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
br_input_modulus_log,
&mut (),
);
@@ -365,37 +356,25 @@ fn noise_check_encrypt_cpk_ks_ms_noise<P>(
NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_ksk_ds =
NoiseSimulationLweKeyswitchKey::new_from_cpk_params(cpk_params, ksk_ds_params, params);
let noise_simulation_drift_key =
NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let compute_br_input_modulus_log = sks.br_input_modulus_log();
let expected_average_after_ms =
noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size());
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
modulus_switch_config.expected_average_after_ms(params.polynomial_size());
assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks));
assert!(noise_simulation_ksk_ds.matches_actual_shortint_keyswitching_key(&ksk_ds));
match (noise_simulation_drift_key, drift_key) {
(Some(noise_simulation_drift_key), Some(drift_key)) => {
assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key))
}
(None, None) => (),
_ => panic!("Inconsistent Drift Key configuration"),
}
assert!(noise_simulation_modulus_switch_config
.matches_shortint_server_key_modulus_switch_config(modulus_switch_config));
let (_input_sim, _after_ks_ds_sim, _after_drift_sim, after_ms_sim) = {
let noise_simulation_input = NoiseSimulationLwe::encrypt_with_cpk(&cpk);
cpk_ks_any_ms(
noise_simulation_input,
&noise_simulation_ksk_ds,
noise_simulation_modulus_switch_config,
noise_simulation_drift_key.as_ref(),
noise_simulation_modulus_switch_config.as_ref(),
compute_br_input_modulus_log,
&mut (),
)
@@ -417,8 +396,7 @@ fn noise_check_encrypt_cpk_ks_ms_noise<P>(
let (_input, _after_ks_ds, _before_ms, after_ms) = cpk_ks_any_ms(
sample_input,
&ksk_ds,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
compute_br_input_modulus_log,
&mut (),
);

View File

@@ -22,8 +22,7 @@ use rayon::prelude::*;
pub fn any_ms<InputCt, DriftTechniqueResult, MsResult, DriftKey, Resources>(
input: &InputCt,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
mod_switch_noise_reduction_key: Option<&DriftKey>,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
br_input_modulus_log: CiphertextModulusLog,
side_resources: &mut Resources,
) -> (Option<DriftTechniqueResult>, MsResult)
@@ -46,10 +45,9 @@ where
SideResources = Resources,
>,
{
match (modulus_switch_configuration, mod_switch_noise_reduction_key) {
(
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction,
Some(mod_switch_noise_reduction_key),
match modulus_switch_configuration {
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(
mod_switch_noise_reduction_key,
) => {
let (mut drift_technique_result, mut ms_result) = mod_switch_noise_reduction_key
.allocate_drift_technique_standard_mod_switch_result(side_resources);
@@ -63,13 +61,13 @@ where
(Some(drift_technique_result), ms_result)
}
(NoiseSimulationModulusSwitchConfig::Standard, None) => {
NoiseSimulationModulusSwitchConfig::Standard => {
let mut ms_result = input.allocate_standard_mod_switch_result(side_resources);
input.standard_mod_switch(br_input_modulus_log, &mut ms_result, side_resources);
(None, ms_result)
}
(NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction, None) => {
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => {
let mut ms_result =
input.allocate_centered_binary_shifted_standard_mod_switch_result(side_resources);
input.centered_binary_shifted_and_standard_mod_switch(
@@ -80,7 +78,6 @@ where
(None, ms_result)
}
_ => panic!("Inconsistent modulus switch and drift key configuration"),
}
}
@@ -98,8 +95,7 @@ pub fn dp_ks_any_ms<
input: InputCt,
scalar: DPScalar,
ksk: &KsKey,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
mod_switch_noise_reduction_key: Option<&DriftKey>,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
br_input_modulus_log: CiphertextModulusLog,
side_resources: &mut Resources,
) -> (
@@ -140,7 +136,6 @@ where
let (drift_technique_result, ms_result) = any_ms(
&ks_result,
modulus_switch_configuration,
mod_switch_noise_reduction_key,
br_input_modulus_log,
side_resources,
);
@@ -169,12 +164,7 @@ where
let id_lut = sks.generate_lookup_table(|x| x);
let br_input_modulus_log = sks.br_input_modulus_log();
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
for _ in 0..10 {
let input_zero = cks.encrypt(0);
@@ -184,8 +174,7 @@ where
input_zero_as_lwe,
max_scalar_mul,
&sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
br_input_modulus_log,
&mut (),
);
@@ -234,12 +223,7 @@ fn encrypt_dp_ks_any_ms_inner_helper(
};
let br_input_modulus_log = sks.br_input_modulus_log();
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let ct = DynLwe::U64(cks.unchecked_encrypt(msg).ct);
@@ -247,8 +231,7 @@ fn encrypt_dp_ks_any_ms_inner_helper(
ct,
scalar_for_multiplication,
sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
br_input_modulus_log,
&mut (),
);
@@ -419,28 +402,17 @@ where
let noise_simulation_ksk =
NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_drift_key =
NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let modulus_switch_config = sks.noise_simulation_modulus_switch_config();
let br_input_modulus_log = sks.br_input_modulus_log();
let expected_average_after_ms =
noise_simulation_modulus_switch_config.expected_average_after_ms(params.polynomial_size());
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => Some(&sks),
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
modulus_switch_config.expected_average_after_ms(params.polynomial_size());
assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks));
match (noise_simulation_drift_key, drift_key) {
(Some(noise_simulation_drift_key), Some(drift_key)) => {
assert!(noise_simulation_drift_key.matches_actual_shortint_server_key(drift_key))
}
(None, None) => (),
_ => panic!("Inconsistent Drift Key configuration"),
}
assert!(noise_simulation_modulus_switch_config
.matches_shortint_server_key_modulus_switch_config(modulus_switch_config));
let max_scalar_mul = sks.max_noise_level.get();
@@ -450,8 +422,7 @@ where
noise_simulation,
max_scalar_mul,
&noise_simulation_ksk,
noise_simulation_modulus_switch_config,
noise_simulation_drift_key.as_ref(),
noise_simulation_modulus_switch_config.as_ref(),
br_input_modulus_log,
&mut (),
)
@@ -466,8 +437,7 @@ where
sample_input,
max_scalar_mul,
&sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
br_input_modulus_log,
&mut (),
);

View File

@@ -43,8 +43,7 @@ fn dp_ks_any_ms_standard_pbs128<
input: InputCt,
scalar: DPScalar,
ksk: &KsKey,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
mod_switch_noise_reduction_key_128: Option<&DriftKey>,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
bsk_128: &Bsk,
br_input_modulus_log: CiphertextModulusLog,
accumulator: &Accumulator,
@@ -93,7 +92,6 @@ where
let (drift_technique_result, ms_result) = any_ms(
&ks_result,
modulus_switch_configuration,
mod_switch_noise_reduction_key_128,
br_input_modulus_log,
side_resources,
);
@@ -131,8 +129,7 @@ fn dp_ks_any_ms_standard_pbs128_packing_ks<
input: Vec<InputCt>,
scalar: DPScalar,
ksk: &KsKey,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig,
mod_switch_noise_reduction_key_128: Option<&DriftKey>,
modulus_switch_configuration: NoiseSimulationModulusSwitchConfig<&DriftKey>,
bsk_128: &Bsk,
br_input_modulus_log: CiphertextModulusLog,
accumulator: &Accumulator,
@@ -199,7 +196,6 @@ where
scalar,
ksk,
modulus_switch_configuration,
mod_switch_noise_reduction_key_128,
bsk_128,
br_input_modulus_log,
accumulator,
@@ -260,15 +256,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks<P>(
let lwe_per_glwe = noise_squashing_compression_key.lwe_per_glwe();
let noise_simulation_modulus_switch_config =
noise_squashing_key.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => {
Some(&noise_squashing_key)
}
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config();
let br_input_modulus_log = noise_squashing_key.br_input_modulus_log();
@@ -304,8 +292,7 @@ fn sanity_check_encrypt_dp_ks_standard_pbs128_packing_ks<P>(
input_zero_as_lwe,
max_scalar_mul,
&sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
&noise_squashing_key,
br_input_modulus_log,
&id_lut,
@@ -427,15 +414,8 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper(
)
};
let noise_simulation_modulus_switch_config =
noise_squashing_key.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => {
Some(noise_squashing_key)
}
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config();
let bsk_polynomial_size = noise_squashing_key.polynomial_size();
let bsk_glwe_size = noise_squashing_key.glwe_size();
let br_input_modulus_log = noise_squashing_key.br_input_modulus_log();
@@ -470,8 +450,7 @@ fn encrypt_dp_ks_standard_pbs128_packing_ks_inner_helper(
inputs,
scalar_for_multiplication,
sks,
noise_simulation_modulus_switch_config,
drift_key,
modulus_switch_config,
noise_squashing_key,
br_input_modulus_log,
&id_lut,
@@ -712,8 +691,8 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise<P>(
let noise_simulation_ksk =
NoiseSimulationLweKeyswitchKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_drift_key =
NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
let noise_simulation_modulus_switch_config =
NoiseSimulationModulusSwitchConfig::new_from_atomic_pattern_parameters(params);
let noise_simulation_bsk128 =
NoiseSimulationLweFourier128Bsk::new_from_parameters(params, noise_squashing_params);
let noise_simulation_packing_key = NoiseSimulationLwePackingKeyswitchKey::new_from_params(
@@ -721,26 +700,11 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise<P>(
noise_squashing_compression_params,
);
let noise_simulation_modulus_switch_config =
noise_squashing_key.noise_simulation_modulus_switch_config();
let drift_key = match noise_simulation_modulus_switch_config {
NoiseSimulationModulusSwitchConfig::Standard => None,
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction => {
Some(&noise_squashing_key)
}
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction => None,
};
let modulus_switch_config = noise_squashing_key.noise_simulation_modulus_switch_config();
assert!(noise_simulation_ksk.matches_actual_shortint_server_key(&sks));
match (noise_simulation_drift_key, drift_key) {
(Some(noise_simulation_drift_key), Some(drift_key)) => {
assert!(
noise_simulation_drift_key.matches_actual_shortint_noise_squashing_key(drift_key)
)
}
(None, None) => (),
_ => panic!("Inconsistent Drift Key configuration"),
}
assert!(noise_simulation_modulus_switch_config
.matches_shortint_noise_squashing_modulus_switch_config(modulus_switch_config));
assert!(
noise_simulation_bsk128.matches_actual_shortint_noise_squashing_key(&noise_squashing_key)
);
@@ -766,8 +730,7 @@ fn noise_check_encrypt_dp_ks_standard_pbs128_packing_ks_noise<P>(
vec![noise_simulation; noise_squashing_compression_key.lwe_per_glwe().0],
max_scalar_mul,
&noise_simulation_ksk,
noise_simulation_modulus_switch_config,
noise_simulation_drift_key.as_ref(),
noise_simulation_modulus_switch_config.as_ref(),
&noise_simulation_bsk128,
br_input_modulus_log,
&noise_simulation_accumulator,

View File

@@ -462,36 +462,129 @@ impl AllocateLweKeyswitchResult for ServerKey {
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NoiseSimulationModulusSwitchConfig {
pub enum NoiseSimulationModulusSwitchConfig<DriftKey> {
Standard,
DriftTechniqueNoiseReduction,
DriftTechniqueNoiseReduction(DriftKey),
CenteredMeanNoiseReduction,
}
impl NoiseSimulationModulusSwitchConfig {
pub fn expected_average_after_ms(self, polynomial_size: PolynomialSize) -> f64 {
match self {
Self::Standard => 0.0f64,
Self::DriftTechniqueNoiseReduction => 0.0f64,
Self::CenteredMeanNoiseReduction => {
// Half case subtracted before entering the blind rotate
-1.0f64 / (4.0 * polynomial_size.0 as f64)
impl NoiseSimulationModulusSwitchConfig<NoiseSimulationDriftTechniqueKey> {
pub fn new_from_atomic_pattern_parameters(params: AtomicPatternParameters) -> Self {
let drift_key =
NoiseSimulationDriftTechniqueKey::new_from_atomic_pattern_parameters(params);
match params {
AtomicPatternParameters::Standard(pbsparameters) => match pbsparameters {
PBSParameters::PBS(classic_pbsparameters) => {
match classic_pbsparameters.modulus_switch_noise_reduction_params {
ModulusSwitchType::Standard => Self::Standard,
ModulusSwitchType::DriftTechniqueNoiseReduction(_) => {
Self::DriftTechniqueNoiseReduction(
drift_key.expect("Invalid drift key configuration"),
)
}
ModulusSwitchType::CenteredMeanNoiseReduction => {
Self::CenteredMeanNoiseReduction
}
}
}
PBSParameters::MultiBitPBS(_) => {
panic!(
"Unsupported ShortintBootstrappingKey::MultiBit \
for NoiseSimulationModulusSwitchConfig"
)
}
},
AtomicPatternParameters::KeySwitch32(key_switch32_pbsparameters) => {
match &key_switch32_pbsparameters.modulus_switch_noise_reduction_params {
ModulusSwitchType::Standard => Self::Standard,
ModulusSwitchType::DriftTechniqueNoiseReduction(_) => {
Self::DriftTechniqueNoiseReduction(
drift_key.expect("Invalid drift key configuration"),
)
}
ModulusSwitchType::CenteredMeanNoiseReduction => {
Self::CenteredMeanNoiseReduction
}
}
}
}
}
pub fn matches_shortint_server_key_modulus_switch_config(
&self,
shortint_config: NoiseSimulationModulusSwitchConfig<&ServerKey>,
) -> bool {
match (self, shortint_config) {
(Self::Standard, NoiseSimulationModulusSwitchConfig::Standard) => true,
(
Self::DriftTechniqueNoiseReduction(noise_sim),
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(sks),
) => noise_sim.matches_actual_shortint_server_key(sks),
(
Self::CenteredMeanNoiseReduction,
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction,
) => true,
_ => false,
}
}
pub fn matches_shortint_noise_squashing_modulus_switch_config(
&self,
shortint_config: NoiseSimulationModulusSwitchConfig<&NoiseSquashingKey>,
) -> bool {
match (self, shortint_config) {
(Self::Standard, NoiseSimulationModulusSwitchConfig::Standard) => true,
(
Self::DriftTechniqueNoiseReduction(noise_sim),
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(sns),
) => noise_sim.matches_actual_shortint_noise_squashing_key(sns),
(
Self::CenteredMeanNoiseReduction,
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction,
) => true,
_ => false,
}
}
}
impl<DriftKey> NoiseSimulationModulusSwitchConfig<DriftKey> {
fn new_from_config_and_key<Scalar: UnsignedInteger>(
config: &ModulusSwitchConfiguration<Scalar>,
key: DriftKey,
) -> Self {
match config {
ModulusSwitchConfiguration::Standard => Self::Standard,
ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_) => {
Self::DriftTechniqueNoiseReduction(key)
}
ModulusSwitchConfiguration::CenteredMeanNoiseReduction => {
Self::CenteredMeanNoiseReduction
}
}
}
}
impl<Scalar: UnsignedInteger> From<&ModulusSwitchConfiguration<Scalar>>
for NoiseSimulationModulusSwitchConfig
{
fn from(value: &ModulusSwitchConfiguration<Scalar>) -> Self {
match value {
ModulusSwitchConfiguration::Standard => Self::Standard,
ModulusSwitchConfiguration::DriftTechniqueNoiseReduction(_) => {
Self::DriftTechniqueNoiseReduction
impl<DriftKey> NoiseSimulationModulusSwitchConfig<DriftKey> {
pub fn as_ref(&self) -> NoiseSimulationModulusSwitchConfig<&DriftKey> {
match self {
Self::Standard => NoiseSimulationModulusSwitchConfig::Standard,
Self::DriftTechniqueNoiseReduction(key) => {
NoiseSimulationModulusSwitchConfig::DriftTechniqueNoiseReduction(key)
}
ModulusSwitchConfiguration::CenteredMeanNoiseReduction => {
Self::CenteredMeanNoiseReduction
Self::CenteredMeanNoiseReduction => {
NoiseSimulationModulusSwitchConfig::CenteredMeanNoiseReduction
}
}
}
pub fn expected_average_after_ms(self, polynomial_size: PolynomialSize) -> f64 {
match self {
Self::Standard => 0.0f64,
Self::DriftTechniqueNoiseReduction(_) => 0.0f64,
Self::CenteredMeanNoiseReduction => {
// Half case subtracted before entering the blind rotate
-1.0f64 / (4.0 * polynomial_size.0 as f64)
}
}
}
@@ -518,16 +611,21 @@ impl ServerKey {
}
}
pub fn noise_simulation_modulus_switch_config(&self) -> NoiseSimulationModulusSwitchConfig {
pub fn noise_simulation_modulus_switch_config(
&self,
) -> NoiseSimulationModulusSwitchConfig<&Self> {
match &self.atomic_pattern {
AtomicPatternServerKey::Standard(standard_atomic_pattern_server_key) => {
match &standard_atomic_pattern_server_key.bootstrapping_key {
ShortintBootstrappingKey::Classic {
bsk: _,
modulus_switch_noise_reduction_key,
} => modulus_switch_noise_reduction_key.into(),
} => NoiseSimulationModulusSwitchConfig::new_from_config_and_key(
modulus_switch_noise_reduction_key,
self,
),
ShortintBootstrappingKey::MultiBit { .. } => {
todo!("Unsupported ShortintBootstrappingKey::MultiBit for noise simulation")
panic!("MultiBit ServerKey does not support the drift technique")
}
}
}
@@ -536,9 +634,12 @@ impl ServerKey {
ShortintBootstrappingKey::Classic {
bsk: _,
modulus_switch_noise_reduction_key,
} => modulus_switch_noise_reduction_key.into(),
} => NoiseSimulationModulusSwitchConfig::new_from_config_and_key(
modulus_switch_noise_reduction_key,
self,
),
ShortintBootstrappingKey::MultiBit { .. } => {
todo!("Unsupported ShortintBootstrappingKey::MultiBit for noise simulation")
panic!("MultiBit ServerKey does not support the drift technique")
}
}
}
@@ -839,7 +940,9 @@ impl<C: Container<Element = u64>> LweClassicFftBootstrap<DynLwe, DynLwe, LookupT
}
impl NoiseSquashingKey {
pub fn noise_simulation_modulus_switch_config(&self) -> NoiseSimulationModulusSwitchConfig {
pub fn noise_simulation_modulus_switch_config(
&self,
) -> NoiseSimulationModulusSwitchConfig<&Self> {
match self.atomic_pattern() {
AtomicPatternNoiseSquashingKey::Standard(
standard_atomic_pattern_noise_squashing_key,
@@ -847,7 +950,10 @@ impl NoiseSquashingKey {
Shortint128BootstrappingKey::Classic {
bsk: _,
modulus_switch_noise_reduction_key,
} => modulus_switch_noise_reduction_key.into(),
} => NoiseSimulationModulusSwitchConfig::new_from_config_and_key(
modulus_switch_noise_reduction_key,
self,
),
Shortint128BootstrappingKey::MultiBit { .. } => {
panic!("MultiBit ServerKey does not support the drift technique")
}
@@ -858,7 +964,10 @@ impl NoiseSquashingKey {
Shortint128BootstrappingKey::Classic {
bsk: _,
modulus_switch_noise_reduction_key,
} => modulus_switch_noise_reduction_key.into(),
} => NoiseSimulationModulusSwitchConfig::new_from_config_and_key(
modulus_switch_noise_reduction_key,
self,
),
Shortint128BootstrappingKey::MultiBit { .. } => {
panic!("MultiBit ServerKey does not support the drift technique")
}