diff --git a/charts/src/bin/norm2_complexity.rs b/charts/src/bin/norm2_complexity.rs index db7df7be1..cd87141ac 100644 --- a/charts/src/bin/norm2_complexity.rs +++ b/charts/src/bin/norm2_complexity.rs @@ -3,6 +3,7 @@ use concrete_optimizer::computing_cost::cpu::CpuComplexity; use concrete_optimizer::global_parameters::DEFAUT_DOMAINS; use concrete_optimizer::optimization::atomic_pattern::{self as optimize_atomic_pattern}; use concrete_optimizer::optimization::config::{Config, SearchSpace}; +use concrete_optimizer::optimization::decomposition; use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern; pub const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; @@ -40,6 +41,8 @@ fn main() -> Result<(), Box> { complexity_model: &CpuComplexity::default(), }; + let cache = decomposition::cache(security_level); + let solutions: Vec<_> = log_norm2s .clone() .filter_map(|log_norm2| { @@ -51,6 +54,7 @@ fn main() -> Result<(), Box> { config, noise_scale, &search_space, + &cache, ) .best_solution .map(|a| (log_norm2, a.complexity)) @@ -61,9 +65,15 @@ fn main() -> Result<(), Box> { .filter_map(|log_norm2| { let noise_scale = 2_f64.powi(log_norm2 as i32); - optimize_wop_atomic_pattern::optimize_one(precision, config, noise_scale, &search_space) - .best_solution - .map(|a| (log_norm2, a.complexity)) + optimize_wop_atomic_pattern::optimize_one( + precision, + config, + noise_scale, + &search_space, + &cache, + ) + .best_solution + .map(|a| (log_norm2, a.complexity)) }) .collect(); diff --git a/charts/src/bin/precision_complexity.rs b/charts/src/bin/precision_complexity.rs index 2a82663bc..b53e72190 100644 --- a/charts/src/bin/precision_complexity.rs +++ b/charts/src/bin/precision_complexity.rs @@ -3,6 +3,7 @@ use concrete_optimizer::computing_cost::cpu::CpuComplexity; use concrete_optimizer::global_parameters::DEFAUT_DOMAINS; use concrete_optimizer::optimization::atomic_pattern::{self as optimize_atomic_pattern}; use concrete_optimizer::optimization::config::{Config, SearchSpace}; +use concrete_optimizer::optimization::decomposition; use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern; pub const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; @@ -40,6 +41,8 @@ fn main() -> Result<(), Box> { complexity_model: &CpuComplexity::default(), }; + let cache = decomposition::cache(security_level); + let solutions: Vec<_> = precisions .clone() .filter_map(|precision| { @@ -51,6 +54,7 @@ fn main() -> Result<(), Box> { config, noise_factor, &search_space, + &cache, ) .best_solution .map(|a| (precision, a.complexity)) @@ -64,6 +68,7 @@ fn main() -> Result<(), Box> { config, log_norm2 as f64, &search_space, + &cache, ) .best_solution .map(|a| (precision, a.complexity)) diff --git a/concrete-optimizer-cpp/src/concrete-optimizer.rs b/concrete-optimizer-cpp/src/concrete-optimizer.rs index 75a925bdd..440ee901c 100644 --- a/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -5,6 +5,7 @@ use concrete_optimizer::dag::operator::{ use concrete_optimizer::dag::unparametrized; use concrete_optimizer::optimization::config::{Config, SearchSpace}; use concrete_optimizer::optimization::dag::solo_key::optimize_generic::Solution as DagSolution; +use concrete_optimizer::optimization::decomposition; fn no_solution() -> ffi::Solution { ffi::Solution { @@ -43,6 +44,7 @@ fn optimize_bootstrap( config, noise_factor, &search_space, + &decomposition::cache(security_level), ); result .best_solution @@ -219,6 +221,7 @@ impl OperationDag { &self.0, config, &search_space, + &decomposition::cache(security_level), ); result .best_solution @@ -239,12 +242,14 @@ impl OperationDag { }; let search_space = SearchSpace::default(); + let cache = decomposition::cache(security_level); let result = concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize( &self.0, config, &search_space, default_log_norm2_woppbs, + &cache, ); result.map_or_else(no_dag_solution, |solution| solution.into()) } diff --git a/concrete-optimizer/Cargo.toml b/concrete-optimizer/Cargo.toml index bf62cd3bc..8487c0a5b 100644 --- a/concrete-optimizer/Cargo.toml +++ b/concrete-optimizer/Cargo.toml @@ -15,7 +15,6 @@ static_init = "1.0.3" serde = { version = "1.0", features = ["derive"] } rmp-serde = "1.1.0" statrs = "0.15.0" -lazy_static = "1.4.0" [dev-dependencies] approx = "0.5" diff --git a/concrete-optimizer/src/optimization/atomic_pattern.rs b/concrete-optimizer/src/optimization/atomic_pattern.rs index 503219427..1730e7d4c 100644 --- a/concrete-optimizer/src/optimization/atomic_pattern.rs +++ b/concrete-optimizer/src/optimization/atomic_pattern.rs @@ -5,7 +5,8 @@ use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositi use crate::utils::square; use concrete_commons::dispersion::{DispersionParameter, Variance}; -use super::decomposition::{blind_rotate, cut_complexity_noise, keyswitch}; +use super::decomposition; +use super::decomposition::{blind_rotate, cut_complexity_noise, keyswitch, PersistDecompCache}; // Ref time for v0 table 1 thread: 950ms const CUTS: bool = true; // 80ms @@ -46,6 +47,19 @@ pub struct Caches { pub keyswitch: keyswitch::Cache, } +impl Caches { + pub fn new(cache: &decomposition::PersistDecompCache) -> Self { + Self { + blind_rotate: cache.br.cache(), + keyswitch: cache.ks.cache(), + } + } + pub fn backport_to(self, cache: &decomposition::PersistDecompCache) { + cache.ks.backport(self.keyswitch); + cache.br.backport(self.blind_rotate); + } +} + #[allow(clippy::too_many_lines)] fn update_state_with_best_decompositions( state: &mut OptimizationState, @@ -182,6 +196,7 @@ pub fn optimize_one( config: Config, noise_factor: f64, search_space: &SearchSpace, + cache: &PersistDecompCache, ) -> OptimizationState { assert!(0 < precision && precision <= 16); assert!(1.0 <= noise_factor); @@ -193,7 +208,6 @@ pub fn optimize_one( // the blind rotate decomposition let ciphertext_modulus_log = config.ciphertext_modulus_log; - let security_level = config.security_level; let safe_variance = error::safe_variance_bound_2padbits( precision, ciphertext_modulus_log, @@ -228,10 +242,7 @@ pub fn optimize_one( > consts.safe_variance }; - let mut caches = Caches { - blind_rotate: blind_rotate::for_security(security_level).cache(), - keyswitch: keyswitch::for_security(security_level).cache(), - }; + let mut caches = Caches::new(cache); for &glwe_dim in &search_space.glwe_dimensions { for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes { @@ -260,8 +271,7 @@ pub fn optimize_one( } } - blind_rotate::for_security(security_level).backport(caches.blind_rotate); - keyswitch::for_security(security_level).backport(caches.keyswitch); + caches.backport_to(cache); if let Some(sol) = state.best_solution { assert!(0.0 <= sol.p_error && sol.p_error <= 1.0); diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 2abaa539f..ad38e61e0 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -7,7 +7,7 @@ use crate::optimization::atomic_pattern::{ Caches, OptimizationDecompositionsConsts, OptimizationState, Solution, }; use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace}; -use crate::optimization::decomposition::{blind_rotate, keyswitch}; +use crate::optimization::decomposition::PersistDecompCache; use crate::parameters::GlweParameters; use crate::security; @@ -231,6 +231,7 @@ pub fn optimize( dag: &unparametrized::OperationDag, config: Config, search_space: &SearchSpace, + cache: &PersistDecompCache, ) -> OptimizationState { let ciphertext_modulus_log = config.ciphertext_modulus_log; let security_level = config.security_level; @@ -266,11 +267,7 @@ pub fn optimize( if dag.nb_luts == 0 { return optimize_no_luts(state, &consts, &dag, search_space); } - - let mut caches = Caches { - blind_rotate: blind_rotate::for_security(security_level).cache(), - keyswitch: keyswitch::for_security(security_level).cache(), - }; + let mut caches = Caches::new(cache); let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| { noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key( @@ -316,8 +313,7 @@ pub fn optimize( } } - blind_rotate::for_security(security_level).backport(caches.blind_rotate); - keyswitch::for_security(security_level).backport(caches.keyswitch); + caches.backport_to(cache); if let Some(sol) = state.best_solution { assert!(0.0 <= sol.p_error && sol.p_error <= 1.0); @@ -335,6 +331,7 @@ pub fn optimize_v0( config: Config, noise_factor: f64, search_space: &SearchSpace, + cache: &PersistDecompCache, ) -> OptimizationState { use crate::dag::operator::{FunctionTable, Shape}; let same_scale_manp = 0.0; @@ -349,7 +346,7 @@ pub fn optimize_v0( let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment); let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); - let mut state = optimize(&dag, config, search_space); + let mut state = optimize(&dag, config, search_space, cache); if let Some(sol) = &mut state.best_solution { sol.complexity /= 2.0; } @@ -364,9 +361,9 @@ mod tests { use crate::computing_cost::cpu::CpuComplexity; use crate::dag::operator::{FunctionTable, Shape, Weights}; use crate::noise_estimator::p_error::repeat_p_error; - use crate::optimization::atomic_pattern; use crate::optimization::config::SearchSpace; use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin; + use crate::optimization::{atomic_pattern, decomposition}; use crate::utils::square; fn small_relative_diff(v1: f64, v2: f64) -> bool { @@ -390,7 +387,10 @@ mod tests { const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; - fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState { + fn optimize( + dag: &unparametrized::OperationDag, + cache: &PersistDecompCache, + ) -> OptimizationState { let config = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, @@ -400,7 +400,7 @@ mod tests { let search_space = SearchSpace::default(); - super::optimize(dag, config, &search_space) + super::optimize(dag, config, &search_space, cache) } struct Times { @@ -439,16 +439,38 @@ mod tests { complexity_model: &CpuComplexity::default(), }; - let _ = optimize_v0(sum_size, precision, config, weight as f64, &search_space); + let cache = decomposition::cache(config.security_level); + + let _ = optimize_v0( + sum_size, + precision, + config, + weight as f64, + &search_space, + &cache, + ); // ensure cache is filled let chrono = Instant::now(); - let state = optimize_v0(sum_size, precision, config, weight as f64, &search_space); + let state = optimize_v0( + sum_size, + precision, + config, + weight as f64, + &search_space, + &cache, + ); times.dag_time += chrono.elapsed().as_nanos(); let chrono = Instant::now(); - let state_ref = - atomic_pattern::optimize_one(sum_size, precision, config, weight as f64, &search_space); + let state_ref = atomic_pattern::optimize_one( + sum_size, + precision, + config, + weight as f64, + &search_space, + &cache, + ); times.worst_time += chrono.elapsed().as_nanos(); assert_eq!( state.best_solution.is_some(), @@ -476,6 +498,10 @@ mod tests { } fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) { + let security_level = 128; + + let cache = decomposition::cache(security_level); + let mut dag = unparametrized::OperationDag::new(); { let input1 = dag.add_input(precision, Shape::number()); @@ -488,7 +514,7 @@ mod tests { let dag2 = analyze::analyze( &dag, &NoiseBoundConfig { - security_level: 128, + security_level, maximum_acceptable_error_probability: _4_SIGMA, ciphertext_modulus_log: 64, }, @@ -506,15 +532,21 @@ mod tests { let search_space = SearchSpace::default(); let config = Config { - security_level: 128, + security_level, maximum_acceptable_error_probability: _4_SIGMA, ciphertext_modulus_log: 64, complexity_model: &CpuComplexity::default(), }; - let state = optimize(&dag); - let state_ref = - atomic_pattern::optimize_one(1, precision as u64, config, weight as f64, &search_space); + let state = optimize(&dag, &cache); + let state_ref = atomic_pattern::optimize_one( + 1, + precision as u64, + config, + weight as f64, + &search_space, + &cache, + ); assert_eq!( state.best_solution.is_some(), state_ref.best_solution.is_some() @@ -531,7 +563,7 @@ mod tests { assert!(sol.global_p_error <= 1.0); } - fn no_lut_vs_lut(precision: Precision) { + fn no_lut_vs_lut(precision: Precision, cache: &PersistDecompCache) { let mut dag_lut = unparametrized::OperationDag::new(); let input1 = dag_lut.add_input(precision as u8, Shape::number()); let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision); @@ -539,8 +571,8 @@ mod tests { let mut dag_no_lut = unparametrized::OperationDag::new(); let _input2 = dag_no_lut.add_input(precision as u8, Shape::number()); - let state_no_lut = optimize(&dag_no_lut); - let state_lut = optimize(&dag_lut); + let state_no_lut = optimize(&dag_no_lut, cache); + let state_lut = optimize(&dag_lut, cache); assert_eq!( state_no_lut.best_solution.is_some(), state_lut.best_solution.is_some() @@ -556,14 +588,16 @@ mod tests { } #[test] fn test_lut_vs_no_lut() { + let cache = decomposition::cache(128); for precision in 1..=8 { - no_lut_vs_lut(precision); + no_lut_vs_lut(precision, &cache); } } fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise( precision: Precision, weight: u64, + cache: &PersistDecompCache, ) { let weight = &Weights::number(weight); @@ -583,8 +617,8 @@ mod tests { let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision); } - let state_1 = optimize(&dag_1); - let state_2 = optimize(&dag_2); + let state_1 = optimize(&dag_1, cache); + let state_2 = optimize(&dag_2, cache); if state_1.best_solution.is_none() { assert!(state_2.best_solution.is_none()); @@ -597,15 +631,18 @@ mod tests { #[test] fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() { + let cache = decomposition::cache(128); for log_weight in 1..=16 { let weight = 1 << log_weight; for precision in 5..=9 { - lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight); + lut_with_input_base_noise_better_than_lut_with_lut_base_noise( + precision, weight, &cache, + ); } } } - fn lut_1_layer_has_better_complexity(precision: Precision) { + fn lut_1_layer_has_better_complexity(precision: Precision, cache: &PersistDecompCache) { let dag_1_layer = { let mut dag = unparametrized::OperationDag::new(); let input1 = dag.add_input(precision as u8, Shape::number()); @@ -621,17 +658,18 @@ mod tests { dag }; - let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap(); - let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap(); + let sol_1_layer = optimize(&dag_1_layer, cache).best_solution.unwrap(); + let sol_2_layer = optimize(&dag_2_layer, cache).best_solution.unwrap(); assert!(sol_1_layer.complexity < sol_2_layer.complexity); } #[test] fn test_lut_1_layer_is_better() { + let cache = decomposition::cache(128); // for some reason on 4, 5, 6, the complexity is already minimal // this could be due to pre-defined pareto set for precision in [1, 2, 3, 7, 8] { - lut_1_layer_has_better_complexity(precision); + lut_1_layer_has_better_complexity(precision, &cache); } } @@ -643,7 +681,10 @@ mod tests { let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); } - fn assert_multi_precision_dominate_single(weight: u64) -> Option { + fn assert_multi_precision_dominate_single( + weight: u64, + cache: &PersistDecompCache, + ) -> Option { let low_precision = 4u8; let high_precision = 5u8; let mut dag_low = unparametrized::OperationDag::new(); @@ -656,12 +697,12 @@ mod tests { circuit(&mut dag_multi, low_precision, weight); circuit(&mut dag_multi, high_precision, 1); } - let state_multi = optimize(&dag_multi); + let state_multi = optimize(&dag_multi, cache); let mut sol_multi = state_multi.best_solution?; - let state_low = optimize(&dag_low); - let state_high = optimize(&dag_high); + let state_low = optimize(&dag_low, cache); + let state_high = optimize(&dag_high, cache); let sol_low = state_low.best_solution.unwrap(); let sol_high = state_high.best_solution.unwrap(); @@ -680,10 +721,11 @@ mod tests { #[test] fn test_multi_precision_dominate_single() { + let cache = decomposition::cache(128); let mut prev = Some(true); // true -> ... -> true -> false -> ... -> false for log2_weight in 0..29 { let weight = 1 << log2_weight; - let current = assert_multi_precision_dominate_single(weight); + let current = assert_multi_precision_dominate_single(weight, &cache); #[allow(clippy::match_like_matches_macro)] // less readable let authorized = match (prev, current) { (Some(false), Some(true)) => false, @@ -713,22 +755,28 @@ mod tests { #[test] fn test_global_p_error_input() { + let cache = decomposition::cache(128); for precision in [4_u8, 8] { for weight in [1, 3, 27, 243, 729] { for dim in [1, 2, 16, 32] { - let _ = check_global_p_error_input(dim, weight, precision); + let _ = check_global_p_error_input(dim, weight, precision, &cache); } } } } - fn check_global_p_error_input(dim: u64, weight: u64, precision: u8) -> f64 { + fn check_global_p_error_input( + dim: u64, + weight: u64, + precision: u8, + cache: &PersistDecompCache, + ) -> f64 { let shape = Shape::vector(dim); let weights = Weights::number(weight); let mut dag = unparametrized::OperationDag::new(); let input1 = dag.add_input(precision as u8, shape); let _dot1 = dag.add_dot([input1], weights); // this is just several multiply - let state = optimize(&dag); + let state = optimize(&dag, cache); let sol = state.best_solution.unwrap(); let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim); approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim); @@ -737,16 +785,22 @@ mod tests { #[test] fn test_global_p_error_lut() { + let cache = decomposition::cache(128); for precision in [4_u8, 8] { for weight in [1, 3, 27, 243, 729] { for depth in [2, 16, 32] { - check_global_p_error_lut(depth, weight, precision); + check_global_p_error_lut(depth, weight, precision, &cache); } } } } - fn check_global_p_error_lut(depth: u64, weight: u64, precision: u8) { + fn check_global_p_error_lut( + depth: u64, + weight: u64, + precision: u8, + cache: &PersistDecompCache, + ) { let shape = Shape::number(); let weights = Weights::number(weight); let mut dag = unparametrized::OperationDag::new(); @@ -755,7 +809,7 @@ mod tests { let dot = dag.add_dot([last_val], &weights); last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision); } - let state = optimize(&dag); + let state = optimize(&dag, cache); let sol = state.best_solution.unwrap(); // the first lut on input has reduced impact on error probability let lower_nb_dominating_lut = depth - 1; @@ -792,6 +846,7 @@ mod tests { #[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const #[test] fn test_global_p_error_dominating_lut() { + let cache = decomposition::cache(128); let depth = 128; let weights_low = 1; let weights_high = 1; @@ -804,7 +859,7 @@ mod tests { weights_low, weights_high, ); - let sol = optimize(&dag).best_solution.unwrap(); + let sol = optimize(&dag, &cache).best_solution.unwrap(); // the 2 first luts and low precision/weight luts have little impact on error probability let nb_dominating_lut = depth - 1; let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut); @@ -819,6 +874,7 @@ mod tests { #[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const #[test] fn test_global_p_error_non_dominating_lut() { + let cache = decomposition::cache(128); let depth = 128; let weights_low = 1024 * 1024 * 3; let weights_high = 1; @@ -831,7 +887,7 @@ mod tests { weights_low, weights_high, ); - let sol = optimize(&dag).best_solution.unwrap(); + let sol = optimize(&dag, &cache).best_solution.unwrap(); // all intern luts have an impact on error probability almost equaly let nb_dominating_lut = (2 * depth) - 1; let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut); diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs index d1300ef21..593b8341d 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs @@ -4,6 +4,7 @@ use crate::noise_estimator::p_error::repeat_p_error; use crate::optimization::atomic_pattern::Solution as WpSolution; use crate::optimization::config::{Config, SearchSpace}; use crate::optimization::dag::solo_key::{analyze, optimize}; +use crate::optimization::decomposition::PersistDecompCache; use crate::optimization::wop_atomic_pattern::optimize::optimize_one as wop_optimize; use crate::optimization::wop_atomic_pattern::Solution as WopSolution; use std::ops::RangeInclusive; @@ -48,6 +49,7 @@ pub fn optimize( config: Config, search_space: &SearchSpace, default_log_norm2_woppbs: f64, + cache: &PersistDecompCache, ) -> Option { let max_precision = max_precision(dag); let nb_luts = analyze::lut_count_from_dag(dag); @@ -58,11 +60,17 @@ pub fn optimize( let default_log_norm = default_log_norm2_woppbs; let worst_log_norm = analyze::worst_log_norm(dag); let log_norm = default_log_norm.min(worst_log_norm); - let opt_sol = - wop_optimize(fallback_16b_precision, config, log_norm, search_space).best_solution; + let opt_sol = wop_optimize( + fallback_16b_precision, + config, + log_norm, + search_space, + cache, + ) + .best_solution; opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol))) } else { - let opt_sol = optimize::optimize(dag, config, search_space).best_solution; + let opt_sol = optimize::optimize(dag, config, search_space, cache).best_solution; opt_sol.map(Solution::WpSolution) } } diff --git a/concrete-optimizer/src/optimization/decomposition/blind_rotate.rs b/concrete-optimizer/src/optimization/decomposition/blind_rotate.rs index 1a0e4b6a9..cb1f1248a 100644 --- a/concrete-optimizer/src/optimization/decomposition/blind_rotate.rs +++ b/concrete-optimizer/src/optimization/decomposition/blind_rotate.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use serde::{Deserialize, Serialize}; use concrete_commons::dispersion::DispersionParameter; @@ -8,8 +6,6 @@ use crate::computing_cost::operators::pbs::PbsComplexity; use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; use crate::parameters::{BrDecompositionParameters, GlweParameters, LweDimension, PbsParameters}; use crate::security; - -use crate::security::security_weights::SECURITY_WEIGHTS_TABLE; use crate::utils::cache::ephemeral::{CacheHashMap, EphemeralCache}; use crate::utils::cache::persistent::PersistentCacheHashMap; @@ -120,42 +116,22 @@ impl Cache { } } -type PersistDecompCache = PersistentCacheHashMap>; -type MultiSecPersistDecompCache = HashMap; // just to attach a finaly +pub type PersistDecompCache = PersistentCacheHashMap>; -#[static_init::dynamic] -pub static SHARED_CACHE: MultiSecPersistDecompCache = SECURITY_WEIGHTS_TABLE - .keys() - .map(|&security_level| { - let ciphertext_modulus_log = 64; - let tmp: String = std::env::temp_dir() - .to_str() - .expect("Invalid tmp dir") - .into(); - let path = format!("{tmp}/optimizer/cache/br-decomp-cpu-64-{security_level}"); - let function = move |(glwe_params, internal_dim): MacroParam| { - pareto_quantities( - ciphertext_modulus_log, - security_level, - internal_dim, - glwe_params, - ) - }; - ( +pub fn cache(security_level: u64) -> PersistDecompCache { + let ciphertext_modulus_log = 64; + let tmp: String = std::env::temp_dir() + .to_str() + .expect("Invalid tmp dir") + .into(); + let path = format!("{tmp}/optimizer/cache/br-decomp-cpu-64-{security_level}"); + let function = move |(glwe_params, internal_dim): MacroParam| { + pareto_quantities( + ciphertext_modulus_log, security_level, - PersistentCacheHashMap::new(&path, "v0", function), + internal_dim, + glwe_params, ) - }) - .collect::(); - -#[cfg(not(target_os = "macos"))] -#[static_init::destructor(10)] -extern "C" fn finaly() { - for v in SHARED_CACHE.values() { - v.sync_to_disk(); - } -} - -pub fn for_security(security_level: u64) -> &'static PersistDecompCache { - SHARED_CACHE.get(&security_level).unwrap() + }; + PersistentCacheHashMap::new(&path, "v0", function) } diff --git a/concrete-optimizer/src/optimization/decomposition/keyswitch.rs b/concrete-optimizer/src/optimization/decomposition/keyswitch.rs index 21d84ff2d..21e9ed020 100644 --- a/concrete-optimizer/src/optimization/decomposition/keyswitch.rs +++ b/concrete-optimizer/src/optimization/decomposition/keyswitch.rs @@ -1,5 +1,3 @@ -use std::collections::HashMap; - use serde::{Deserialize, Serialize}; use concrete_commons::dispersion::DispersionParameter; @@ -9,7 +7,6 @@ use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; use crate::parameters::{ GlweParameters, KeyswitchParameters, KsDecompositionParameters, LweDimension, }; -use crate::security::security_weights::SECURITY_WEIGHTS_TABLE; use crate::utils::cache::ephemeral::{CacheHashMap, EphemeralCache}; use crate::utils::cache::persistent::PersistentCacheHashMap; @@ -121,42 +118,22 @@ impl Cache { } } -type PersistDecompCache = PersistentCacheHashMap>; -type MultiSecPersistDecompCache = HashMap; +pub type PersistDecompCache = PersistentCacheHashMap>; -#[static_init::dynamic] -pub static SHARED_CACHE: MultiSecPersistDecompCache = SECURITY_WEIGHTS_TABLE - .keys() - .map(|&security_level| { - let ciphertext_modulus_log = 64; - let tmp: String = std::env::temp_dir() - .to_str() - .expect("Invalid tmp dir") - .into(); - let path = format!("{tmp}/optimizer/cache/ks-decomp-cpu-64-{security_level}"); - let function = move |(glwe_params, internal_dim): MacroParam| { - pareto_quantities( - ciphertext_modulus_log, - security_level, - internal_dim, - glwe_params, - ) - }; - ( +pub fn cache(security_level: u64) -> PersistDecompCache { + let ciphertext_modulus_log = 64; + let tmp: String = std::env::temp_dir() + .to_str() + .expect("Invalid tmp dir") + .into(); + let path = format!("{tmp}/optimizer/cache/ks-decomp-cpu-64-{security_level}"); + let function = move |(glwe_params, internal_dim): MacroParam| { + pareto_quantities( + ciphertext_modulus_log, security_level, - PersistentCacheHashMap::new(&path, "v0", function), + internal_dim, + glwe_params, ) - }) - .collect::(); - -#[cfg(not(target_os = "macos"))] -#[static_init::destructor(10)] -extern "C" fn finaly() { - for v in SHARED_CACHE.values() { - v.sync_to_disk(); - } -} - -pub fn for_security(security_level: u64) -> &'static PersistDecompCache { - SHARED_CACHE.get(&security_level).unwrap() + }; + PersistentCacheHashMap::new(&path, "v0", function) } diff --git a/concrete-optimizer/src/optimization/decomposition/mod.rs b/concrete-optimizer/src/optimization/decomposition/mod.rs index ff574d4e7..185047fff 100644 --- a/concrete-optimizer/src/optimization/decomposition/mod.rs +++ b/concrete-optimizer/src/optimization/decomposition/mod.rs @@ -5,3 +5,15 @@ pub mod keyswitch; pub use common::MacroParam; pub use cut::cut_complexity_noise; + +pub struct PersistDecompCache { + pub ks: keyswitch::PersistDecompCache, + pub br: blind_rotate::PersistDecompCache, +} + +pub fn cache(security_level: u64) -> PersistDecompCache { + PersistDecompCache { + ks: keyswitch::cache(security_level), + br: blind_rotate::cache(security_level), + } +} diff --git a/concrete-optimizer/src/optimization/mod.rs b/concrete-optimizer/src/optimization/mod.rs index a72365804..e66098c05 100644 --- a/concrete-optimizer/src/optimization/mod.rs +++ b/concrete-optimizer/src/optimization/mod.rs @@ -1,5 +1,5 @@ pub mod atomic_pattern; pub mod config; pub mod dag; -pub(crate) mod decomposition; +pub mod decomposition; pub mod wop_atomic_pattern; diff --git a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs index f765ce61a..3b74b3d49 100644 --- a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs +++ b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs @@ -9,15 +9,12 @@ use crate::noise_estimator::error::{ use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; use crate::noise_estimator::operators::wop_atomic_pattern::estimate_packing_private_keyswitch; use crate::optimization::atomic_pattern; -use crate::optimization::atomic_pattern::Caches; -use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts; +use crate::optimization::atomic_pattern::{Caches, OptimizationDecompositionsConsts}; use crate::optimization::config::{Config, SearchSpace}; -use crate::optimization::decomposition::blind_rotate; use crate::optimization::decomposition::blind_rotate::BrComplexityNoise; -use crate::optimization::decomposition::cut_complexity_noise; -use crate::optimization::decomposition::keyswitch; use crate::optimization::decomposition::keyswitch::KsComplexityNoise; +use crate::optimization::decomposition::{cut_complexity_noise, PersistDecompCache}; use crate::optimization::wop_atomic_pattern::pareto::BR_CIRCUIT_BOOTSTRAP_PARETO_DECOMP; use crate::parameters::{ GlweParameters, KeyswitchParameters, KsDecompositionParameters, LweDimension, PbsParameters, @@ -457,12 +454,12 @@ fn optimize_raw( search_space: &SearchSpace, n_functions: u64, // Many functions at the same time, stay at 1 for start partitionning: &[u64], + cache: &PersistDecompCache, ) -> OptimizationState { assert!(0.0 < config.maximum_acceptable_error_probability); assert!(config.maximum_acceptable_error_probability < 1.0); let ciphertext_modulus_log = config.ciphertext_modulus_log; - let security_level = config.security_level; // Circuit BS bound // 1 bit of message only here =) @@ -485,10 +482,7 @@ fn optimize_raw( safe_variance: safe_variance_bound, }; - let mut caches = Caches { - blind_rotate: blind_rotate::for_security(security_level).cache(), - keyswitch: keyswitch::for_security(security_level).cache(), - }; + let mut caches = Caches::new(cache); for &glwe_dim in &search_space.glwe_dimensions { for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes { @@ -516,9 +510,7 @@ fn optimize_raw( } } } - - blind_rotate::for_security(security_level).backport(caches.blind_rotate); - keyswitch::for_security(security_level).backport(caches.keyswitch); + caches.backport_to(cache); state } @@ -528,11 +520,19 @@ pub fn optimize_one( config: Config, log_norm: f64, search_space: &SearchSpace, + cache: &PersistDecompCache, ) -> OptimizationState { let coprimes = crt_decomposition::default_coprimes(precision as Precision); let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes); let n_functions = 1; - let mut state = optimize_raw(log_norm, config, search_space, n_functions, &partitionning); + let mut state = optimize_raw( + log_norm, + config, + search_space, + n_functions, + &partitionning, + cache, + ); state.best_solution = state.best_solution.map(|mut sol| -> Solution { sol.crt_decomposition = coprimes; sol @@ -546,9 +546,10 @@ pub fn optimize_one_compat( config: Config, noise_factor: f64, search_space: &SearchSpace, + cache: &PersistDecompCache, ) -> atomic_pattern::OptimizationState { let log_norm = noise_factor.log2(); - let result = optimize_one(precision, config, log_norm, search_space); + let result = optimize_one(precision, config, log_norm, search_space, cache); atomic_pattern::OptimizationState { best_solution: result.best_solution.map(Solution::into), } diff --git a/v0-parameters/src/lib.rs b/v0-parameters/src/lib.rs index 12fa31d8e..b7df36553 100644 --- a/v0-parameters/src/lib.rs +++ b/v0-parameters/src/lib.rs @@ -16,6 +16,7 @@ use concrete_optimizer::optimization::atomic_pattern::{ }; use concrete_optimizer::optimization::config::{Config, SearchSpace}; use concrete_optimizer::optimization::dag::solo_key::optimize::{self as optimize_dag}; +use concrete_optimizer::optimization::decomposition; use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern; use rayon_cond::CondIterator; use std::io::Write; @@ -107,6 +108,8 @@ pub fn all_results(args: &Args) -> Vec> { complexity_model: &CpuComplexity::default(), }; + let cache = decomposition::cache(config.security_level); + precisions_iter .map(|precision| { let mut last_solution = None; @@ -121,6 +124,7 @@ pub fn all_results(args: &Args) -> Vec> { config, noise_scale, &search_space, + &cache, ) } else if args.simulate_dag { optimize_dag::optimize_v0( @@ -129,6 +133,7 @@ pub fn all_results(args: &Args) -> Vec> { config, noise_scale, &search_space, + &cache, ) } else { optimize_atomic_pattern::optimize_one( @@ -137,6 +142,7 @@ pub fn all_results(args: &Args) -> Vec> { config, noise_scale, &search_space, + &cache, ) }; last_solution = result.best_solution;