From 20833cefac32b14ab409a6588593573e42839bc1 Mon Sep 17 00:00:00 2001 From: rudy Date: Mon, 5 Dec 2022 11:14:17 +0100 Subject: [PATCH] feat(compiler): select the best complexity solution --- .../dag/solo_key/optimize_generic.rs | 35 ++++++++++++++++--- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs index ad81e4c04..dd2a75c10 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs @@ -13,6 +13,15 @@ pub enum Solution { WopSolution(WopSolution), } +impl Solution { + fn complexity(&self) -> f64 { + match self { + Self::WpSolution(v) => v.complexity, + Self::WopSolution(v) => v.complexity, + } + } +} + #[derive(Clone, Copy)] pub enum Encoding { Auto, @@ -24,15 +33,33 @@ fn max_precision(dag: &OperationDag) -> Precision { dag.out_precisions.iter().copied().max().unwrap_or(0) } -fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution { +fn updated_global_p_error_and_comlexity(nb_luts: u64, sol: WopSolution) -> WopSolution { let global_p_error = repeat_p_error(sol.p_error, nb_luts); - + let complexity = nb_luts as f64 * sol.complexity; WopSolution { + complexity, global_p_error, ..sol } } +fn best_complexity_solution(native: Option, crt: Option) -> Option { + match (&native, &crt) { + (Some(s_native), Some(s_crt)) => { + // crt has 0 complexity in no lut case + // so we always select native in this case + if s_native.complexity() <= s_crt.complexity() || s_crt.complexity() == 0.0 { + native + } else { + crt + } + } + (Some(_), None) => native, + (None, Some(_)) => crt, + (None, None) => None, + } +} + fn optimize_with_wop_pbs( dag: &OperationDag, config: Config, @@ -46,7 +73,7 @@ fn optimize_with_wop_pbs( let log_norm = default_log_norm2_woppbs.min(worst_log_norm); wop_optimize(max_precision as u64, config, log_norm, search_space, caches) .best_solution - .map(|sol| updated_global_p_error(nb_luts, sol)) + .map(|sol| updated_global_p_error_and_comlexity(nb_luts, sol)) } pub fn optimize( @@ -67,7 +94,7 @@ pub fn optimize( .map(Solution::WopSolution) }; match encoding { - Encoding::Auto => native().or_else(crt), + Encoding::Auto => best_complexity_solution(native(), crt()), Encoding::Native => native(), Encoding::Crt => crt(), }