mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
chore: woppbs, variance and complexity functions
This commit is contained in:
@@ -9,9 +9,9 @@ use crate::optimization::atomic_pattern;
|
||||
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
|
||||
|
||||
use crate::optimization::config::{Config, SearchSpace};
|
||||
use crate::optimization::decomposition::circuit_bootstrap::CbComplexityNoise;
|
||||
use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches};
|
||||
use crate::parameters::{GlweParameters, LweDimension, PbsParameters};
|
||||
use crate::utils::square;
|
||||
use crate::parameters::{BrDecompositionParameters, GlweParameters};
|
||||
use concrete_commons::dispersion::{DispersionParameter, Variance};
|
||||
|
||||
pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 {
|
||||
@@ -94,6 +94,89 @@ impl From<Solution> for atomic_pattern::Solution {
|
||||
}
|
||||
}
|
||||
|
||||
fn estimate_variance(
|
||||
br_variance: f64,
|
||||
pp_variance: f64,
|
||||
cb_decomp: &CbComplexityNoise,
|
||||
ks_variance: f64,
|
||||
variance_modulus_switching: f64,
|
||||
log_norm: f64,
|
||||
precisions_sum: u64,
|
||||
max_precision: u64,
|
||||
) -> f64 {
|
||||
assert!(max_precision <= precisions_sum);
|
||||
let variance_ggsw = pp_variance + br_variance / 2.;
|
||||
let variance_coeff_1_cmux_tree =
|
||||
2_f64.powf(2. * log_norm) // variance_coeff for the multisum
|
||||
* (precisions_sum // for hybrid packing
|
||||
<< (2 * (max_precision - 1))) as f64 // for left shift
|
||||
;
|
||||
let variance_one_external_product_for_cmux_tree = cb_decomp.variance_from_ggsw(variance_ggsw);
|
||||
variance_modulus_switching
|
||||
+ variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree
|
||||
+ ks_variance
|
||||
}
|
||||
|
||||
fn estimate_complexity(
|
||||
glwe_params: &GlweParameters,
|
||||
br_cost: f64,
|
||||
pp_cost: f64,
|
||||
cb_decomp: &CbComplexityNoise,
|
||||
ks_cost: f64,
|
||||
precisions_sum: u64,
|
||||
nb_blocks: u64,
|
||||
n_functions: u64,
|
||||
) -> f64 {
|
||||
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
||||
// Hybrid packing
|
||||
let cb_level = cb_decomp.decomp.level;
|
||||
let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp;
|
||||
let complexity_1_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft;
|
||||
// BitExtract use br
|
||||
let complexity_bit_extract_1_pbs = br_cost;
|
||||
let complexity_bit_extract_wo_ks =
|
||||
(precisions_sum - nb_blocks) as f64 * complexity_bit_extract_1_pbs;
|
||||
|
||||
// Hybrid packing
|
||||
// Circuit bs: fp-ks
|
||||
let complexity_ppks = pp_cost;
|
||||
let complexity_all_ppks =
|
||||
((glwe_params.glwe_dimension + 1) * cb_level * precisions_sum) as f64 * complexity_ppks;
|
||||
|
||||
// Circuit bs: pbs
|
||||
let complexity_all_pbs = (precisions_sum * cb_level) as f64 * br_cost;
|
||||
|
||||
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
|
||||
|
||||
// Hybrid packing (Do we have 1 or 2 groups)
|
||||
let log2_polynomial_size = glwe_params.log2_polynomial_size;
|
||||
// Size of cmux_group, can be zero
|
||||
let cmux_group_count = if precisions_sum > log2_polynomial_size {
|
||||
2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp;
|
||||
|
||||
let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_1_ggsw_to_fft;
|
||||
|
||||
// Hybrid packing blind rotate
|
||||
let complexity_g_br =
|
||||
complexity_1_cmux_hp * u64::min(glwe_params.log2_polynomial_size, precisions_sum) as f64;
|
||||
|
||||
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
|
||||
|
||||
let complexity_multi_hybrid_packing =
|
||||
n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft;
|
||||
|
||||
let complexity_all_ks = precisions_sum as f64 * ks_cost;
|
||||
|
||||
complexity_bit_extract_wo_ks
|
||||
+ complexity_circuit_bs
|
||||
+ complexity_multi_hybrid_packing
|
||||
+ complexity_all_ks
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn update_state_with_best_decompositions(
|
||||
state: &mut OptimizationState,
|
||||
@@ -107,6 +190,7 @@ fn update_state_with_best_decompositions(
|
||||
let ciphertext_modulus_log = consts.config.ciphertext_modulus_log;
|
||||
let precisions_sum = partitionning.iter().copied().sum();
|
||||
let max_precision = partitionning.iter().copied().max().unwrap();
|
||||
let nb_blocks = partitionning.len() as u64;
|
||||
|
||||
let safe_variance_bound = consts.safe_variance;
|
||||
let log_norm = consts.noise_factor.log2();
|
||||
@@ -140,146 +224,173 @@ fn update_state_with_best_decompositions(
|
||||
.keyswitch
|
||||
.pareto_quantities(glwe_params, internal_dim);
|
||||
|
||||
let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise;
|
||||
let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise;
|
||||
let lower_bound_complexity_all_ks = precisions_sum as f64 * pareto_keyswitch[0].complexity;
|
||||
|
||||
let variance_coeff_br = square(consts.noise_factor) / 2.0;
|
||||
let simple_variance = |br_variance: Option<_>, ks_variance: Option<_>| {
|
||||
variance_modulus_switching
|
||||
+ variance_coeff_br * br_variance.unwrap_or(lower_bound_variance_blind_rotate)
|
||||
+ ks_variance.unwrap_or(lower_bound_variance_keyswitch)
|
||||
};
|
||||
|
||||
let lower_bound_variance = simple_variance(None, None);
|
||||
if lower_bound_variance > consts.safe_variance {
|
||||
// saves 20%
|
||||
return;
|
||||
}
|
||||
|
||||
let pp_switch = caches
|
||||
.pp_switch
|
||||
.pareto_quantities(glwe_params, internal_dim);
|
||||
|
||||
let pareto_cb = caches.cb_pbs.pareto_quantities(glwe_params);
|
||||
|
||||
let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise;
|
||||
let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise;
|
||||
let lower_bound_variance_private_packing = pp_switch.last().unwrap().noise;
|
||||
let lower_pareto_cb_bias = pareto_cb
|
||||
.iter()
|
||||
.map(|cb| cb.variance_bias)
|
||||
.reduce(f64::min)
|
||||
.unwrap();
|
||||
let lower_pareto_cb_slope = pareto_cb
|
||||
.iter()
|
||||
.map(|cb| cb.variance_ggsw_factor)
|
||||
.reduce(f64::min)
|
||||
.unwrap();
|
||||
let lower_bound_cost_blind_rotate = pareto_blind_rotate[0].complexity;
|
||||
let lower_bound_cost_keyswitch = pareto_keyswitch[0].complexity;
|
||||
let lower_bound_cost_pp = pp_switch[0].complexity;
|
||||
let lower_bound_cost_cb_complexity_1_cmux_hp = pareto_cb
|
||||
.iter()
|
||||
.map(|cb| cb.complexity_one_cmux_hp)
|
||||
.reduce(f64::min)
|
||||
.unwrap_or(0.0);
|
||||
let lower_bound_cost_cb_complexity_1_ggsw_to_fft = pareto_cb
|
||||
.iter()
|
||||
.map(|cb| cb.complexity_one_ggsw_to_fft)
|
||||
.reduce(f64::min)
|
||||
.unwrap_or(0.0);
|
||||
let lower_bound_cb = CbComplexityNoise {
|
||||
decomp: BrDecompositionParameters {
|
||||
level: 1,
|
||||
log2_base: 1,
|
||||
},
|
||||
complexity_one_cmux_hp: lower_bound_cost_cb_complexity_1_cmux_hp,
|
||||
complexity_one_ggsw_to_fft: lower_bound_cost_cb_complexity_1_ggsw_to_fft,
|
||||
variance_bias: lower_pareto_cb_bias,
|
||||
variance_ggsw_factor: lower_pareto_cb_slope,
|
||||
};
|
||||
|
||||
let variance = |br_variance: Option<_>,
|
||||
pp_variance: Option<_>,
|
||||
cb_decomp: Option<&CbComplexityNoise>,
|
||||
ks_variance: Option<_>| {
|
||||
let br_variance = br_variance.unwrap_or(lower_bound_variance_blind_rotate);
|
||||
let pp_variance = pp_variance.unwrap_or(lower_bound_variance_private_packing);
|
||||
let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb);
|
||||
let ks_variance = ks_variance.unwrap_or(lower_bound_variance_keyswitch);
|
||||
estimate_variance(
|
||||
br_variance,
|
||||
pp_variance,
|
||||
cb_decomp,
|
||||
ks_variance,
|
||||
variance_modulus_switching,
|
||||
log_norm,
|
||||
precisions_sum,
|
||||
max_precision,
|
||||
)
|
||||
};
|
||||
|
||||
let lower_bound_variance = variance(None, None, None, None);
|
||||
if lower_bound_variance > consts.safe_variance {
|
||||
// saves 20%
|
||||
return;
|
||||
}
|
||||
|
||||
let complexity = |br_cost: Option<_>,
|
||||
pp_cost: Option<_>,
|
||||
cb_decomp: Option<&CbComplexityNoise>,
|
||||
ks_cost: Option<_>| {
|
||||
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
||||
let br_cost = br_cost.unwrap_or(lower_bound_cost_blind_rotate);
|
||||
let ks_cost = ks_cost.unwrap_or(lower_bound_cost_keyswitch);
|
||||
let pp_cost = pp_cost.unwrap_or(lower_bound_cost_pp);
|
||||
let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb);
|
||||
estimate_complexity(
|
||||
&glwe_params,
|
||||
br_cost,
|
||||
pp_cost,
|
||||
cb_decomp,
|
||||
ks_cost,
|
||||
precisions_sum,
|
||||
nb_blocks,
|
||||
n_functions,
|
||||
)
|
||||
};
|
||||
|
||||
// BlindRotate dans Circuit BS
|
||||
for (br_i, shared_br_decomp) in pareto_blind_rotate.iter().enumerate() {
|
||||
if simple_variance(Some(shared_br_decomp.noise), None) > consts.safe_variance {
|
||||
let lower_bound_variance = variance(Some(shared_br_decomp.noise), None, None, None);
|
||||
|
||||
if lower_bound_variance > consts.safe_variance {
|
||||
// saves 20%
|
||||
continue;
|
||||
}
|
||||
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
||||
let br_decomposition_parameter = shared_br_decomp.decomp;
|
||||
let pbs_parameters = PbsParameters {
|
||||
internal_lwe_dimension: LweDimension(internal_dim),
|
||||
br_decomposition_parameter,
|
||||
output_glwe_params: glwe_params,
|
||||
};
|
||||
|
||||
// BitExtract use this pbs
|
||||
let complexity_bit_extract_1_pbs = shared_br_decomp.complexity;
|
||||
let complexity_bit_extract_wo_ks =
|
||||
(precisions_sum - partitionning.len() as u64) as f64 * complexity_bit_extract_1_pbs;
|
||||
|
||||
if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks > best_complexity {
|
||||
// saves 0%
|
||||
// next br_decomp are at least as costly
|
||||
break;
|
||||
}
|
||||
|
||||
// Circuit Boostrap
|
||||
// private packing keyswitch, <=> FP-KS
|
||||
let pp_switching_index = br_i;
|
||||
let base_variance_private_packing_ks = pp_switch[pp_switching_index].noise;
|
||||
let complexity_ppks = pp_switch[pp_switching_index].complexity;
|
||||
let pp_switching = pp_switch[pp_switching_index];
|
||||
|
||||
let variance_ggsw = base_variance_private_packing_ks + shared_br_decomp.noise / 2.;
|
||||
|
||||
let variance_coeff_1_cmux_tree =
|
||||
2_f64.powf(2. * log_norm) // variance_coeff for the multisum
|
||||
* (precisions_sum // for hybrid packing
|
||||
<< (2 * (max_precision - 1))) as f64 // for left shift
|
||||
;
|
||||
let lower_bound_variance = variance(
|
||||
Some(shared_br_decomp.noise),
|
||||
Some(pp_switching.noise),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
if lower_bound_variance > safe_variance_bound {
|
||||
continue;
|
||||
}
|
||||
let lower_bound_complexity = complexity(
|
||||
Some(shared_br_decomp.complexity),
|
||||
Some(pp_switching.complexity),
|
||||
None,
|
||||
None,
|
||||
);
|
||||
if lower_bound_complexity > best_complexity {
|
||||
// saves ?? TODO
|
||||
// next br_decomp are at least as costly
|
||||
break;
|
||||
}
|
||||
|
||||
// CircuitBootstrap: new parameters l,b
|
||||
// for &circuit_pbs_decomposition in pareto_circuit_pbs {
|
||||
for cb_decomp in pareto_cb {
|
||||
// Hybrid packing
|
||||
let cb_level = cb_decomp.decomp.level;
|
||||
// Circuit bs: fp-ks
|
||||
let complexity_all_ppks = ((pbs_parameters.output_glwe_params.glwe_dimension + 1)
|
||||
* cb_level
|
||||
* precisions_sum) as f64
|
||||
* complexity_ppks;
|
||||
|
||||
// Circuit bs: pbs
|
||||
let complexity_all_pbs =
|
||||
(precisions_sum * cb_level) as f64 * shared_br_decomp.complexity;
|
||||
|
||||
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
|
||||
|
||||
if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks + complexity_circuit_bs
|
||||
> best_complexity
|
||||
{
|
||||
let lower_bound_variance = variance(
|
||||
Some(shared_br_decomp.noise),
|
||||
Some(pp_switching.noise),
|
||||
Some(cb_decomp),
|
||||
None,
|
||||
);
|
||||
if lower_bound_variance > safe_variance_bound {
|
||||
continue;
|
||||
}
|
||||
let lower_bound_complexity = complexity(
|
||||
Some(shared_br_decomp.complexity),
|
||||
Some(pp_switching.complexity),
|
||||
Some(cb_decomp),
|
||||
None,
|
||||
);
|
||||
if lower_bound_complexity > best_complexity {
|
||||
// saves 50%
|
||||
// next circuit_pbs_decomposition_parameter are at least as costly
|
||||
break;
|
||||
}
|
||||
|
||||
// Hybrid packing
|
||||
let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp;
|
||||
|
||||
// Hybrid packing (Do we have 1 or 2 groups)
|
||||
let log2_polynomial_size = pbs_parameters.output_glwe_params.log2_polynomial_size;
|
||||
// Size of cmux_group, can be zero
|
||||
let cmux_group_count = if precisions_sum > log2_polynomial_size {
|
||||
2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32)
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp;
|
||||
|
||||
let complexity_one_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft;
|
||||
|
||||
let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_one_ggsw_to_fft;
|
||||
|
||||
// Hybrid packing blind rotate
|
||||
let complexity_g_br = complexity_1_cmux_hp
|
||||
* u64::min(
|
||||
pbs_parameters.output_glwe_params.log2_polynomial_size,
|
||||
precisions_sum,
|
||||
) as f64;
|
||||
|
||||
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
|
||||
|
||||
let complexity_multi_hybrid_packing =
|
||||
n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft;
|
||||
// Cutting on complexity here is counter-productive probably because complexity_multi_hybrid_packing is small
|
||||
|
||||
let variance_one_external_product_for_cmux_tree =
|
||||
cb_decomp.variance_from_ggsw(variance_ggsw);
|
||||
|
||||
// final out noise hybrid packing
|
||||
let variance_after_1st_bit_extract =
|
||||
variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree;
|
||||
|
||||
let variance_wo_ks = variance_modulus_switching + variance_after_1st_bit_extract;
|
||||
|
||||
// Shared by all pbs (like brs)
|
||||
for ks_decomp in pareto_keyswitch.iter().rev() {
|
||||
let variance_keyswitch = ks_decomp.noise;
|
||||
let variance_max = variance_wo_ks + variance_keyswitch;
|
||||
let variance_max = variance(
|
||||
Some(shared_br_decomp.noise),
|
||||
Some(pp_switching.noise),
|
||||
Some(cb_decomp),
|
||||
Some(ks_decomp.noise),
|
||||
);
|
||||
if variance_max > safe_variance_bound {
|
||||
// saves 40%
|
||||
break;
|
||||
}
|
||||
|
||||
let complexity_all_ks = precisions_sum as f64 * ks_decomp.complexity;
|
||||
let complexity = complexity_bit_extract_wo_ks
|
||||
+ complexity_circuit_bs
|
||||
+ complexity_multi_hybrid_packing
|
||||
+ complexity_all_ks;
|
||||
let complexity = complexity(
|
||||
Some(shared_br_decomp.complexity),
|
||||
Some(pp_switching.complexity),
|
||||
Some(cb_decomp),
|
||||
Some(ks_decomp.complexity),
|
||||
);
|
||||
|
||||
if complexity > best_complexity {
|
||||
continue;
|
||||
@@ -294,19 +405,15 @@ fn update_state_with_best_decompositions(
|
||||
best_complexity = complexity;
|
||||
best_variance = variance_max;
|
||||
let p_error = find_p_error(kappa, safe_variance_bound, variance_max);
|
||||
let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension();
|
||||
let glwe_polynomial_size = glwe_params.polynomial_size();
|
||||
let glwe_dimension = glwe_params.glwe_dimension;
|
||||
let ks_decomposition_parameter = ks_decomp.decomp;
|
||||
state.best_solution = Some(Solution {
|
||||
input_lwe_dimension,
|
||||
input_lwe_dimension: glwe_params.sample_extract_lwe_dimension(),
|
||||
internal_ks_output_lwe_dimension: internal_dim,
|
||||
ks_decomposition_level_count: ks_decomposition_parameter.level,
|
||||
ks_decomposition_base_log: ks_decomposition_parameter.log2_base,
|
||||
glwe_polynomial_size,
|
||||
glwe_dimension,
|
||||
br_decomposition_level_count: br_decomposition_parameter.level,
|
||||
br_decomposition_base_log: br_decomposition_parameter.log2_base,
|
||||
ks_decomposition_level_count: ks_decomp.decomp.level,
|
||||
ks_decomposition_base_log: ks_decomp.decomp.log2_base,
|
||||
glwe_polynomial_size: glwe_params.polynomial_size(),
|
||||
glwe_dimension: glwe_params.glwe_dimension,
|
||||
br_decomposition_level_count: shared_br_decomp.decomp.level,
|
||||
br_decomposition_base_log: shared_br_decomp.decomp.log2_base,
|
||||
noise_max: variance_max,
|
||||
complexity,
|
||||
p_error,
|
||||
|
||||
Reference in New Issue
Block a user