feat(optimizer): precompute mins in cb_pareto

This commit is contained in:
Mayeul@Zama
2022-12-23 14:37:33 +01:00
committed by mayeul-zama
parent b9678a37b4
commit d6e69f878c
2 changed files with 50 additions and 32 deletions

View File

@@ -87,15 +87,24 @@ pub fn pareto_quantities(
quantities
}
pub type Cache = CacheHashMap<GlweParameters, Vec<CbComplexityNoise>>;
pub type Cache = CacheHashMap<GlweParameters, CbPareto>;
impl Cache {
pub fn pareto_quantities(&mut self, glwe_params: GlweParameters) -> &[CbComplexityNoise] {
pub fn pareto_quantities(&mut self, glwe_params: GlweParameters) -> &CbPareto {
self.get(glwe_params)
}
}
pub type PersistDecompCache = PersistentCacheHashMap<GlweParameters, Vec<CbComplexityNoise>>;
pub type PersistDecompCache = PersistentCacheHashMap<GlweParameters, CbPareto>;
#[derive(Clone, Serialize, Deserialize)]
pub struct CbPareto {
pub pareto: Vec<CbComplexityNoise>,
pub lower_pareto_cb_bias: f64,
pub lower_pareto_cb_slope: f64,
pub lower_bound_cost_cb_complexity_1_cmux_hp: f64,
pub lower_bound_cost_cb_complexity_1_ggsw_to_fft: f64,
}
pub fn cache(
security_level: u64,
@@ -108,11 +117,40 @@ pub fn cache(
let path =
format!("{cache_dir}/cb-decomp-{hardware}-{ciphertext_modulus_log}-{security_level}");
let function = move |glwe_params| {
pareto_quantities(
let pareto = pareto_quantities(
complexity_model.as_ref(),
ciphertext_modulus_log,
glwe_params,
)
);
let lower_pareto_cb_bias = pareto
.iter()
.map(|cb| cb.variance_bias)
.reduce(f64::min)
.unwrap();
let lower_pareto_cb_slope = pareto
.iter()
.map(|cb| cb.variance_ggsw_factor)
.reduce(f64::min)
.unwrap();
let lower_bound_cost_cb_complexity_1_cmux_hp = pareto
.iter()
.map(|cb| cb.complexity_one_cmux_hp)
.reduce(f64::min)
.unwrap();
let lower_bound_cost_cb_complexity_1_ggsw_to_fft = pareto
.iter()
.map(|cb| cb.complexity_one_ggsw_to_fft)
.reduce(f64::min)
.unwrap();
CbPareto {
pareto,
lower_pareto_cb_bias,
lower_pareto_cb_slope,
lower_bound_cost_cb_complexity_1_cmux_hp,
lower_bound_cost_cb_complexity_1_ggsw_to_fft,
}
};
PersistentCacheHashMap::new_no_read(&path, VERSION, function)
}

View File

@@ -11,7 +11,7 @@ use crate::optimization::atomic_pattern;
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
use crate::optimization::config::{Config, SearchSpace};
use crate::optimization::decomposition::circuit_bootstrap::CbComplexityNoise;
use crate::optimization::decomposition::circuit_bootstrap::{CbComplexityNoise, CbPareto};
use crate::optimization::decomposition::cmux::CmuxComplexityNoise;
use crate::optimization::decomposition::keyswitch::KsComplexityNoise;
use crate::optimization::decomposition::pp_switch::PpSwitchComplexityNoise;
@@ -197,7 +197,7 @@ fn update_state_with_best_decompositions(
pareto_cmux: &[CmuxComplexityNoise],
pareto_keyswitch: &[KsComplexityNoise],
pp_switch: &[PpSwitchComplexityNoise],
pareto_cb: &[CbComplexityNoise],
pareto_cb: &CbPareto,
) {
let ciphertext_modulus_log = consts.config.ciphertext_modulus_log;
let precisions_sum = partitionning.iter().copied().sum();
@@ -231,38 +231,18 @@ fn update_state_with_best_decompositions(
let lower_bound_variance_br = pareto_cmux.last().unwrap().noise_br(internal_dim);
let lower_bound_variance_ks = pareto_keyswitch.last().unwrap().noise(input_lwe_dimension);
let lower_bound_variance_private_packing = pp_switch.last().unwrap().noise;
let lower_pareto_cb_bias = pareto_cb
.iter()
.map(|cb| cb.variance_bias)
.reduce(f64::min)
.unwrap();
let lower_pareto_cb_slope = pareto_cb
.iter()
.map(|cb| cb.variance_ggsw_factor)
.reduce(f64::min)
.unwrap();
let lower_bound_cost_br = pareto_cmux[0].complexity_br(internal_dim);
let lower_bound_cost_ks = pareto_keyswitch[0].complexity(input_lwe_dimension);
let lower_bound_cost_pp = pp_switch[0].complexity;
let lower_bound_cost_cb_complexity_1_cmux_hp = pareto_cb
.iter()
.map(|cb| cb.complexity_one_cmux_hp)
.reduce(f64::min)
.unwrap_or(0.0);
let lower_bound_cost_cb_complexity_1_ggsw_to_fft = pareto_cb
.iter()
.map(|cb| cb.complexity_one_ggsw_to_fft)
.reduce(f64::min)
.unwrap_or(0.0);
let lower_bound_cb = CbComplexityNoise {
decomp: BrDecompositionParameters {
level: 1,
log2_base: 1,
},
complexity_one_cmux_hp: lower_bound_cost_cb_complexity_1_cmux_hp,
complexity_one_ggsw_to_fft: lower_bound_cost_cb_complexity_1_ggsw_to_fft,
variance_bias: lower_pareto_cb_bias,
variance_ggsw_factor: lower_pareto_cb_slope,
complexity_one_cmux_hp: pareto_cb.lower_bound_cost_cb_complexity_1_cmux_hp,
complexity_one_ggsw_to_fft: pareto_cb.lower_bound_cost_cb_complexity_1_ggsw_to_fft,
variance_bias: pareto_cb.lower_pareto_cb_bias,
variance_ggsw_factor: pareto_cb.lower_pareto_cb_slope,
};
let variance = |cmux_quantity: Option<CmuxComplexityNoise>,
@@ -348,7 +328,7 @@ fn update_state_with_best_decompositions(
// CircuitBootstrap: new parameters l,b
// for &circuit_pbs_decomposition in pareto_circuit_pbs {
for cb_decomp in pareto_cb {
for cb_decomp in &pareto_cb.pareto {
let lower_bound_variance = variance(
Some(cmux_decomp),
Some(pp_switching.noise),