diff --git a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs index 4a5e418b1..d35f83e33 100644 --- a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs +++ b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs @@ -86,7 +86,7 @@ impl Solution { #[derive(Debug)] pub struct Tab { pbs: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>, - modulus_switching: HashMap<(u64, u64), (f64, f64)>, + modulus_switching: HashMap<(u64, u64), f64>, key_switching: HashMap<(u64, u64, u64), Vec<(KsDecompositionParameters, (f64, Complexity))>>, // NEW VALUE MEMOIZED pp_switching: HashMap<(u64, u64, u64, u64), (f64, Complexity)>, @@ -143,10 +143,8 @@ pub fn tabulate_circuit_bootstrap( glwe_params.polynomial_size(), ) .get_variance(); - let _ = noise_cost_modulus_switching.insert( - (internal_dim, glwe_poly_size), - (noise_modulus_switching, 0.), - ); + let _ = noise_cost_modulus_switching + .insert((internal_dim, glwe_poly_size), noise_modulus_switching); for &br_decomposition_parameter in BR_BL.iter() { let pbs_parameters = PbsParameters { @@ -206,7 +204,6 @@ pub fn tabulate_circuit_bootstrap( } std::mem::drop(noise_cost_key_switching.insert(macro_key, ks_seq)); - // let pp_ks_decomposition_parameter = pbs_parameters.br_decomposition_parameter; for &pp_ks_decomposition_parameter in BR_BL.iter() { let ppks_parameter = PbsParameters { internal_lwe_dimension: LweDimension( @@ -274,15 +271,15 @@ pub fn optimise_one_with_memo( glwe_log_polynomial_sizes: &[u64], glwe_dimensions: &[u64], internal_lwe_dimensions: &[u64], - n_functions_log: u64, // Many functions at the same time, stay at 1 for start + n_functions: u64, // Many functions at the same time, stay at 1 for start memo: &Tab, n_inputs: u64, // Tau (nb blocks) ) -> OptimizationState { - assert!(n_functions_log == 0); // update complexity scaling assert!(0.0 < maximum_acceptable_error_probability); assert!(maximum_acceptable_error_probability < 1.0); let ciphertext_modulus_log = W::BITS as u64; + let global_precision = n_inputs * precision; // Circuit BS bound // 1 bit of message only here =) @@ -316,8 +313,10 @@ pub fn optimise_one_with_memo( glwe_dimension: glwe_dim, }; + let input_lwe_dimension = glwe_params.lwe_dimension(); + for &internal_dim in internal_lwe_dimensions { - let &(noise_modulus_switching, _) = memo + let &noise_modulus_switching = memo .modulus_switching .get(&(internal_dim, glwe_poly_size)) .expect(&format!( @@ -325,6 +324,10 @@ pub fn optimise_one_with_memo( internal_dim, glwe_poly_size )); + if noise_modulus_switching > variance_max { + continue; + } + let macro_key = (glwe_dim, glwe_poly_size, internal_dim); // BlindRotate dans Circuit BS @@ -362,10 +365,16 @@ pub fn optimise_one_with_memo( )) .unwrap(); + let complexity_bit_extract_wo_ks = + (n_inputs * (precision - 1)) as f64 * complexity_bit_extract_pbs; + + if complexity_bit_extract_wo_ks > best_complexity { + continue; + } + // private packing keyswitch, <=> FP-KS (Circuit Boostrap) let pp_ks_decomposition_parameter = pbs_parameters.br_decomposition_parameter; - // for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS // Circuit Boostrap let &(base_noise_private_packing_ks, complexity_ppks) = memo @@ -394,52 +403,63 @@ pub fn optimise_one_with_memo( output_glwe_params: pbs_parameters.output_glwe_params, }; - // Hybrid packing (rename) - let complexity_cmux_for_cb = DEFAULT_COMPLEXITY.pbs.complexity( + // Hybrid packing + let complexity_1_cmux_hp = DEFAULT_COMPLEXITY.pbs.complexity( cmux_tree_blind_rotate_parameters, ciphertext_modulus_log, ); // TODO: missing fft transform // Hybrid packing (Do we have 1 or 2 groups) - #[allow(clippy::precedence)] - let complexity_cmux_tree = if precision * n_inputs as u64 // sum of precisions - > pbs_parameters.output_glwe_params.log2_polynomial_size - { - // 2 groups - complexity_cmux_for_cb - * (1 << (precision * n_inputs - - pbs_parameters.output_glwe_params.log2_polynomial_size) - - 1) as f64 - // * (f64::exp2( - // (precision * n_inputs - // - pbs_parameters - // .output_glwe_params - // .log2_polynomial_size) - // as f64, - // ) - 1.) + let log2_polynomial_size = + pbs_parameters.output_glwe_params.log2_polynomial_size; + // Size of cmux_group, can be zero + let cmux_group_count = if global_precision > log2_polynomial_size { + 2f64.powi((global_precision - log2_polynomial_size - 1) as i32) } else { - // 1 group, no cmux tree - 0. + 0.0 }; + let complexity_cmux_tree = + cmux_group_count as f64 * complexity_1_cmux_hp; // Hybrid packing blind rotate - let complexity_g_br = complexity_cmux_for_cb - * f64::min( - (pbs_parameters.output_glwe_params.log2_polynomial_size) as f64, - (precision * n_inputs) as f64, - ); + let complexity_g_br = complexity_1_cmux_hp + * u64::min( + pbs_parameters.output_glwe_params.log2_polynomial_size, + global_precision, + ) as f64; - let noise_private_packing_ks = - base_noise_private_packing_ks + base_noise / 2.; + let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br; + let complexity_multi_hybrid_packing = + n_functions as f64 * complexity_hybrid_packing; - // Circuit Boostrap - if noise_private_packing_ks + noise_modulus_switching > variance_max - || (precision - 1) as f64 * complexity_pbs + complexity_ppks - > best_complexity + // Circuit bs: fp-ks + let complexity_all_ppks = + ((pbs_parameters.output_glwe_params.glwe_dimension + 1) + * circuit_pbs_decomposition_parameter.level + * precision + * n_inputs) as f64 + * complexity_ppks; + + // Circuit bs: pbs + let complexity_all_pbs = + (n_inputs * precision * circuit_pbs_decomposition_parameter.level) + as f64 + * complexity_pbs; + + let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks; + + if complexity_bit_extract_wo_ks + complexity_circuit_bs + > best_complexity { continue; } - let noise_ggsw = noise_private_packing_ks; + let noise_ggsw = base_noise_private_packing_ks + base_noise / 2.; + + // Circuit Boostrap + let noise_hybrid_packing = noise_modulus_switching + noise_ggsw; + if noise_hybrid_packing > variance_max { + continue; + } let noise_one_external_product_for_cmux_tree = noise_atomic_pattern::variance_bootstrap::( @@ -449,13 +469,6 @@ pub fn optimise_one_with_memo( ) .get_variance(); - // all fp-ks - let complexity_all_ppks = - ((pbs_parameters.output_glwe_params.glwe_dimension + 1) - * circuit_pbs_decomposition_parameter.level - * precision) as f64 - * complexity_ppks; - // final out noise hybrid packing let noise_cmux_tree_blind_rotate = noise_one_external_product_for_cmux_tree @@ -464,94 +477,64 @@ pub fn optimise_one_with_memo( let noise_multisum = (2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights + let noise_all_multisum = + noise_multisum * (1 << (2 * (precision - 1))) as f64; + + let noise_ggsw_reencoding = + noise_modulus_switching + noise_all_multisum; + if noise_ggsw_reencoding > variance_max { + continue; + } + + let noise_max = noise_ggsw_reencoding.max(noise_hybrid_packing); + // Shared by all pbs (like brs) let key_switching_q = memo.key_switching.get(¯o_key).unwrap(); for &( ks_decomposition_parameter, (noise_keyswitch, complexity_keyswitch), - ) in key_switching_q.iter() + ) in key_switching_q { - let keyswitch_parameter = KeyswitchParameters { - input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim), - output_lwe_dimension: LweDimension(internal_dim), - ks_decomposition_parameter, - }; - let complexity_all_ks = precision as f64 * complexity_keyswitch; - if noise_private_packing_ks - + noise_modulus_switching - + noise_keyswitch - > variance_max - || (precision - 1) as f64 * complexity_pbs - + complexity_ppks - + precision as f64 * complexity_keyswitch - > best_complexity - { + let noise_max = noise_max + noise_keyswitch; + if noise_max > variance_max { continue; } - // noise_multisum = dot - let current_maximal_noise = noise_multisum - * (1 << (2 * (precision - 1))) as f64 - + noise_keyswitch - + noise_modulus_switching; + let complexity_all_ks = + (precision * n_inputs) as f64 * complexity_keyswitch; + let complexity_bit_extract = + complexity_bit_extract_wo_ks + complexity_all_ks; - let complexity_all_pbs = - (precision * circuit_pbs_decomposition_parameter.level) as f64 - * complexity_pbs - + (precision - 1) as f64 * complexity_bit_extract_pbs; + let complexity_ggsw_reencoding = + complexity_bit_extract + complexity_circuit_bs; - let complexity_bias = - (complexity_all_ppks + complexity_all_pbs + complexity_all_ks) - * n_inputs as f64; + let complexity = + complexity_ggsw_reencoding + complexity_multi_hybrid_packing; - let complexity_slope = complexity_cmux_tree + complexity_g_br; - - let current_complexity = complexity_slope - * (1 << n_functions_log) as f64 - + complexity_bias; - - if current_complexity > best_complexity { - // next level is more costly + if complexity > best_complexity { + // next ks.level will be even more costly break; } - if current_complexity < best_complexity - && current_maximal_noise < variance_max - { - best_complexity = current_complexity; + if complexity < best_complexity { + best_complexity = complexity; + let p_error = find_p_error(kappa, variance_max, noise_max); state.best_solution = Some(Solution { - input_lwe_dimension: pbs_parameters - .output_glwe_params - .glwe_dimension - * pbs_parameters.output_glwe_params.polynomial_size(), - internal_ks_output_lwe_dimension: keyswitch_parameter - .output_lwe_dimension - .0, - ks_decomposition_level_count: keyswitch_parameter - .ks_decomposition_parameter + input_lwe_dimension, + internal_ks_output_lwe_dimension: internal_dim, + ks_decomposition_level_count: ks_decomposition_parameter .level, - ks_decomposition_base_log: keyswitch_parameter - .ks_decomposition_parameter + ks_decomposition_base_log: ks_decomposition_parameter .log2_base, - glwe_polynomial_size: pbs_parameters - .output_glwe_params - .polynomial_size(), - glwe_dimension: pbs_parameters - .output_glwe_params - .glwe_dimension, - br_decomposition_level_count: pbs_parameters - .br_decomposition_parameter + glwe_polynomial_size: glwe_poly_size, + glwe_dimension: glwe_dim, + br_decomposition_level_count: br_decomposition_parameter .level, - br_decomposition_base_log: pbs_parameters - .br_decomposition_parameter + br_decomposition_base_log: br_decomposition_parameter .log2_base, - noise_max: current_maximal_noise, - complexity: current_complexity, - p_error: find_p_error( - kappa, - variance_max, - current_maximal_noise, - ), // consts.maximum_acceptable_error_probability, + noise_max, + complexity, + p_error, cb_decomposition_level_count: Some( circuit_pbs_decomposition_parameter.level, ), @@ -571,6 +554,30 @@ pub fn optimise_one_with_memo( state } +// Default heuristic to split in several word +pub fn default_partitionning(precision: u64) -> Vec { + #[allow(clippy::match_same_arms)] + match precision { + 1 => vec![1], + 2 => vec![2], + 3 => vec![2; 2], + 4 => vec![3; 2], + 5 => vec![3; 2], + 6 => vec![3; 3], + 7 => vec![3; 3], + 8 => vec![3; 3], + 9 => vec![4; 3], + 10 => vec![4; 3], + 11 => vec![4; 3], + 12 => vec![4; 4], + 13 => vec![4; 4], + 14 => vec![4; 4], + 15 => vec![4; 4], + 16 => vec![5; 4], + _ => vec![5; (precision / 5) as usize], + } +} + #[allow(clippy::too_many_lines)] pub fn optimize_one( _sum_size: u64, @@ -583,33 +590,11 @@ pub fn optimize_one( internal_lwe_dimensions: &[u64], memo_opt: &mut Option, ) -> atomic_pattern::OptimizationState { - // Basic heuristic to split in several word - let no_sol = atomic_pattern::OptimizationState { - best_solution: None, - count_domain: 0, - }; - #[allow(clippy::match_same_arms)] - let (nb_words, max_word_precision) = match precision { - 1 => (1, 1), - 2 => (1, 2), - 3 => (2, 2), - 4 => (2, 3), - 5 => (2, 3), - 6 => (3, 3), - 7 => (3, 3), - 8 => (3, 3), - 9 => (3, 4), - 10 => (3, 4), - 11 => (3, 4), - 12 => (4, 4), - 13 => (4, 4), - 14 => (4, 4), - 15 => (4, 4), - 16 => (4, 5), - _ => return no_sol, - }; + let partitionning = default_partitionning(precision); + let nb_words = partitionning.len() as u64; + let max_word_precision = *partitionning.iter().max().unwrap() as u64; let log_norm = noise_factor.log2(); - let n_functions_log = 0; + let n_functions = 1; let memo = memo_opt.get_or_insert_with(|| { tabulate_circuit_bootstrap::( security_level, @@ -627,7 +612,7 @@ pub fn optimize_one( glwe_log_polynomial_sizes, glwe_dimensions, internal_lwe_dimensions, - n_functions_log, + n_functions, memo, nb_words, // Tau ); diff --git a/concrete-optimizer/src/parameters.rs b/concrete-optimizer/src/parameters.rs index c36c79b29..b348b95d0 100644 --- a/concrete-optimizer/src/parameters.rs +++ b/concrete-optimizer/src/parameters.rs @@ -26,6 +26,9 @@ mod individual { pub fn polynomial_size(self) -> u64 { 1 << self.log2_polynomial_size } + pub fn lwe_dimension(self) -> u64 { + self.glwe_dimension << self.log2_polynomial_size + } } #[derive(Clone, Copy)]