use super::crt_decomposition; use crate::dag::operator::Precision; use crate::noise_estimator::error::{ error_probability_of_sigma_scale, safe_variance_bound_product_1padbit, sigma_scale_of_error_probability, }; use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; use crate::optimization::atomic_pattern; use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts; use crate::optimization::config::{Config, SearchSpace}; use crate::optimization::decomposition::circuit_bootstrap::CbComplexityNoise; use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches}; use crate::parameters::{BrDecompositionParameters, GlweParameters}; use concrete_commons::dispersion::{DispersionParameter, Variance}; pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 { let sigma = Variance(variance_bound).get_standard_dev() * kappa; let sigma_scale = sigma / Variance(current_maximum_noise).get_standard_dev(); error_probability_of_sigma_scale(sigma_scale) } #[derive(Clone, Debug)] pub struct OptimizationState { pub best_solution: Option, } #[derive(Clone, Debug)] pub struct Solution { pub input_lwe_dimension: u64, //n_big pub internal_ks_output_lwe_dimension: u64, //n_small pub ks_decomposition_level_count: u64, //l(KS) pub ks_decomposition_base_log: u64, //b(KS) pub glwe_polynomial_size: u64, //N pub glwe_dimension: u64, //k pub br_decomposition_level_count: u64, //l(BR) pub br_decomposition_base_log: u64, //b(BR) pub complexity: f64, pub noise_max: f64, pub p_error: f64, pub global_p_error: f64, // error probability pub cb_decomposition_level_count: u64, pub cb_decomposition_base_log: u64, pub crt_decomposition: Vec, } impl Solution { pub fn init() -> Self { Self { input_lwe_dimension: 0, internal_ks_output_lwe_dimension: 0, ks_decomposition_level_count: 0, ks_decomposition_base_log: 0, glwe_polynomial_size: 0, glwe_dimension: 0, br_decomposition_level_count: 0, br_decomposition_base_log: 0, complexity: 0., noise_max: 0.0, p_error: 0.0, global_p_error: 0.0, cb_decomposition_level_count: 0, cb_decomposition_base_log: 0, crt_decomposition: vec![], } } } impl From for atomic_pattern::Solution { fn from(sol: Solution) -> Self { Self { input_lwe_dimension: sol.input_lwe_dimension, internal_ks_output_lwe_dimension: sol.internal_ks_output_lwe_dimension, ks_decomposition_level_count: sol.ks_decomposition_level_count, ks_decomposition_base_log: sol.ks_decomposition_base_log, glwe_polynomial_size: sol.glwe_polynomial_size, glwe_dimension: sol.glwe_dimension, br_decomposition_level_count: sol.br_decomposition_level_count, br_decomposition_base_log: sol.br_decomposition_base_log, complexity: sol.complexity, noise_max: sol.noise_max, p_error: sol.p_error, global_p_error: sol.global_p_error, } } } fn estimate_variance( br_variance: f64, pp_variance: f64, cb_decomp: &CbComplexityNoise, ks_variance: f64, variance_modulus_switching: f64, log_norm: f64, precisions_sum: u64, max_precision: u64, ) -> f64 { assert!(max_precision <= precisions_sum); let variance_ggsw = pp_variance + br_variance / 2.; let variance_coeff_1_cmux_tree = 2_f64.powf(2. * log_norm) // variance_coeff for the multisum * (precisions_sum // for hybrid packing << (2 * (max_precision - 1))) as f64 // for left shift ; let variance_one_external_product_for_cmux_tree = cb_decomp.variance_from_ggsw(variance_ggsw); variance_modulus_switching + variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree + ks_variance } fn estimate_complexity( glwe_params: &GlweParameters, br_cost: f64, pp_cost: f64, cb_decomp: &CbComplexityNoise, ks_cost: f64, precisions_sum: u64, nb_blocks: u64, n_functions: u64, ) -> f64 { // Pbs dans BitExtract et Circuit BS et FP-KS (partagés) // Hybrid packing let cb_level = cb_decomp.decomp.level; let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp; let complexity_1_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft; // BitExtract use br let complexity_bit_extract_1_pbs = br_cost; let complexity_bit_extract_wo_ks = (precisions_sum - nb_blocks) as f64 * complexity_bit_extract_1_pbs; // Hybrid packing // Circuit bs: fp-ks let complexity_ppks = pp_cost; let complexity_all_ppks = ((glwe_params.glwe_dimension + 1) * cb_level * precisions_sum) as f64 * complexity_ppks; // Circuit bs: pbs let complexity_all_pbs = (precisions_sum * cb_level) as f64 * br_cost; let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks; // Hybrid packing (Do we have 1 or 2 groups) let log2_polynomial_size = glwe_params.log2_polynomial_size; // Size of cmux_group, can be zero let cmux_group_count = if precisions_sum > log2_polynomial_size { 2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32) } else { 0.0 }; let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp; let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_1_ggsw_to_fft; // Hybrid packing blind rotate let complexity_g_br = complexity_1_cmux_hp * u64::min(glwe_params.log2_polynomial_size, precisions_sum) as f64; let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br; let complexity_multi_hybrid_packing = n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft; let complexity_all_ks = precisions_sum as f64 * ks_cost; complexity_bit_extract_wo_ks + complexity_circuit_bs + complexity_multi_hybrid_packing + complexity_all_ks } #[allow(clippy::too_many_lines)] fn update_state_with_best_decompositions( state: &mut OptimizationState, consts: &OptimizationDecompositionsConsts, glwe_params: GlweParameters, internal_dim: u64, n_functions: u64, partitionning: &[u64], caches: &mut DecompCaches, ) { let ciphertext_modulus_log = consts.config.ciphertext_modulus_log; let precisions_sum = partitionning.iter().copied().sum(); let max_precision = partitionning.iter().copied().max().unwrap(); let nb_blocks = partitionning.len() as u64; let safe_variance_bound = consts.safe_variance; let log_norm = consts.noise_factor.log2(); let variance_modulus_switching = noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key( internal_dim, glwe_params.polynomial_size(), ciphertext_modulus_log, ) .get_variance(); if variance_modulus_switching > consts.safe_variance { return; } let mut best_complexity = state .best_solution .as_ref() .map_or(f64::INFINITY, |s| s.complexity); let mut best_variance = state .best_solution .as_ref() .map_or(f64::INFINITY, |s| s.noise_max); let pareto_blind_rotate = caches .blind_rotate .pareto_quantities(glwe_params, internal_dim); let pareto_keyswitch = caches .keyswitch .pareto_quantities(glwe_params, internal_dim); let pp_switch = caches .pp_switch .pareto_quantities(glwe_params, internal_dim); let pareto_cb = caches.cb_pbs.pareto_quantities(glwe_params); let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise; let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise; let lower_bound_variance_private_packing = pp_switch.last().unwrap().noise; let lower_pareto_cb_bias = pareto_cb .iter() .map(|cb| cb.variance_bias) .reduce(f64::min) .unwrap(); let lower_pareto_cb_slope = pareto_cb .iter() .map(|cb| cb.variance_ggsw_factor) .reduce(f64::min) .unwrap(); let lower_bound_cost_blind_rotate = pareto_blind_rotate[0].complexity; let lower_bound_cost_keyswitch = pareto_keyswitch[0].complexity; let lower_bound_cost_pp = pp_switch[0].complexity; let lower_bound_cost_cb_complexity_1_cmux_hp = pareto_cb .iter() .map(|cb| cb.complexity_one_cmux_hp) .reduce(f64::min) .unwrap_or(0.0); let lower_bound_cost_cb_complexity_1_ggsw_to_fft = pareto_cb .iter() .map(|cb| cb.complexity_one_ggsw_to_fft) .reduce(f64::min) .unwrap_or(0.0); let lower_bound_cb = CbComplexityNoise { decomp: BrDecompositionParameters { level: 1, log2_base: 1, }, complexity_one_cmux_hp: lower_bound_cost_cb_complexity_1_cmux_hp, complexity_one_ggsw_to_fft: lower_bound_cost_cb_complexity_1_ggsw_to_fft, variance_bias: lower_pareto_cb_bias, variance_ggsw_factor: lower_pareto_cb_slope, }; let variance = |br_variance: Option<_>, pp_variance: Option<_>, cb_decomp: Option<&CbComplexityNoise>, ks_variance: Option<_>| { let br_variance = br_variance.unwrap_or(lower_bound_variance_blind_rotate); let pp_variance = pp_variance.unwrap_or(lower_bound_variance_private_packing); let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb); let ks_variance = ks_variance.unwrap_or(lower_bound_variance_keyswitch); estimate_variance( br_variance, pp_variance, cb_decomp, ks_variance, variance_modulus_switching, log_norm, precisions_sum, max_precision, ) }; let lower_bound_variance = variance(None, None, None, None); if lower_bound_variance > consts.safe_variance { // saves 20% return; } let complexity = |br_cost: Option<_>, pp_cost: Option<_>, cb_decomp: Option<&CbComplexityNoise>, ks_cost: Option<_>| { // Pbs dans BitExtract et Circuit BS et FP-KS (partagés) let br_cost = br_cost.unwrap_or(lower_bound_cost_blind_rotate); let ks_cost = ks_cost.unwrap_or(lower_bound_cost_keyswitch); let pp_cost = pp_cost.unwrap_or(lower_bound_cost_pp); let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb); estimate_complexity( &glwe_params, br_cost, pp_cost, cb_decomp, ks_cost, precisions_sum, nb_blocks, n_functions, ) }; // BlindRotate dans Circuit BS for (br_i, shared_br_decomp) in pareto_blind_rotate.iter().enumerate() { let lower_bound_variance = variance(Some(shared_br_decomp.noise), None, None, None); if lower_bound_variance > consts.safe_variance { // saves 20% continue; } // Circuit Boostrap // private packing keyswitch, <=> FP-KS let pp_switching_index = br_i; let pp_switching = pp_switch[pp_switching_index]; let lower_bound_variance = variance( Some(shared_br_decomp.noise), Some(pp_switching.noise), None, None, ); if lower_bound_variance > safe_variance_bound { continue; } let lower_bound_complexity = complexity( Some(shared_br_decomp.complexity), Some(pp_switching.complexity), None, None, ); if lower_bound_complexity > best_complexity { // saves ?? TODO // next br_decomp are at least as costly break; } // CircuitBootstrap: new parameters l,b // for &circuit_pbs_decomposition in pareto_circuit_pbs { for cb_decomp in pareto_cb { let lower_bound_variance = variance( Some(shared_br_decomp.noise), Some(pp_switching.noise), Some(cb_decomp), None, ); if lower_bound_variance > safe_variance_bound { continue; } let lower_bound_complexity = complexity( Some(shared_br_decomp.complexity), Some(pp_switching.complexity), Some(cb_decomp), None, ); if lower_bound_complexity > best_complexity { // saves 50% // next circuit_pbs_decomposition_parameter are at least as costly break; } // Shared by all pbs (like brs) for ks_decomp in pareto_keyswitch.iter().rev() { let variance_max = variance( Some(shared_br_decomp.noise), Some(pp_switching.noise), Some(cb_decomp), Some(ks_decomp.noise), ); if variance_max > safe_variance_bound { // saves 40% break; } let complexity = complexity( Some(shared_br_decomp.complexity), Some(pp_switching.complexity), Some(cb_decomp), Some(ks_decomp.complexity), ); if complexity > best_complexity { continue; } #[allow(clippy::float_cmp)] if complexity == best_complexity && variance_max > best_variance { continue; } let kappa = consts.kappa; best_complexity = complexity; best_variance = variance_max; let p_error = find_p_error(kappa, safe_variance_bound, variance_max); state.best_solution = Some(Solution { input_lwe_dimension: glwe_params.sample_extract_lwe_dimension(), internal_ks_output_lwe_dimension: internal_dim, ks_decomposition_level_count: ks_decomp.decomp.level, ks_decomposition_base_log: ks_decomp.decomp.log2_base, glwe_polynomial_size: glwe_params.polynomial_size(), glwe_dimension: glwe_params.glwe_dimension, br_decomposition_level_count: shared_br_decomp.decomp.level, br_decomposition_base_log: shared_br_decomp.decomp.log2_base, noise_max: variance_max, complexity, p_error, global_p_error: f64::NAN, cb_decomposition_level_count: cb_decomp.decomp.level, cb_decomposition_base_log: cb_decomp.decomp.log2_base, crt_decomposition: vec![], }); } } } } fn optimize_raw( log_norm: f64, // ?? norm2 of noise multisum, complexity of multisum is neglected config: Config, search_space: &SearchSpace, n_functions: u64, // Many functions at the same time, stay at 1 for start partitionning: &[u64], persistent_caches: &PersistDecompCaches, ) -> OptimizationState { assert!(0.0 < config.maximum_acceptable_error_probability); assert!(config.maximum_acceptable_error_probability < 1.0); assert!(!partitionning.is_empty()); let ciphertext_modulus_log = config.ciphertext_modulus_log; // Circuit BS bound // 1 bit of message only here =) // Bound for first bit extract in BitExtract (dominate others) let max_block_precision = *partitionning.iter().max().unwrap(); let safe_variance_bound = safe_variance_bound_product_1padbit( max_block_precision, ciphertext_modulus_log, config.maximum_acceptable_error_probability, ); let kappa: f64 = sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); let mut state = OptimizationState { best_solution: None, }; let consts = OptimizationDecompositionsConsts { config, kappa, sum_size: 1, // Ignored noise_factor: log_norm.exp2(), safe_variance: safe_variance_bound, }; let mut caches = persistent_caches.caches(); for &glwe_dim in &search_space.glwe_dimensions { for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes { let input_lwe_dimension = glwe_dim << glwe_log_poly_size; // Manual experimental CUT if input_lwe_dimension > 1 << 13 { continue; } let glwe_params = GlweParameters { log2_polynomial_size: glwe_log_poly_size, glwe_dimension: glwe_dim, }; for &internal_dim in &search_space.internal_lwe_dimensions { update_state_with_best_decompositions( &mut state, &consts, glwe_params, internal_dim, n_functions, partitionning, &mut caches, ); } } } persistent_caches.backport(caches); state } pub fn optimize_one( precision: u64, config: Config, log_norm: f64, search_space: &SearchSpace, caches: &PersistDecompCaches, ) -> OptimizationState { let coprimes = crt_decomposition::default_coprimes(precision as Precision); let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes); let n_functions = 1; let mut state = optimize_raw( log_norm, config, search_space, n_functions, &partitionning, caches, ); state.best_solution = state.best_solution.map(|mut sol| -> Solution { sol.crt_decomposition = coprimes; sol }); state } pub fn optimize_one_compat( _sum_size: u64, precision: u64, config: Config, noise_factor: f64, search_space: &SearchSpace, cache: &PersistDecompCaches, ) -> atomic_pattern::OptimizationState { let log_norm = noise_factor.log2(); let result = optimize_one(precision, config, log_norm, search_space, cache); atomic_pattern::OptimizationState { best_solution: result.best_solution.map(Solution::into), } }