diff --git a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs index 36355f0fb..4a5e418b1 100644 --- a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs +++ b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs @@ -87,7 +87,7 @@ impl Solution { pub struct Tab { pbs: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>, modulus_switching: HashMap<(u64, u64), (f64, f64)>, - key_switching: HashMap<(u64, u64, u64, u64, u64), (f64, Complexity)>, + key_switching: HashMap<(u64, u64, u64), Vec<(KsDecompositionParameters, (f64, Complexity))>>, // NEW VALUE MEMOIZED pp_switching: HashMap<(u64, u64, u64, u64), (f64, Complexity)>, } @@ -130,6 +130,8 @@ pub fn tabulate_circuit_bootstrap( for &internal_dim in internal_lwe_dimensions { assert!(256 < internal_dim); + let macro_key = (glwe_dim, glwe_poly_size, internal_dim); + let variance_ksk = noise_atomic_pattern::variance_ksk( internal_dim, ciphertext_modulus_log, @@ -180,6 +182,7 @@ pub fn tabulate_circuit_bootstrap( ); } + let mut ks_seq = Vec::with_capacity(KS_BL_FOR_CB.len()); for &ks_decomposition_parameter in KS_BL_FOR_CB.iter() { let keyswitch_parameter = KeyswitchParameters { input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim), @@ -196,17 +199,12 @@ pub fn tabulate_circuit_bootstrap( variance_ksk, ) .get_variance(); - let _ = noise_cost_key_switching.insert( - ( - glwe_dim, - glwe_poly_size, - internal_dim, - ks_decomposition_parameter.log2_base, - ks_decomposition_parameter.level, - ), + ks_seq.push(( + ks_decomposition_parameter, (noise_keyswitch, complexity_keyswitch), - ); + )); } + std::mem::drop(noise_cost_key_switching.insert(macro_key, ks_seq)); // let pp_ks_decomposition_parameter = pbs_parameters.br_decomposition_parameter; for &pp_ks_decomposition_parameter in BR_BL.iter() { @@ -327,6 +325,8 @@ pub fn optimise_one_with_memo( internal_dim, glwe_poly_size )); + let macro_key = (glwe_dim, glwe_poly_size, internal_dim); + // BlindRotate dans Circuit BS for &br_decomposition_parameter in BR_BL_FOR_CB.iter() { // Pbs dans BitExtract et Circuit BS et FP-KS (partagés) @@ -348,6 +348,42 @@ pub fn optimise_one_with_memo( )) .unwrap(); + // new pbs key for the bit extract pbs, shared + let bit_extract_decomposition_parameter = br_decomposition_parameter; + + let &(_bit_extract_base_noise, complexity_bit_extract_pbs) = memo + .pbs + .get(&( + glwe_dim, + glwe_poly_size, + internal_dim, + bit_extract_decomposition_parameter.log2_base, + bit_extract_decomposition_parameter.level, + )) + .unwrap(); + + // private packing keyswitch, <=> FP-KS (Circuit Boostrap) + let pp_ks_decomposition_parameter = + pbs_parameters.br_decomposition_parameter; + // for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS + + // Circuit Boostrap + let &(base_noise_private_packing_ks, complexity_ppks) = memo + .pp_switching + .get(&( + glwe_dim, + glwe_poly_size, + pp_ks_decomposition_parameter.log2_base, + pp_ks_decomposition_parameter.level, + )) + .expect(&format!( + "{}, {}, {}, {}", + glwe_dim, + glwe_poly_size, + pp_ks_decomposition_parameter.log2_base, + pp_ks_decomposition_parameter.level, + )); + // CircuitBootstrap: new parameters l,b for &circuit_pbs_decomposition_parameter in CB_V1_BL.iter() { // Hybrid packing @@ -392,27 +428,8 @@ pub fn optimise_one_with_memo( (precision * n_inputs) as f64, ); - // private packing keyswitch, <=> FP-KS (Circuit Boostrap) - let pp_ks_decomposition_parameter = - pbs_parameters.br_decomposition_parameter; - // for &pp_ks_decomposition_parameter in PP_KS_BL.iter() { // independant params for FP-KS - // Circuit Boostrap - let &(mut noise_private_packing_ks, complexity_ppks) = memo - .pp_switching - .get(&( - glwe_dim, - glwe_poly_size, - pp_ks_decomposition_parameter.log2_base, - pp_ks_decomposition_parameter.level, - )) - .expect(&format!( - "{}, {}, {}, {}", - glwe_dim, - glwe_poly_size, - pp_ks_decomposition_parameter.log2_base, - pp_ks_decomposition_parameter.level, - )); - noise_private_packing_ks += base_noise / 2.; + let noise_private_packing_ks = + base_noise_private_packing_ks + base_noise / 2.; // Circuit Boostrap if noise_private_packing_ks + noise_modulus_switching > variance_max @@ -448,23 +465,17 @@ pub fn optimise_one_with_memo( (2_f64.powf(2. * log_norm as f64)) * noise_cmux_tree_blind_rotate; // out noise * weights // Shared by all pbs (like brs) - for &ks_decomposition_parameter in KS_BL_FOR_CB.iter() { + let key_switching_q = memo.key_switching.get(¯o_key).unwrap(); + for &( + ks_decomposition_parameter, + (noise_keyswitch, complexity_keyswitch), + ) in key_switching_q.iter() + { let keyswitch_parameter = KeyswitchParameters { input_lwe_dimension: LweDimension(glwe_poly_size * glwe_dim), output_lwe_dimension: LweDimension(internal_dim), ks_decomposition_parameter, }; - - let &(noise_keyswitch, complexity_keyswitch) = memo - .key_switching - .get(&( - glwe_dim, - glwe_poly_size, - internal_dim, - ks_decomposition_parameter.log2_base, - ks_decomposition_parameter.level, - )) - .unwrap(); let complexity_all_ks = precision as f64 * complexity_keyswitch; if noise_private_packing_ks + noise_modulus_switching @@ -478,21 +489,6 @@ pub fn optimise_one_with_memo( continue; } - // new pbs key for the bit extract pbs, shared - let bit_extract_decomposition_parameter = - br_decomposition_parameter; - - let &(_bit_extract_base_noise, complexity_bit_extract_pbs) = memo - .pbs - .get(&( - glwe_dim, - glwe_poly_size, - internal_dim, - bit_extract_decomposition_parameter.log2_base, - bit_extract_decomposition_parameter.level, - )) - .unwrap(); - // noise_multisum = dot let current_maximal_noise = noise_multisum * (1 << (2 * (precision - 1))) as f64 @@ -514,6 +510,11 @@ pub fn optimise_one_with_memo( * (1 << n_functions_log) as f64 + complexity_bias; + if current_complexity > best_complexity { + // next level is more costly + break; + } + if current_complexity < best_complexity && current_maximal_noise < variance_max { @@ -558,9 +559,6 @@ pub fn optimise_one_with_memo( circuit_pbs_decomposition_parameter.log2_base, ), }); - } else if current_complexity > best_complexity { - // next level is more costly - break; } } }