mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(compiler): select the best complexity solution
This commit is contained in:
@@ -13,6 +13,15 @@ pub enum Solution {
|
||||
WopSolution(WopSolution),
|
||||
}
|
||||
|
||||
impl Solution {
|
||||
fn complexity(&self) -> f64 {
|
||||
match self {
|
||||
Self::WpSolution(v) => v.complexity,
|
||||
Self::WopSolution(v) => v.complexity,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Encoding {
|
||||
Auto,
|
||||
@@ -24,15 +33,33 @@ fn max_precision(dag: &OperationDag) -> Precision {
|
||||
dag.out_precisions.iter().copied().max().unwrap_or(0)
|
||||
}
|
||||
|
||||
fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {
|
||||
fn updated_global_p_error_and_comlexity(nb_luts: u64, sol: WopSolution) -> WopSolution {
|
||||
let global_p_error = repeat_p_error(sol.p_error, nb_luts);
|
||||
|
||||
let complexity = nb_luts as f64 * sol.complexity;
|
||||
WopSolution {
|
||||
complexity,
|
||||
global_p_error,
|
||||
..sol
|
||||
}
|
||||
}
|
||||
|
||||
fn best_complexity_solution(native: Option<Solution>, crt: Option<Solution>) -> Option<Solution> {
|
||||
match (&native, &crt) {
|
||||
(Some(s_native), Some(s_crt)) => {
|
||||
// crt has 0 complexity in no lut case
|
||||
// so we always select native in this case
|
||||
if s_native.complexity() <= s_crt.complexity() || s_crt.complexity() == 0.0 {
|
||||
native
|
||||
} else {
|
||||
crt
|
||||
}
|
||||
}
|
||||
(Some(_), None) => native,
|
||||
(None, Some(_)) => crt,
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn optimize_with_wop_pbs(
|
||||
dag: &OperationDag,
|
||||
config: Config,
|
||||
@@ -46,7 +73,7 @@ fn optimize_with_wop_pbs(
|
||||
let log_norm = default_log_norm2_woppbs.min(worst_log_norm);
|
||||
wop_optimize(max_precision as u64, config, log_norm, search_space, caches)
|
||||
.best_solution
|
||||
.map(|sol| updated_global_p_error(nb_luts, sol))
|
||||
.map(|sol| updated_global_p_error_and_comlexity(nb_luts, sol))
|
||||
}
|
||||
|
||||
pub fn optimize(
|
||||
@@ -67,7 +94,7 @@ pub fn optimize(
|
||||
.map(Solution::WopSolution)
|
||||
};
|
||||
match encoding {
|
||||
Encoding::Auto => native().or_else(crt),
|
||||
Encoding::Auto => best_complexity_solution(native(), crt()),
|
||||
Encoding::Native => native(),
|
||||
Encoding::Crt => crt(),
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user