From 48e43c5762d00cc6284c32fd0811f73075c95a63 Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 31 Aug 2022 17:54:26 +0200 Subject: [PATCH] chore: clarify no luts optimization and cuts --- .../src/optimization/dag/solo_key/analyze.rs | 18 +- .../src/optimization/dag/solo_key/optimize.rs | 238 ++++++++++-------- 2 files changed, 140 insertions(+), 116 deletions(-) diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index f29648958..fa75007de 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -655,11 +655,15 @@ impl OperationDag { true } - pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 { + pub fn complexity(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 { let luts_cost = one_lut_cost * (self.nb_luts as f64); let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension); luts_cost + levelled_cost } + + pub fn levelled_complexity(&self, input_lwe_dimension: u64) -> f64 { + self.levelled_complexity.cost(input_lwe_dimension) + } } #[cfg(test)] @@ -701,7 +705,7 @@ mod tests { let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; - let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT); assert_eq!(analysis.out_shapes[input1.i], Shape::number()); @@ -724,7 +728,7 @@ mod tests { let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; - let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT); assert!(analysis.out_shapes[lut1.i] == Shape::number()); @@ -750,7 +754,7 @@ mod tests { let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; - let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); let expected_var = SymbolicVariance { input_coeff: norm2, @@ -781,7 +785,7 @@ mod tests { let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; - let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); assert!(analysis.out_variances[dot.i].origin() == VO::Input); assert_eq!(analysis.out_precisions[dot.i], 3); @@ -810,7 +814,7 @@ mod tests { let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; - let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); let expected_var_dot1 = SymbolicVariance { input_coeff: weights.square_norm2() as f64, @@ -858,7 +862,7 @@ mod tests { let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; - let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); let expected_cost = (2 * lwe_dim) as f64 + 2.0 * one_lut_cost; assert_f64_eq(expected_cost, complexity_cost); diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index b4b068b83..49454e0d5 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -7,16 +7,12 @@ use crate::optimization::atomic_pattern::{ Caches, OptimizationDecompositionsConsts, OptimizationState, Solution, }; use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace}; -use crate::optimization::decomposition::{blind_rotate, cut_complexity_noise, keyswitch}; +use crate::optimization::decomposition::{blind_rotate, keyswitch}; use crate::parameters::GlweParameters; -use crate::security::glwe::minimal_variance; +use crate::security; use concrete_commons::dispersion::DispersionParameter; -const CUTS: bool = true; -const PARETO_CUTS: bool = true; -const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true; - #[allow(clippy::too_many_lines)] fn update_best_solution_with_best_decompositions( state: &mut OptimizationState, @@ -24,10 +20,11 @@ fn update_best_solution_with_best_decompositions( dag: &analyze::OperationDag, internal_dim: u64, glwe_params: GlweParameters, + input_noise_out: f64, noise_modulus_switching: f64, caches: &mut Caches, ) { - let safe_variance = consts.safe_variance; + assert!(dag.nb_luts > 0); let glwe_poly_size = glwe_params.polynomial_size(); let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size; @@ -35,134 +32,52 @@ fn update_best_solution_with_best_decompositions( let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max); let mut best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error); - let input_noise_out = minimal_variance( - glwe_params, - consts.config.ciphertext_modulus_log, - consts.config.security_level, - ) - .get_variance(); - - let no_luts = dag.nb_luts == 0; - // if no_luts we disable cuts, any parameters is acceptable in luts - let (cut_noise, cut_complexity) = if no_luts && CUTS { - (f64::INFINITY, f64::INFINITY) - } else { - ( - safe_variance - noise_modulus_switching, - (best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) - / (dag.nb_luts as f64), - ) - }; - - if input_noise_out > cut_noise { - // exact cut when has_only_luts_with_inputs, lower bound cut otherwise - return; - } - - // if only one layer of luts, no cut inside pareto_blind_rotate based on br noise, - // since it's never use inside the lut - let br_cut_noise = if dag.has_only_luts_with_inputs { - f64::INFINITY - } else { - cut_noise - }; - let br_cut_complexity = cut_complexity; - let br_pareto = caches .blind_rotate .pareto_quantities(glwe_params, internal_dim); - // if only one layer of luts, no cut inside pareto_blind_rotate based on br noise, - // since this noise is never used inside a lut - let br_pareto = if dag.has_only_luts_with_inputs { - br_pareto - } else { - cut_complexity_noise(br_cut_complexity, br_cut_noise, br_pareto) - }; - - if br_pareto.is_empty() { - return; - } - - let worst_input_ks_noise = if dag.has_only_luts_with_inputs { - input_noise_out - } else { - br_pareto.last().unwrap().noise - }; - let ks_cut_noise = cut_noise - worst_input_ks_noise; - let ks_cut_complexity = cut_complexity - br_pareto[0].complexity; let ks_pareto = caches .keyswitch .pareto_quantities(glwe_params, internal_dim); - let ks_pareto = cut_complexity_noise(ks_cut_complexity, ks_cut_noise, ks_pareto); - - if ks_pareto.is_empty() { - return; - } - - let i_max_ks = ks_pareto.len() - 1; - let mut i_current_max_ks = i_max_ks; - let input_noise_out = minimal_variance( - glwe_params, - consts.config.ciphertext_modulus_log, - consts.config.security_level, - ) - .get_variance(); + // by constructon br_pareto and ks_pareto are non-empty let mut best_br = br_pareto[0]; let mut best_ks = ks_pareto[0]; let mut update_best_solution = false; for &br_quantity in br_pareto { // increasing complexity, decreasing variance + let one_lut_cost = br_quantity.complexity; + let complexity = dag.complexity(input_lwe_dimension, one_lut_cost); + if complexity > best_complexity { + // Since br_pareto is scanned by increasing complexity, we can stop + break; + } let not_feasible = !dag.feasible( input_noise_out, br_quantity.noise, 0.0, noise_modulus_switching, ); - if not_feasible && CUTS { + if not_feasible { continue; } - let one_lut_cost = br_quantity.complexity; - let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost); - if complexity > best_complexity { - // As best can evolves it is complementary to blind_rotate_quantities cuts. - if PARETO_CUTS { - break; - } else if CUTS { + for &ks_quantity in ks_pareto.iter().rev() { + let one_lut_cost = ks_quantity.complexity + br_quantity.complexity; + let complexity = dag.complexity(input_lwe_dimension, one_lut_cost); + let worse_complexity = complexity > best_complexity; + if worse_complexity { continue; } - } - for i_ks_pareto in (0..=i_current_max_ks).rev() { - // increasing variance, decreasing complexity - let ks_quantity = ks_pareto[i_ks_pareto]; let not_feasible = !dag.feasible( input_noise_out, br_quantity.noise, ks_quantity.noise, noise_modulus_switching, ); - // let noise_max = br_quantity.noise * dag.lut_base_noise_worst_lut + ks_quantity.noise + noise_modulus_switching; if not_feasible { - if CROSS_PARETO_CUTS { - // the pareto of 2 added pareto is scanned linearly - // but with all cuts, pre-computing => no gain - i_current_max_ks = usize::min(i_ks_pareto + 1, i_max_ks); - break; - // it's compatible with next i_br but with the worst complexity - } else if PARETO_CUTS { - // increasing variance => we can skip all remaining - break; - } - continue; - } - - let one_lut_cost = ks_quantity.complexity + br_quantity.complexity; - let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost); - let worse_complexity = complexity > best_complexity; - if worse_complexity { - continue; + // Since ks_pareto is scanned by increasing noise, we can stop + break; } let (peek_p_error, variance) = dag.peek_p_error( @@ -215,6 +130,103 @@ fn update_best_solution_with_best_decompositions( const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8; +fn update_no_luts_solution( + state: &mut OptimizationState, + consts: &OptimizationDecompositionsConsts, + dag: &analyze::OperationDag, + glwe_params: GlweParameters, + input_noise_out: f64, +) { + const CHECKED_IGNORED_NOISE: f64 = f64::MAX; + const UNDEFINED_PARAM: u64 = 0; + + let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension(); + + let best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity); + let best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error); + + let complexity = if dag.levelled_complexity == LevelledComplexity::ZERO { + // The compiler has given a 0 levelled complexity. + // There is no way to compare solutions. + // Assuming linear complexity. + input_lwe_dimension as f64 + } else { + dag.levelled_complexity(input_lwe_dimension) + }; + + if complexity > best_complexity { + return; + } + + let (p_error, variance) = dag.peek_p_error( + input_noise_out, + CHECKED_IGNORED_NOISE, + CHECKED_IGNORED_NOISE, + CHECKED_IGNORED_NOISE, + consts.kappa, + ); + + #[allow(clippy::float_cmp)] + let same_complexity_no_few_errors = complexity == best_complexity && p_error >= best_p_error; + if same_complexity_no_few_errors { + return; + } + // The complexity is either better or equivalent with less errors + state.best_solution = Some(Solution { + input_lwe_dimension, + internal_ks_output_lwe_dimension: UNDEFINED_PARAM, + ks_decomposition_level_count: UNDEFINED_PARAM, + ks_decomposition_base_log: UNDEFINED_PARAM, + glwe_polynomial_size: glwe_params.polynomial_size(), + glwe_dimension: glwe_params.glwe_dimension, + br_decomposition_level_count: UNDEFINED_PARAM, + br_decomposition_base_log: UNDEFINED_PARAM, + complexity, + p_error, + global_p_error: dag.global_p_error( + input_noise_out, + CHECKED_IGNORED_NOISE, + CHECKED_IGNORED_NOISE, + CHECKED_IGNORED_NOISE, + consts.kappa, + ), + noise_max: variance, + }); +} + +fn minimal_variance(config: &Config, glwe_params: GlweParameters) -> f64 { + security::glwe::minimal_variance( + glwe_params, + config.ciphertext_modulus_log, + config.security_level, + ) + .get_variance() +} + +fn optimize_no_luts( + mut state: OptimizationState, + consts: &OptimizationDecompositionsConsts, + dag: &analyze::OperationDag, + search_space: &SearchSpace, +) -> OptimizationState { + let not_feasible = |input_noise_out| !dag.feasible(input_noise_out, 0.0, 0.0, 0.0); + for &glwe_dim in &search_space.glwe_dimensions { + for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes { + let glwe_params = GlweParameters { + log2_polynomial_size: glwe_log_poly_size, + glwe_dimension: glwe_dim, + }; + let input_noise_out = minimal_variance(&consts.config, glwe_params); + if not_feasible(input_noise_out) { + continue; + } + update_no_luts_solution(&mut state, consts, dag, glwe_params, input_noise_out); + } + } + state +} + +#[allow(clippy::too_many_lines)] pub fn optimize( dag: &unparametrized::OperationDag, config: Config, @@ -251,6 +263,10 @@ pub fn optimize( best_solution: None, }; + if dag.nb_luts == 0 { + return optimize_no_luts(state, &consts, &dag, search_space); + } + let mut caches = Caches { blind_rotate: blind_rotate::for_security(security_level).cache(), keyswitch: keyswitch::for_security(security_level).cache(), @@ -264,20 +280,23 @@ pub fn optimize( ) .get_variance() }; - let not_feasible = - |noise_modulus_switching| !dag.feasible(0.0, 0.0, 0.0, noise_modulus_switching); + + let not_feasible = |input_noise_out, noise_modulus_switching| { + !dag.feasible(input_noise_out, 0.0, 0.0, noise_modulus_switching) + }; for &glwe_dim in &search_space.glwe_dimensions { for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes { - let glwe_poly_size = 1 << glwe_log_poly_size; let glwe_params = GlweParameters { log2_polynomial_size: glwe_log_poly_size, glwe_dimension: glwe_dim, }; + let input_noise_out = minimal_variance(&config, glwe_params); for &internal_dim in &search_space.internal_lwe_dimensions { + let glwe_poly_size = 1 << glwe_log_poly_size; let noise_modulus_switching = noise_modulus_switching(glwe_poly_size, internal_dim); - if CUTS && not_feasible(noise_modulus_switching) { - // assume this noise is increasing with internal_dim + if not_feasible(input_noise_out, noise_modulus_switching) { + // noise_modulus_switching is increasing with internal_dim break; } update_best_solution_with_best_decompositions( @@ -286,6 +305,7 @@ pub fn optimize( &dag, internal_dim, glwe_params, + input_noise_out, noise_modulus_switching, &mut caches, );