use super::crt_decomposition; use crate::computing_cost::complexity::Complexity; use crate::computing_cost::operators::cmux; use crate::dag::operator::Precision; use crate::noise_estimator::error::{ error_probability_of_sigma_scale, safe_variance_bound_1bit_1padbit, sigma_scale_of_error_probability, }; use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; use crate::noise_estimator::operators::wop_atomic_pattern::estimate_packing_private_keyswitch; use crate::optimization::atomic_pattern; use crate::optimization::atomic_pattern::{ cutted_blind_rotate, pareto_keyswitch, ComplexityNoise, OptimizationDecompositionsConsts, }; use crate::optimization::config::{Config, SearchSpace}; use crate::optimization::wop_atomic_pattern::pareto::{ BR_CIRCUIT_BOOTSTRAP_PARETO_DECOMP, BR_PARETO_DECOMP, KS_CIRCUIT_BOOTSTRAP_PARETO_DECOMP, KS_PARETO_DECOMP, }; use crate::parameters::{ GlweParameters, KeyswitchParameters, KsDecompositionParameters, LweDimension, PbsParameters, }; use crate::security; use crate::utils::square; use concrete_commons::dispersion::{DispersionParameter, Variance}; use concrete_commons::numeric::UnsignedInteger; pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 { let sigma = Variance(variance_bound).get_standard_dev() * kappa; let sigma_scale = sigma / Variance(current_maximum_noise).get_standard_dev(); error_probability_of_sigma_scale(sigma_scale) } #[derive(Clone, Debug)] pub struct OptimizationState { pub best_solution: Option, pub count_domain: usize, } #[derive(Clone, Debug)] pub struct Solution { pub input_lwe_dimension: u64, //n_big pub internal_ks_output_lwe_dimension: u64, //n_small pub ks_decomposition_level_count: u64, //l(KS) pub ks_decomposition_base_log: u64, //b(KS) pub glwe_polynomial_size: u64, //N pub glwe_dimension: u64, //k pub br_decomposition_level_count: u64, //l(BR) pub br_decomposition_base_log: u64, //b(BR) pub complexity: f64, pub noise_max: f64, pub p_error: f64, pub global_p_error: f64, // error probability pub cb_decomposition_level_count: u64, pub cb_decomposition_base_log: u64, pub crt_decomposition: Vec, } impl Solution { pub fn init() -> Self { Self { input_lwe_dimension: 0, internal_ks_output_lwe_dimension: 0, ks_decomposition_level_count: 0, ks_decomposition_base_log: 0, glwe_polynomial_size: 0, glwe_dimension: 0, br_decomposition_level_count: 0, br_decomposition_base_log: 0, complexity: 0., noise_max: 0.0, p_error: 0.0, global_p_error: 0.0, cb_decomposition_level_count: 0, cb_decomposition_base_log: 0, crt_decomposition: vec![], } } } impl From for atomic_pattern::Solution { fn from(sol: Solution) -> Self { Self { input_lwe_dimension: sol.input_lwe_dimension, internal_ks_output_lwe_dimension: sol.internal_ks_output_lwe_dimension, ks_decomposition_level_count: sol.ks_decomposition_level_count, ks_decomposition_base_log: sol.ks_decomposition_base_log, glwe_polynomial_size: sol.glwe_polynomial_size, glwe_dimension: sol.glwe_dimension, br_decomposition_level_count: sol.br_decomposition_level_count, br_decomposition_base_log: sol.br_decomposition_base_log, complexity: sol.complexity, noise_max: sol.noise_max, p_error: sol.p_error, global_p_error: sol.global_p_error, } } } #[derive(Debug)] struct NoiseCostByMicroParam { cutted_blind_rotate: Vec, pareto_keyswitch: Vec, pp_switching: Vec<(f64, Complexity)>, } fn compute_noise_cost_by_micro_param( consts: &OptimizationDecompositionsConsts, glwe_params: GlweParameters, internal_dim: u64, best_complexity: f64, variance_modulus_switching: f64, precision: u64, n_inputs: u64, ) -> Option { let security_level = consts.config.security_level; let variance_coeff = square(consts.noise_factor) / 2.0; let complexity_coeff = (n_inputs * (2 * precision - 1)) as f64; let cut_complexity = best_complexity / complexity_coeff; // saves 0% let cut_variance = (consts.safe_variance - variance_modulus_switching) / variance_coeff; // saves 40% let cutted_blind_rotate = cutted_blind_rotate::( consts, internal_dim, glwe_params, cut_complexity, cut_variance, ); if cutted_blind_rotate.is_empty() { return None; } let variance_coeff_br = variance_coeff; let variance_coeff = 1.0; let complexity_coeff = (precision * n_inputs) as f64; let cut_complexity = best_complexity / complexity_coeff; // saves 0% let cut_variance = (consts.safe_variance - variance_modulus_switching - variance_coeff_br * cutted_blind_rotate.last().unwrap().noise) / variance_coeff; // saves 25% let input_dim = glwe_params.sample_extract_lwe_dimension(); let pareto_keyswitch = pareto_keyswitch::( consts, input_dim, internal_dim, cut_complexity, cut_variance, ); if pareto_keyswitch.is_empty() { return None; } let ciphertext_modulus_log = W::BITS as u64; let variance_bsk = security::glwe::minimal_variance(glwe_params, ciphertext_modulus_log, security_level); let mut variance_cost_pp_switching = vec![(f64::NAN, f64::NAN); BR_PARETO_DECOMP.len()]; for br in &cutted_blind_rotate { // saves 0% let pp_ks_decomposition_parameter = BR_PARETO_DECOMP[br.index]; let ppks_parameter = PbsParameters { internal_lwe_dimension: LweDimension( glwe_params.glwe_dimension * glwe_params.polynomial_size(), ), br_decomposition_parameter: pp_ks_decomposition_parameter, output_glwe_params: glwe_params, }; // We assume the packing KS and the external product in a PBSto have // the same parameters (base, level) let variance_private_packing_ks = estimate_packing_private_keyswitch::(Variance(0.), variance_bsk, ppks_parameter) .get_variance(); let ppks_parameter_complexity = KeyswitchParameters { input_lwe_dimension: LweDimension( glwe_params.glwe_dimension * glwe_params.polynomial_size(), ), output_lwe_dimension: LweDimension( glwe_params.glwe_dimension * glwe_params.polynomial_size(), ), ks_decomposition_parameter: KsDecompositionParameters { level: pp_ks_decomposition_parameter.level, log2_base: pp_ks_decomposition_parameter.log2_base, }, }; let complexity_ppks = consts .config .complexity_model .ks_complexity(ppks_parameter_complexity, ciphertext_modulus_log); variance_cost_pp_switching[br.index] = (variance_private_packing_ks, complexity_ppks); } Some(NoiseCostByMicroParam { cutted_blind_rotate, pareto_keyswitch, pp_switching: variance_cost_pp_switching, }) } #[allow(clippy::too_many_lines)] fn update_state_with_best_decompositions( state: &mut OptimizationState, consts: &OptimizationDecompositionsConsts, glwe_params: GlweParameters, internal_dim: u64, n_functions: u64, precision: u64, n_inputs: u64, // Tau ) { let ciphertext_modulus_log = consts.config.ciphertext_modulus_log; let global_precision = n_inputs * precision; let safe_variance_bound = consts.safe_variance; let log_norm = consts.noise_factor.log2(); let variance_modulus_switching = noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key::( internal_dim, glwe_params.polynomial_size(), ) .get_variance(); if variance_modulus_switching > consts.safe_variance { return; } let mut best_complexity = state .best_solution .as_ref() .map_or(f64::INFINITY, |s| s.complexity); let mut best_variance = state .best_solution .as_ref() .map_or(f64::INFINITY, |s| s.noise_max); let variance_cost_opt = compute_noise_cost_by_micro_param::( consts, glwe_params, internal_dim, best_complexity, variance_modulus_switching, precision, n_inputs, ); let variance_cost = if let Some(variance_cost) = variance_cost_opt { variance_cost } else { return; }; // pareto keyswitch is sorted by complexity increasing and variance decreasing let lower_bound_variance_keyswitch = variance_cost.pareto_keyswitch[variance_cost.pareto_keyswitch.len() - 1].noise; let lower_bound_complexity_all_ks = (precision * n_inputs) as f64 * variance_cost.pareto_keyswitch[0].complexity; // BlindRotate dans Circuit BS for shared_br_decomp in &variance_cost.cutted_blind_rotate { // Pbs dans BitExtract et Circuit BS et FP-KS (partagés) let br_decomposition_parameter = consts.blind_rotate_decompositions[shared_br_decomp.index]; let pbs_parameters = PbsParameters { internal_lwe_dimension: LweDimension(internal_dim), br_decomposition_parameter, output_glwe_params: glwe_params, }; // BitExtract use this pbs let complexity_bit_extract_1_pbs = shared_br_decomp.complexity; let complexity_bit_extract_wo_ks = (n_inputs * (precision - 1)) as f64 * complexity_bit_extract_1_pbs; if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks > best_complexity { // saves 0% // next br_decomp are at least as costly break; } // Circuit Boostrap // private packing keyswitch, <=> FP-KS let pp_switching_index = shared_br_decomp.index; let (base_variance_private_packing_ks, complexity_ppks) = variance_cost.pp_switching[pp_switching_index]; let variance_ggsw = base_variance_private_packing_ks + shared_br_decomp.noise / 2.; let variance_coeff_1_cmux_tree = 2_f64.powf(2. * log_norm as f64) // variance_coeff for the multisum * (global_precision // for hybrid packing << (2 * (precision - 1))) as f64 // for left shift ; // CircuitBootstrap: new parameters l,b for &circuit_pbs_decomposition_parameter in BR_CIRCUIT_BOOTSTRAP_PARETO_DECOMP.iter() { // Hybrid packing let nb_cmux = 1_u64; let cmux_tree_blind_rotate_parameters = PbsParameters { internal_lwe_dimension: LweDimension(nb_cmux), // complexity for 1 cmux br_decomposition_parameter: circuit_pbs_decomposition_parameter, output_glwe_params: pbs_parameters.output_glwe_params, }; // Circuit bs: fp-ks let complexity_all_ppks = ((pbs_parameters.output_glwe_params.glwe_dimension + 1) * circuit_pbs_decomposition_parameter.level * global_precision) as f64 * complexity_ppks; // Circuit bs: pbs let complexity_all_pbs = (global_precision * circuit_pbs_decomposition_parameter.level) as f64 * shared_br_decomp.complexity; let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks; if complexity_bit_extract_wo_ks + lower_bound_complexity_all_ks + complexity_circuit_bs > best_complexity { // saves 50% // next circuit_pbs_decomposition_parameter are at least as costly break; } // Hybrid packing let complexity_1_cmux_hp = consts .config .complexity_model .pbs_complexity(cmux_tree_blind_rotate_parameters, ciphertext_modulus_log); // TODO: missing fft transform // Hybrid packing (Do we have 1 or 2 groups) let log2_polynomial_size = pbs_parameters.output_glwe_params.log2_polynomial_size; // Size of cmux_group, can be zero let cmux_group_count = if global_precision > log2_polynomial_size { 2f64.powi((global_precision - log2_polynomial_size - 1) as i32) } else { 0.0 }; let complexity_cmux_tree = cmux_group_count as f64 * complexity_1_cmux_hp; let cmux_complexity = cmux::SimpleWithFactors::default(); let f_glwe_poly_size = glwe_params.polynomial_size() as f64; let f_glwe_size = (glwe_params.glwe_dimension + 1) as f64; let complexity_one_ggsw_to_fft = square(f_glwe_size) * circuit_pbs_decomposition_parameter.level as f64 * cmux_complexity.fft_complexity(f_glwe_poly_size, ciphertext_modulus_log); let complexity_all_ggsw_to_fft = (1 << global_precision) as f64 * complexity_one_ggsw_to_fft; // Hybrid packing blind rotate let complexity_g_br = complexity_1_cmux_hp * u64::min( pbs_parameters.output_glwe_params.log2_polynomial_size, global_precision, ) as f64; let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br; let complexity_multi_hybrid_packing = n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft; // Cutting on complexity here is counter-productive probably because complexity_multi_hybrid_packing is small let variance_one_external_product_for_cmux_tree = noise_atomic_pattern::variance_bootstrap::( cmux_tree_blind_rotate_parameters, ciphertext_modulus_log, Variance::from_variance(variance_ggsw), ) .get_variance(); // final out noise hybrid packing let variance_after_1st_bit_extract = variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree; let variance_wo_ks = variance_modulus_switching + variance_after_1st_bit_extract; if variance_wo_ks + lower_bound_variance_keyswitch > safe_variance_bound { // saves 40% continue; } // Shared by all pbs (like brs) for ks_decomp in &variance_cost.pareto_keyswitch { let variance_keyswitch = ks_decomp.noise; let variance_max = variance_wo_ks + variance_keyswitch; if variance_max > safe_variance_bound { continue; } let complexity_all_ks = (precision * n_inputs) as f64 * ks_decomp.complexity; let complexity = complexity_bit_extract_wo_ks + complexity_circuit_bs + complexity_multi_hybrid_packing + complexity_all_ks; if complexity > best_complexity { // next ks.level will be even more costly break; } #[allow(clippy::float_cmp)] if complexity == best_complexity && variance_max > best_variance { continue; } let kappa = consts.kappa; best_complexity = complexity; best_variance = variance_max; let p_error = find_p_error(kappa, safe_variance_bound, variance_max); let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension(); let glwe_polynomial_size = glwe_params.polynomial_size(); let glwe_dimension = glwe_params.glwe_dimension; let ks_decomposition_parameter = consts.keyswitch_decompositions[ks_decomp.index]; state.best_solution = Some(Solution { input_lwe_dimension, internal_ks_output_lwe_dimension: internal_dim, ks_decomposition_level_count: ks_decomposition_parameter.level, ks_decomposition_base_log: ks_decomposition_parameter.log2_base, glwe_polynomial_size, glwe_dimension, br_decomposition_level_count: br_decomposition_parameter.level, br_decomposition_base_log: br_decomposition_parameter.log2_base, noise_max: variance_max, complexity, p_error, global_p_error: f64::NAN, cb_decomposition_level_count: circuit_pbs_decomposition_parameter.level, cb_decomposition_base_log: circuit_pbs_decomposition_parameter.log2_base, crt_decomposition: vec![], }); } } } } fn optimize_raw( max_word_precision: u64, // max precision of a word log_norm: f64, // ?? norm2 of noise multisum, complexity of multisum is neglected config: Config, search_space: &SearchSpace, n_functions: u64, // Many functions at the same time, stay at 1 for start n_inputs: u64, // Tau (nb blocks) ) -> OptimizationState { assert!(0.0 < config.maximum_acceptable_error_probability); assert!(config.maximum_acceptable_error_probability < 1.0); let ciphertext_modulus_log = W::BITS as u64; // Circuit BS bound // 1 bit of message only here =) // Bound for first bit extract in BitExtract (dominate others) let safe_variance_bound = safe_variance_bound_1bit_1padbit( ciphertext_modulus_log, config.maximum_acceptable_error_probability, ); let kappa: f64 = sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); let mut state = OptimizationState { best_solution: None, count_domain: search_space.glwe_dimensions.len() * search_space.glwe_log_polynomial_sizes.len() * search_space.internal_lwe_dimensions.len() * KS_PARETO_DECOMP.len() * BR_PARETO_DECOMP.len(), }; let consts = OptimizationDecompositionsConsts { config, kappa, sum_size: 1, // Ignored noise_factor: log_norm.exp2(), keyswitch_decompositions: KS_CIRCUIT_BOOTSTRAP_PARETO_DECOMP.to_vec(), blind_rotate_decompositions: BR_PARETO_DECOMP.to_vec(), safe_variance: safe_variance_bound, }; for &glwe_dim in &search_space.glwe_dimensions { for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes { let input_lwe_dimension = glwe_dim << glwe_log_poly_size; // Manual experimental CUT if input_lwe_dimension > 1 << 13 { continue; } let glwe_params = GlweParameters { log2_polynomial_size: glwe_log_poly_size, glwe_dimension: glwe_dim, }; for &internal_dim in &search_space.internal_lwe_dimensions { update_state_with_best_decompositions::( &mut state, &consts, glwe_params, internal_dim, n_functions, max_word_precision, n_inputs, ); } } } state } pub fn optimize_one( precision: u64, config: Config, log_norm: f64, search_space: &SearchSpace, ) -> OptimizationState { let coprimes = crt_decomposition::default_coprimes(precision as Precision); let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes); let nb_words = partitionning.len() as u64; let max_word_precision = *partitionning.iter().max().unwrap() as u64; let n_functions = 1; let mut state = optimize_raw::( max_word_precision, log_norm, config, search_space, n_functions, nb_words, // Tau ); state.best_solution = state.best_solution.map(|mut sol| -> Solution { sol.crt_decomposition = coprimes; sol }); state } pub fn optimize_one_compat( _sum_size: u64, precision: u64, config: Config, noise_factor: f64, search_space: &SearchSpace, ) -> atomic_pattern::OptimizationState { let log_norm = noise_factor.log2(); let result = optimize_one::(precision, config, log_norm, search_space); atomic_pattern::OptimizationState { best_solution: result.best_solution.map(Solution::into), count_domain: result.count_domain, } }