chore: woppbs, variance and complexity functions

This commit is contained in:
rudy
2022-10-26 14:15:27 +02:00
committed by rudy-6-4
parent 15237c4550
commit b8e7c04469

View File

@@ -9,9 +9,9 @@ use crate::optimization::atomic_pattern;
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
use crate::optimization::config::{Config, SearchSpace};
use crate::optimization::decomposition::circuit_bootstrap::CbComplexityNoise;
use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches};
use crate::parameters::{GlweParameters, LweDimension, PbsParameters};
use crate::utils::square;
use crate::parameters::{BrDecompositionParameters, GlweParameters};
use concrete_commons::dispersion::{DispersionParameter, Variance};
pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 {
@@ -94,6 +94,89 @@ impl From<Solution> for atomic_pattern::Solution {
}
}
fn estimate_variance(
br_variance: f64,
pp_variance: f64,
cb_decomp: &CbComplexityNoise,
ks_variance: f64,
variance_modulus_switching: f64,
log_norm: f64,
precisions_sum: u64,
max_precision: u64,
) -> f64 {
assert!(max_precision <= precisions_sum);
let variance_ggsw = pp_variance + br_variance / 2.;
let variance_coeff_1_cmux_tree =
2_f64.powf(2. * log_norm) // variance_coeff for the multisum
* (precisions_sum // for hybrid packing
<< (2 * (max_precision - 1))) as f64 // for left shift
;
let variance_one_external_product_for_cmux_tree = cb_decomp.variance_from_ggsw(variance_ggsw);
variance_modulus_switching
+ variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree
+ ks_variance
}
fn estimate_complexity(
glwe_params: &GlweParameters,
br_cost: f64,
pp_cost: f64,
cb_decomp: &CbComplexityNoise,
ks_cost: f64,
precisions_sum: u64,
nb_blocks: u64,
n_functions: u64,
) -> f64 {
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
// Hybrid packing
let cb_level = cb_decomp.decomp.level;
let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp;
let complexity_1_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft;
// BitExtract use br
let complexity_bit_extract_1_pbs = br_cost;
let complexity_bit_extract_wo_ks =
(precisions_sum - nb_blocks) as f64 * complexity_bit_extract_1_pbs;
// Hybrid packing
// Circuit bs: fp-ks
let complexity_ppks = pp_cost;
let complexity_all_ppks =
((glwe_params.glwe_dimension + 1) * cb_level * precisions_sum) as f64 * complexity_ppks;
// Circuit bs: pbs
let complexity_all_pbs = (precisions_sum * cb_level) as f64 * br_cost;
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
// Hybrid packing (Do we have 1 or 2 groups)
let log2_polynomial_size = glwe_params.log2_polynomial_size;
// Size of cmux_group, can be zero
let cmux_group_count = if precisions_sum > log2_polynomial_size {
2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32)
} else {
0.0
};
let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp;
let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_1_ggsw_to_fft;
// Hybrid packing blind rotate
let complexity_g_br =
complexity_1_cmux_hp * u64::min(glwe_params.log2_polynomial_size, precisions_sum) as f64;
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
let complexity_multi_hybrid_packing =
n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft;
let complexity_all_ks = precisions_sum as f64 * ks_cost;
complexity_bit_extract_wo_ks
+ complexity_circuit_bs
+ complexity_multi_hybrid_packing
+ complexity_all_ks
}
#[allow(clippy::too_many_lines)]
fn update_state_with_best_decompositions(
state: &mut OptimizationState,
@@ -107,6 +190,7 @@ fn update_state_with_best_decompositions(
let ciphertext_modulus_log = consts.config.ciphertext_modulus_log;
let precisions_sum = partitionning.iter().copied().sum();
let max_precision = partitionning.iter().copied().max().unwrap();
let nb_blocks = partitionning.len() as u64;
let safe_variance_bound = consts.safe_variance;
let log_norm = consts.noise_factor.log2();
@@ -140,146 +224,173 @@ fn update_state_with_best_decompositions(
.keyswitch
.pareto_quantities(glwe_params, internal_dim);
let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise;
let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise;
let lower_bound_complexity_all_ks = precisions_sum as f64 * pareto_keyswitch[0].complexity;
let variance_coeff_br = square(consts.noise_factor) / 2.0;
let simple_variance = |br_variance: Option<_>, ks_variance: Option<_>| {
variance_modulus_switching
+ variance_coeff_br * br_variance.unwrap_or(lower_bound_variance_blind_rotate)
+ ks_variance.unwrap_or(lower_bound_variance_keyswitch)
};
let lower_bound_variance = simple_variance(None, None);
if lower_bound_variance > consts.safe_variance {
// saves 20%
return;
}
let pp_switch = caches
.pp_switch
.pareto_quantities(glwe_params, internal_dim);
let pareto_cb = caches.cb_pbs.pareto_quantities(glwe_params);
let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise;
let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise;
let lower_bound_variance_private_packing = pp_switch.last().unwrap().noise;
let lower_pareto_cb_bias = pareto_cb
.iter()
.map(|cb| cb.variance_bias)
.reduce(f64::min)
.unwrap();
let lower_pareto_cb_slope = pareto_cb
.iter()
.map(|cb| cb.variance_ggsw_factor)
.reduce(f64::min)
.unwrap();
let lower_bound_cost_blind_rotate = pareto_blind_rotate[0].complexity;
let lower_bound_cost_keyswitch = pareto_keyswitch[0].complexity;
let lower_bound_cost_pp = pp_switch[0].complexity;
let lower_bound_cost_cb_complexity_1_cmux_hp = pareto_cb
.iter()
.map(|cb| cb.complexity_one_cmux_hp)
.reduce(f64::min)
.unwrap_or(0.0);
let lower_bound_cost_cb_complexity_1_ggsw_to_fft = pareto_cb
.iter()
.map(|cb| cb.complexity_one_ggsw_to_fft)
.reduce(f64::min)
.unwrap_or(0.0);
let lower_bound_cb = CbComplexityNoise {
decomp: BrDecompositionParameters {
level: 1,
log2_base: 1,
},
complexity_one_cmux_hp: lower_bound_cost_cb_complexity_1_cmux_hp,
complexity_one_ggsw_to_fft: lower_bound_cost_cb_complexity_1_ggsw_to_fft,
variance_bias: lower_pareto_cb_bias,
variance_ggsw_factor: lower_pareto_cb_slope,
};
let variance = |br_variance: Option<_>,
pp_variance: Option<_>,
cb_decomp: Option<&CbComplexityNoise>,
ks_variance: Option<_>| {
let br_variance = br_variance.unwrap_or(lower_bound_variance_blind_rotate);
let pp_variance = pp_variance.unwrap_or(lower_bound_variance_private_packing);
let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb);
let ks_variance = ks_variance.unwrap_or(lower_bound_variance_keyswitch);
estimate_variance(
br_variance,
pp_variance,
cb_decomp,
ks_variance,
variance_modulus_switching,
log_norm,
precisions_sum,
max_precision,
)
};
let lower_bound_variance = variance(None, None, None, None);
if lower_bound_variance > consts.safe_variance {
// saves 20%
return;
}
let complexity = |br_cost: Option<_>,
pp_cost: Option<_>,
cb_decomp: Option<&CbComplexityNoise>,
ks_cost: Option<_>| {
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
let br_cost = br_cost.unwrap_or(lower_bound_cost_blind_rotate);
let ks_cost = ks_cost.unwrap_or(lower_bound_cost_keyswitch);
let pp_cost = pp_cost.unwrap_or(lower_bound_cost_pp);
let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb);
estimate_complexity(
&glwe_params,
br_cost,
pp_cost,
cb_decomp,
ks_cost,
precisions_sum,
nb_blocks,
n_functions,
)
};
// BlindRotate dans Circuit BS
for (br_i, shared_br_decomp) in pareto_blind_rotate.iter().enumerate() {
if simple_variance(Some(shared_br_decomp.noise), None) > consts.safe_variance {
let lower_bound_variance = variance(Some(shared_br_decomp.noise), None, None, None);
if lower_bound_variance > consts.safe_variance {
// saves 20%
continue;
}
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
let br_decomposition_parameter = shared_br_decomp.decomp;
let pbs_parameters = PbsParameters {
internal_lwe_dimension: LweDimension(internal_dim),
br_decomposition_parameter,
output_glwe_params: glwe_params,
};
// BitExtract use this pbs
let complexity_bit_extract_1_pbs = shared_br_decomp.complexity;
let complexity_bit_extract_wo_ks =
(precisions_sum - partitionning.len() as u64) as f64 * complexity_bit_extract_1_pbs;
if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks > best_complexity {
// saves 0%
// next br_decomp are at least as costly
break;
}
// Circuit Boostrap
// private packing keyswitch, <=> FP-KS
let pp_switching_index = br_i;
let base_variance_private_packing_ks = pp_switch[pp_switching_index].noise;
let complexity_ppks = pp_switch[pp_switching_index].complexity;
let pp_switching = pp_switch[pp_switching_index];
let variance_ggsw = base_variance_private_packing_ks + shared_br_decomp.noise / 2.;
let variance_coeff_1_cmux_tree =
2_f64.powf(2. * log_norm) // variance_coeff for the multisum
* (precisions_sum // for hybrid packing
<< (2 * (max_precision - 1))) as f64 // for left shift
;
let lower_bound_variance = variance(
Some(shared_br_decomp.noise),
Some(pp_switching.noise),
None,
None,
);
if lower_bound_variance > safe_variance_bound {
continue;
}
let lower_bound_complexity = complexity(
Some(shared_br_decomp.complexity),
Some(pp_switching.complexity),
None,
None,
);
if lower_bound_complexity > best_complexity {
// saves ?? TODO
// next br_decomp are at least as costly
break;
}
// CircuitBootstrap: new parameters l,b
// for &circuit_pbs_decomposition in pareto_circuit_pbs {
for cb_decomp in pareto_cb {
// Hybrid packing
let cb_level = cb_decomp.decomp.level;
// Circuit bs: fp-ks
let complexity_all_ppks = ((pbs_parameters.output_glwe_params.glwe_dimension + 1)
* cb_level
* precisions_sum) as f64
* complexity_ppks;
// Circuit bs: pbs
let complexity_all_pbs =
(precisions_sum * cb_level) as f64 * shared_br_decomp.complexity;
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks + complexity_circuit_bs
> best_complexity
{
let lower_bound_variance = variance(
Some(shared_br_decomp.noise),
Some(pp_switching.noise),
Some(cb_decomp),
None,
);
if lower_bound_variance > safe_variance_bound {
continue;
}
let lower_bound_complexity = complexity(
Some(shared_br_decomp.complexity),
Some(pp_switching.complexity),
Some(cb_decomp),
None,
);
if lower_bound_complexity > best_complexity {
// saves 50%
// next circuit_pbs_decomposition_parameter are at least as costly
break;
}
// Hybrid packing
let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp;
// Hybrid packing (Do we have 1 or 2 groups)
let log2_polynomial_size = pbs_parameters.output_glwe_params.log2_polynomial_size;
// Size of cmux_group, can be zero
let cmux_group_count = if precisions_sum > log2_polynomial_size {
2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32)
} else {
0.0
};
let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp;
let complexity_one_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft;
let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_one_ggsw_to_fft;
// Hybrid packing blind rotate
let complexity_g_br = complexity_1_cmux_hp
* u64::min(
pbs_parameters.output_glwe_params.log2_polynomial_size,
precisions_sum,
) as f64;
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
let complexity_multi_hybrid_packing =
n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft;
// Cutting on complexity here is counter-productive probably because complexity_multi_hybrid_packing is small
let variance_one_external_product_for_cmux_tree =
cb_decomp.variance_from_ggsw(variance_ggsw);
// final out noise hybrid packing
let variance_after_1st_bit_extract =
variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree;
let variance_wo_ks = variance_modulus_switching + variance_after_1st_bit_extract;
// Shared by all pbs (like brs)
for ks_decomp in pareto_keyswitch.iter().rev() {
let variance_keyswitch = ks_decomp.noise;
let variance_max = variance_wo_ks + variance_keyswitch;
let variance_max = variance(
Some(shared_br_decomp.noise),
Some(pp_switching.noise),
Some(cb_decomp),
Some(ks_decomp.noise),
);
if variance_max > safe_variance_bound {
// saves 40%
break;
}
let complexity_all_ks = precisions_sum as f64 * ks_decomp.complexity;
let complexity = complexity_bit_extract_wo_ks
+ complexity_circuit_bs
+ complexity_multi_hybrid_packing
+ complexity_all_ks;
let complexity = complexity(
Some(shared_br_decomp.complexity),
Some(pp_switching.complexity),
Some(cb_decomp),
Some(ks_decomp.complexity),
);
if complexity > best_complexity {
continue;
@@ -294,19 +405,15 @@ fn update_state_with_best_decompositions(
best_complexity = complexity;
best_variance = variance_max;
let p_error = find_p_error(kappa, safe_variance_bound, variance_max);
let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension();
let glwe_polynomial_size = glwe_params.polynomial_size();
let glwe_dimension = glwe_params.glwe_dimension;
let ks_decomposition_parameter = ks_decomp.decomp;
state.best_solution = Some(Solution {
input_lwe_dimension,
input_lwe_dimension: glwe_params.sample_extract_lwe_dimension(),
internal_ks_output_lwe_dimension: internal_dim,
ks_decomposition_level_count: ks_decomposition_parameter.level,
ks_decomposition_base_log: ks_decomposition_parameter.log2_base,
glwe_polynomial_size,
glwe_dimension,
br_decomposition_level_count: br_decomposition_parameter.level,
br_decomposition_base_log: br_decomposition_parameter.log2_base,
ks_decomposition_level_count: ks_decomp.decomp.level,
ks_decomposition_base_log: ks_decomp.decomp.log2_base,
glwe_polynomial_size: glwe_params.polynomial_size(),
glwe_dimension: glwe_params.glwe_dimension,
br_decomposition_level_count: shared_br_decomp.decomp.level,
br_decomposition_base_log: shared_br_decomp.decomp.log2_base,
noise_max: variance_max,
complexity,
p_error,