diff --git a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs index 973600477..981006835 100644 --- a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs +++ b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs @@ -9,9 +9,9 @@ use crate::optimization::atomic_pattern; use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts; use crate::optimization::config::{Config, SearchSpace}; +use crate::optimization::decomposition::circuit_bootstrap::CbComplexityNoise; use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches}; -use crate::parameters::{GlweParameters, LweDimension, PbsParameters}; -use crate::utils::square; +use crate::parameters::{BrDecompositionParameters, GlweParameters}; use concrete_commons::dispersion::{DispersionParameter, Variance}; pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 { @@ -94,6 +94,89 @@ impl From for atomic_pattern::Solution { } } +fn estimate_variance( + br_variance: f64, + pp_variance: f64, + cb_decomp: &CbComplexityNoise, + ks_variance: f64, + variance_modulus_switching: f64, + log_norm: f64, + precisions_sum: u64, + max_precision: u64, +) -> f64 { + assert!(max_precision <= precisions_sum); + let variance_ggsw = pp_variance + br_variance / 2.; + let variance_coeff_1_cmux_tree = + 2_f64.powf(2. * log_norm) // variance_coeff for the multisum + * (precisions_sum // for hybrid packing + << (2 * (max_precision - 1))) as f64 // for left shift + ; + let variance_one_external_product_for_cmux_tree = cb_decomp.variance_from_ggsw(variance_ggsw); + variance_modulus_switching + + variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree + + ks_variance +} + +fn estimate_complexity( + glwe_params: &GlweParameters, + br_cost: f64, + pp_cost: f64, + cb_decomp: &CbComplexityNoise, + ks_cost: f64, + precisions_sum: u64, + nb_blocks: u64, + n_functions: u64, +) -> f64 { + // Pbs dans BitExtract et Circuit BS et FP-KS (partagés) + // Hybrid packing + let cb_level = cb_decomp.decomp.level; + let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp; + let complexity_1_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft; + // BitExtract use br + let complexity_bit_extract_1_pbs = br_cost; + let complexity_bit_extract_wo_ks = + (precisions_sum - nb_blocks) as f64 * complexity_bit_extract_1_pbs; + + // Hybrid packing + // Circuit bs: fp-ks + let complexity_ppks = pp_cost; + let complexity_all_ppks = + ((glwe_params.glwe_dimension + 1) * cb_level * precisions_sum) as f64 * complexity_ppks; + + // Circuit bs: pbs + let complexity_all_pbs = (precisions_sum * cb_level) as f64 * br_cost; + + let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks; + + // Hybrid packing (Do we have 1 or 2 groups) + let log2_polynomial_size = glwe_params.log2_polynomial_size; + // Size of cmux_group, can be zero + let cmux_group_count = if precisions_sum > log2_polynomial_size { + 2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32) + } else { + 0.0 + }; + let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp; + + let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_1_ggsw_to_fft; + + // Hybrid packing blind rotate + let complexity_g_br = + complexity_1_cmux_hp * u64::min(glwe_params.log2_polynomial_size, precisions_sum) as f64; + + let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br; + + let complexity_multi_hybrid_packing = + n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft; + + let complexity_all_ks = precisions_sum as f64 * ks_cost; + + complexity_bit_extract_wo_ks + + complexity_circuit_bs + + complexity_multi_hybrid_packing + + complexity_all_ks +} + #[allow(clippy::too_many_lines)] fn update_state_with_best_decompositions( state: &mut OptimizationState, @@ -107,6 +190,7 @@ fn update_state_with_best_decompositions( let ciphertext_modulus_log = consts.config.ciphertext_modulus_log; let precisions_sum = partitionning.iter().copied().sum(); let max_precision = partitionning.iter().copied().max().unwrap(); + let nb_blocks = partitionning.len() as u64; let safe_variance_bound = consts.safe_variance; let log_norm = consts.noise_factor.log2(); @@ -140,146 +224,173 @@ fn update_state_with_best_decompositions( .keyswitch .pareto_quantities(glwe_params, internal_dim); - let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise; - let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise; - let lower_bound_complexity_all_ks = precisions_sum as f64 * pareto_keyswitch[0].complexity; - - let variance_coeff_br = square(consts.noise_factor) / 2.0; - let simple_variance = |br_variance: Option<_>, ks_variance: Option<_>| { - variance_modulus_switching - + variance_coeff_br * br_variance.unwrap_or(lower_bound_variance_blind_rotate) - + ks_variance.unwrap_or(lower_bound_variance_keyswitch) - }; - - let lower_bound_variance = simple_variance(None, None); - if lower_bound_variance > consts.safe_variance { - // saves 20% - return; - } - let pp_switch = caches .pp_switch .pareto_quantities(glwe_params, internal_dim); let pareto_cb = caches.cb_pbs.pareto_quantities(glwe_params); + let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise; + let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise; + let lower_bound_variance_private_packing = pp_switch.last().unwrap().noise; + let lower_pareto_cb_bias = pareto_cb + .iter() + .map(|cb| cb.variance_bias) + .reduce(f64::min) + .unwrap(); + let lower_pareto_cb_slope = pareto_cb + .iter() + .map(|cb| cb.variance_ggsw_factor) + .reduce(f64::min) + .unwrap(); + let lower_bound_cost_blind_rotate = pareto_blind_rotate[0].complexity; + let lower_bound_cost_keyswitch = pareto_keyswitch[0].complexity; + let lower_bound_cost_pp = pp_switch[0].complexity; + let lower_bound_cost_cb_complexity_1_cmux_hp = pareto_cb + .iter() + .map(|cb| cb.complexity_one_cmux_hp) + .reduce(f64::min) + .unwrap_or(0.0); + let lower_bound_cost_cb_complexity_1_ggsw_to_fft = pareto_cb + .iter() + .map(|cb| cb.complexity_one_ggsw_to_fft) + .reduce(f64::min) + .unwrap_or(0.0); + let lower_bound_cb = CbComplexityNoise { + decomp: BrDecompositionParameters { + level: 1, + log2_base: 1, + }, + complexity_one_cmux_hp: lower_bound_cost_cb_complexity_1_cmux_hp, + complexity_one_ggsw_to_fft: lower_bound_cost_cb_complexity_1_ggsw_to_fft, + variance_bias: lower_pareto_cb_bias, + variance_ggsw_factor: lower_pareto_cb_slope, + }; + + let variance = |br_variance: Option<_>, + pp_variance: Option<_>, + cb_decomp: Option<&CbComplexityNoise>, + ks_variance: Option<_>| { + let br_variance = br_variance.unwrap_or(lower_bound_variance_blind_rotate); + let pp_variance = pp_variance.unwrap_or(lower_bound_variance_private_packing); + let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb); + let ks_variance = ks_variance.unwrap_or(lower_bound_variance_keyswitch); + estimate_variance( + br_variance, + pp_variance, + cb_decomp, + ks_variance, + variance_modulus_switching, + log_norm, + precisions_sum, + max_precision, + ) + }; + + let lower_bound_variance = variance(None, None, None, None); + if lower_bound_variance > consts.safe_variance { + // saves 20% + return; + } + + let complexity = |br_cost: Option<_>, + pp_cost: Option<_>, + cb_decomp: Option<&CbComplexityNoise>, + ks_cost: Option<_>| { + // Pbs dans BitExtract et Circuit BS et FP-KS (partagés) + let br_cost = br_cost.unwrap_or(lower_bound_cost_blind_rotate); + let ks_cost = ks_cost.unwrap_or(lower_bound_cost_keyswitch); + let pp_cost = pp_cost.unwrap_or(lower_bound_cost_pp); + let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb); + estimate_complexity( + &glwe_params, + br_cost, + pp_cost, + cb_decomp, + ks_cost, + precisions_sum, + nb_blocks, + n_functions, + ) + }; + // BlindRotate dans Circuit BS for (br_i, shared_br_decomp) in pareto_blind_rotate.iter().enumerate() { - if simple_variance(Some(shared_br_decomp.noise), None) > consts.safe_variance { + let lower_bound_variance = variance(Some(shared_br_decomp.noise), None, None, None); + + if lower_bound_variance > consts.safe_variance { // saves 20% continue; } - // Pbs dans BitExtract et Circuit BS et FP-KS (partagés) - let br_decomposition_parameter = shared_br_decomp.decomp; - let pbs_parameters = PbsParameters { - internal_lwe_dimension: LweDimension(internal_dim), - br_decomposition_parameter, - output_glwe_params: glwe_params, - }; - - // BitExtract use this pbs - let complexity_bit_extract_1_pbs = shared_br_decomp.complexity; - let complexity_bit_extract_wo_ks = - (precisions_sum - partitionning.len() as u64) as f64 * complexity_bit_extract_1_pbs; - - if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks > best_complexity { - // saves 0% - // next br_decomp are at least as costly - break; - } // Circuit Boostrap // private packing keyswitch, <=> FP-KS let pp_switching_index = br_i; - let base_variance_private_packing_ks = pp_switch[pp_switching_index].noise; - let complexity_ppks = pp_switch[pp_switching_index].complexity; + let pp_switching = pp_switch[pp_switching_index]; - let variance_ggsw = base_variance_private_packing_ks + shared_br_decomp.noise / 2.; - - let variance_coeff_1_cmux_tree = - 2_f64.powf(2. * log_norm) // variance_coeff for the multisum - * (precisions_sum // for hybrid packing - << (2 * (max_precision - 1))) as f64 // for left shift - ; + let lower_bound_variance = variance( + Some(shared_br_decomp.noise), + Some(pp_switching.noise), + None, + None, + ); + if lower_bound_variance > safe_variance_bound { + continue; + } + let lower_bound_complexity = complexity( + Some(shared_br_decomp.complexity), + Some(pp_switching.complexity), + None, + None, + ); + if lower_bound_complexity > best_complexity { + // saves ?? TODO + // next br_decomp are at least as costly + break; + } // CircuitBootstrap: new parameters l,b // for &circuit_pbs_decomposition in pareto_circuit_pbs { for cb_decomp in pareto_cb { - // Hybrid packing - let cb_level = cb_decomp.decomp.level; - // Circuit bs: fp-ks - let complexity_all_ppks = ((pbs_parameters.output_glwe_params.glwe_dimension + 1) - * cb_level - * precisions_sum) as f64 - * complexity_ppks; - - // Circuit bs: pbs - let complexity_all_pbs = - (precisions_sum * cb_level) as f64 * shared_br_decomp.complexity; - - let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks; - - if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks + complexity_circuit_bs - > best_complexity - { + let lower_bound_variance = variance( + Some(shared_br_decomp.noise), + Some(pp_switching.noise), + Some(cb_decomp), + None, + ); + if lower_bound_variance > safe_variance_bound { + continue; + } + let lower_bound_complexity = complexity( + Some(shared_br_decomp.complexity), + Some(pp_switching.complexity), + Some(cb_decomp), + None, + ); + if lower_bound_complexity > best_complexity { // saves 50% // next circuit_pbs_decomposition_parameter are at least as costly break; } - - // Hybrid packing - let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp; - - // Hybrid packing (Do we have 1 or 2 groups) - let log2_polynomial_size = pbs_parameters.output_glwe_params.log2_polynomial_size; - // Size of cmux_group, can be zero - let cmux_group_count = if precisions_sum > log2_polynomial_size { - 2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32) - } else { - 0.0 - }; - let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp; - - let complexity_one_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft; - - let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_one_ggsw_to_fft; - - // Hybrid packing blind rotate - let complexity_g_br = complexity_1_cmux_hp - * u64::min( - pbs_parameters.output_glwe_params.log2_polynomial_size, - precisions_sum, - ) as f64; - - let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br; - - let complexity_multi_hybrid_packing = - n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft; - // Cutting on complexity here is counter-productive probably because complexity_multi_hybrid_packing is small - - let variance_one_external_product_for_cmux_tree = - cb_decomp.variance_from_ggsw(variance_ggsw); - - // final out noise hybrid packing - let variance_after_1st_bit_extract = - variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree; - - let variance_wo_ks = variance_modulus_switching + variance_after_1st_bit_extract; - // Shared by all pbs (like brs) for ks_decomp in pareto_keyswitch.iter().rev() { - let variance_keyswitch = ks_decomp.noise; - let variance_max = variance_wo_ks + variance_keyswitch; + let variance_max = variance( + Some(shared_br_decomp.noise), + Some(pp_switching.noise), + Some(cb_decomp), + Some(ks_decomp.noise), + ); if variance_max > safe_variance_bound { // saves 40% break; } - let complexity_all_ks = precisions_sum as f64 * ks_decomp.complexity; - let complexity = complexity_bit_extract_wo_ks - + complexity_circuit_bs - + complexity_multi_hybrid_packing - + complexity_all_ks; + let complexity = complexity( + Some(shared_br_decomp.complexity), + Some(pp_switching.complexity), + Some(cb_decomp), + Some(ks_decomp.complexity), + ); if complexity > best_complexity { continue; @@ -294,19 +405,15 @@ fn update_state_with_best_decompositions( best_complexity = complexity; best_variance = variance_max; let p_error = find_p_error(kappa, safe_variance_bound, variance_max); - let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension(); - let glwe_polynomial_size = glwe_params.polynomial_size(); - let glwe_dimension = glwe_params.glwe_dimension; - let ks_decomposition_parameter = ks_decomp.decomp; state.best_solution = Some(Solution { - input_lwe_dimension, + input_lwe_dimension: glwe_params.sample_extract_lwe_dimension(), internal_ks_output_lwe_dimension: internal_dim, - ks_decomposition_level_count: ks_decomposition_parameter.level, - ks_decomposition_base_log: ks_decomposition_parameter.log2_base, - glwe_polynomial_size, - glwe_dimension, - br_decomposition_level_count: br_decomposition_parameter.level, - br_decomposition_base_log: br_decomposition_parameter.log2_base, + ks_decomposition_level_count: ks_decomp.decomp.level, + ks_decomposition_base_log: ks_decomp.decomp.log2_base, + glwe_polynomial_size: glwe_params.polynomial_size(), + glwe_dimension: glwe_params.glwe_dimension, + br_decomposition_level_count: shared_br_decomp.decomp.level, + br_decomposition_base_log: shared_br_decomp.decomp.log2_base, noise_max: variance_max, complexity, p_error,