chore(optimizer): make cache not static

This commit is contained in:
Mayeul@Zama
2022-09-19 18:13:03 +02:00
committed by mayeul-zama
parent 5a2ddccc6f
commit 48962811b9
13 changed files with 218 additions and 153 deletions

View File

@@ -3,6 +3,7 @@ use concrete_optimizer::computing_cost::cpu::CpuComplexity;
use concrete_optimizer::global_parameters::DEFAUT_DOMAINS;
use concrete_optimizer::optimization::atomic_pattern::{self as optimize_atomic_pattern};
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::decomposition;
use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern;
pub const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
@@ -40,6 +41,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
complexity_model: &CpuComplexity::default(),
};
let cache = decomposition::cache(security_level);
let solutions: Vec<_> = log_norm2s
.clone()
.filter_map(|log_norm2| {
@@ -51,6 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
config,
noise_scale,
&search_space,
&cache,
)
.best_solution
.map(|a| (log_norm2, a.complexity))
@@ -61,9 +65,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
.filter_map(|log_norm2| {
let noise_scale = 2_f64.powi(log_norm2 as i32);
optimize_wop_atomic_pattern::optimize_one(precision, config, noise_scale, &search_space)
.best_solution
.map(|a| (log_norm2, a.complexity))
optimize_wop_atomic_pattern::optimize_one(
precision,
config,
noise_scale,
&search_space,
&cache,
)
.best_solution
.map(|a| (log_norm2, a.complexity))
})
.collect();

View File

@@ -3,6 +3,7 @@ use concrete_optimizer::computing_cost::cpu::CpuComplexity;
use concrete_optimizer::global_parameters::DEFAUT_DOMAINS;
use concrete_optimizer::optimization::atomic_pattern::{self as optimize_atomic_pattern};
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::decomposition;
use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern;
pub const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
@@ -40,6 +41,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
complexity_model: &CpuComplexity::default(),
};
let cache = decomposition::cache(security_level);
let solutions: Vec<_> = precisions
.clone()
.filter_map(|precision| {
@@ -51,6 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
config,
noise_factor,
&search_space,
&cache,
)
.best_solution
.map(|a| (precision, a.complexity))
@@ -64,6 +68,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
config,
log_norm2 as f64,
&search_space,
&cache,
)
.best_solution
.map(|a| (precision, a.complexity))

View File

@@ -5,6 +5,7 @@ use concrete_optimizer::dag::operator::{
use concrete_optimizer::dag::unparametrized;
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::Solution as DagSolution;
use concrete_optimizer::optimization::decomposition;
fn no_solution() -> ffi::Solution {
ffi::Solution {
@@ -43,6 +44,7 @@ fn optimize_bootstrap(
config,
noise_factor,
&search_space,
&decomposition::cache(security_level),
);
result
.best_solution
@@ -219,6 +221,7 @@ impl OperationDag {
&self.0,
config,
&search_space,
&decomposition::cache(security_level),
);
result
.best_solution
@@ -239,12 +242,14 @@ impl OperationDag {
};
let search_space = SearchSpace::default();
let cache = decomposition::cache(security_level);
let result = concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize(
&self.0,
config,
&search_space,
default_log_norm2_woppbs,
&cache,
);
result.map_or_else(no_dag_solution, |solution| solution.into())
}

View File

@@ -15,7 +15,6 @@ static_init = "1.0.3"
serde = { version = "1.0", features = ["derive"] }
rmp-serde = "1.1.0"
statrs = "0.15.0"
lazy_static = "1.4.0"
[dev-dependencies]
approx = "0.5"

View File

@@ -5,7 +5,8 @@ use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositi
use crate::utils::square;
use concrete_commons::dispersion::{DispersionParameter, Variance};
use super::decomposition::{blind_rotate, cut_complexity_noise, keyswitch};
use super::decomposition;
use super::decomposition::{blind_rotate, cut_complexity_noise, keyswitch, PersistDecompCache};
// Ref time for v0 table 1 thread: 950ms
const CUTS: bool = true; // 80ms
@@ -46,6 +47,19 @@ pub struct Caches {
pub keyswitch: keyswitch::Cache,
}
impl Caches {
pub fn new(cache: &decomposition::PersistDecompCache) -> Self {
Self {
blind_rotate: cache.br.cache(),
keyswitch: cache.ks.cache(),
}
}
pub fn backport_to(self, cache: &decomposition::PersistDecompCache) {
cache.ks.backport(self.keyswitch);
cache.br.backport(self.blind_rotate);
}
}
#[allow(clippy::too_many_lines)]
fn update_state_with_best_decompositions(
state: &mut OptimizationState,
@@ -182,6 +196,7 @@ pub fn optimize_one(
config: Config,
noise_factor: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
) -> OptimizationState {
assert!(0 < precision && precision <= 16);
assert!(1.0 <= noise_factor);
@@ -193,7 +208,6 @@ pub fn optimize_one(
// the blind rotate decomposition
let ciphertext_modulus_log = config.ciphertext_modulus_log;
let security_level = config.security_level;
let safe_variance = error::safe_variance_bound_2padbits(
precision,
ciphertext_modulus_log,
@@ -228,10 +242,7 @@ pub fn optimize_one(
> consts.safe_variance
};
let mut caches = Caches {
blind_rotate: blind_rotate::for_security(security_level).cache(),
keyswitch: keyswitch::for_security(security_level).cache(),
};
let mut caches = Caches::new(cache);
for &glwe_dim in &search_space.glwe_dimensions {
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
@@ -260,8 +271,7 @@ pub fn optimize_one(
}
}
blind_rotate::for_security(security_level).backport(caches.blind_rotate);
keyswitch::for_security(security_level).backport(caches.keyswitch);
caches.backport_to(cache);
if let Some(sol) = state.best_solution {
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);

View File

@@ -7,7 +7,7 @@ use crate::optimization::atomic_pattern::{
Caches, OptimizationDecompositionsConsts, OptimizationState, Solution,
};
use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace};
use crate::optimization::decomposition::{blind_rotate, keyswitch};
use crate::optimization::decomposition::PersistDecompCache;
use crate::parameters::GlweParameters;
use crate::security;
@@ -231,6 +231,7 @@ pub fn optimize(
dag: &unparametrized::OperationDag,
config: Config,
search_space: &SearchSpace,
cache: &PersistDecompCache,
) -> OptimizationState {
let ciphertext_modulus_log = config.ciphertext_modulus_log;
let security_level = config.security_level;
@@ -266,11 +267,7 @@ pub fn optimize(
if dag.nb_luts == 0 {
return optimize_no_luts(state, &consts, &dag, search_space);
}
let mut caches = Caches {
blind_rotate: blind_rotate::for_security(security_level).cache(),
keyswitch: keyswitch::for_security(security_level).cache(),
};
let mut caches = Caches::new(cache);
let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| {
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key(
@@ -316,8 +313,7 @@ pub fn optimize(
}
}
blind_rotate::for_security(security_level).backport(caches.blind_rotate);
keyswitch::for_security(security_level).backport(caches.keyswitch);
caches.backport_to(cache);
if let Some(sol) = state.best_solution {
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
@@ -335,6 +331,7 @@ pub fn optimize_v0(
config: Config,
noise_factor: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
) -> OptimizationState {
use crate::dag::operator::{FunctionTable, Shape};
let same_scale_manp = 0.0;
@@ -349,7 +346,7 @@ pub fn optimize_v0(
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
let mut state = optimize(&dag, config, search_space);
let mut state = optimize(&dag, config, search_space, cache);
if let Some(sol) = &mut state.best_solution {
sol.complexity /= 2.0;
}
@@ -364,9 +361,9 @@ mod tests {
use crate::computing_cost::cpu::CpuComplexity;
use crate::dag::operator::{FunctionTable, Shape, Weights};
use crate::noise_estimator::p_error::repeat_p_error;
use crate::optimization::atomic_pattern;
use crate::optimization::config::SearchSpace;
use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin;
use crate::optimization::{atomic_pattern, decomposition};
use crate::utils::square;
fn small_relative_diff(v1: f64, v2: f64) -> bool {
@@ -390,7 +387,10 @@ mod tests {
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
fn optimize(
dag: &unparametrized::OperationDag,
cache: &PersistDecompCache,
) -> OptimizationState {
let config = Config {
security_level: 128,
maximum_acceptable_error_probability: _4_SIGMA,
@@ -400,7 +400,7 @@ mod tests {
let search_space = SearchSpace::default();
super::optimize(dag, config, &search_space)
super::optimize(dag, config, &search_space, cache)
}
struct Times {
@@ -439,16 +439,38 @@ mod tests {
complexity_model: &CpuComplexity::default(),
};
let _ = optimize_v0(sum_size, precision, config, weight as f64, &search_space);
let cache = decomposition::cache(config.security_level);
let _ = optimize_v0(
sum_size,
precision,
config,
weight as f64,
&search_space,
&cache,
);
// ensure cache is filled
let chrono = Instant::now();
let state = optimize_v0(sum_size, precision, config, weight as f64, &search_space);
let state = optimize_v0(
sum_size,
precision,
config,
weight as f64,
&search_space,
&cache,
);
times.dag_time += chrono.elapsed().as_nanos();
let chrono = Instant::now();
let state_ref =
atomic_pattern::optimize_one(sum_size, precision, config, weight as f64, &search_space);
let state_ref = atomic_pattern::optimize_one(
sum_size,
precision,
config,
weight as f64,
&search_space,
&cache,
);
times.worst_time += chrono.elapsed().as_nanos();
assert_eq!(
state.best_solution.is_some(),
@@ -476,6 +498,10 @@ mod tests {
}
fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) {
let security_level = 128;
let cache = decomposition::cache(security_level);
let mut dag = unparametrized::OperationDag::new();
{
let input1 = dag.add_input(precision, Shape::number());
@@ -488,7 +514,7 @@ mod tests {
let dag2 = analyze::analyze(
&dag,
&NoiseBoundConfig {
security_level: 128,
security_level,
maximum_acceptable_error_probability: _4_SIGMA,
ciphertext_modulus_log: 64,
},
@@ -506,15 +532,21 @@ mod tests {
let search_space = SearchSpace::default();
let config = Config {
security_level: 128,
security_level,
maximum_acceptable_error_probability: _4_SIGMA,
ciphertext_modulus_log: 64,
complexity_model: &CpuComplexity::default(),
};
let state = optimize(&dag);
let state_ref =
atomic_pattern::optimize_one(1, precision as u64, config, weight as f64, &search_space);
let state = optimize(&dag, &cache);
let state_ref = atomic_pattern::optimize_one(
1,
precision as u64,
config,
weight as f64,
&search_space,
&cache,
);
assert_eq!(
state.best_solution.is_some(),
state_ref.best_solution.is_some()
@@ -531,7 +563,7 @@ mod tests {
assert!(sol.global_p_error <= 1.0);
}
fn no_lut_vs_lut(precision: Precision) {
fn no_lut_vs_lut(precision: Precision, cache: &PersistDecompCache) {
let mut dag_lut = unparametrized::OperationDag::new();
let input1 = dag_lut.add_input(precision as u8, Shape::number());
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision);
@@ -539,8 +571,8 @@ mod tests {
let mut dag_no_lut = unparametrized::OperationDag::new();
let _input2 = dag_no_lut.add_input(precision as u8, Shape::number());
let state_no_lut = optimize(&dag_no_lut);
let state_lut = optimize(&dag_lut);
let state_no_lut = optimize(&dag_no_lut, cache);
let state_lut = optimize(&dag_lut, cache);
assert_eq!(
state_no_lut.best_solution.is_some(),
state_lut.best_solution.is_some()
@@ -556,14 +588,16 @@ mod tests {
}
#[test]
fn test_lut_vs_no_lut() {
let cache = decomposition::cache(128);
for precision in 1..=8 {
no_lut_vs_lut(precision);
no_lut_vs_lut(precision, &cache);
}
}
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
precision: Precision,
weight: u64,
cache: &PersistDecompCache,
) {
let weight = &Weights::number(weight);
@@ -583,8 +617,8 @@ mod tests {
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision);
}
let state_1 = optimize(&dag_1);
let state_2 = optimize(&dag_2);
let state_1 = optimize(&dag_1, cache);
let state_2 = optimize(&dag_2, cache);
if state_1.best_solution.is_none() {
assert!(state_2.best_solution.is_none());
@@ -597,15 +631,18 @@ mod tests {
#[test]
fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() {
let cache = decomposition::cache(128);
for log_weight in 1..=16 {
let weight = 1 << log_weight;
for precision in 5..=9 {
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight);
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
precision, weight, &cache,
);
}
}
}
fn lut_1_layer_has_better_complexity(precision: Precision) {
fn lut_1_layer_has_better_complexity(precision: Precision, cache: &PersistDecompCache) {
let dag_1_layer = {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, Shape::number());
@@ -621,17 +658,18 @@ mod tests {
dag
};
let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap();
let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap();
let sol_1_layer = optimize(&dag_1_layer, cache).best_solution.unwrap();
let sol_2_layer = optimize(&dag_2_layer, cache).best_solution.unwrap();
assert!(sol_1_layer.complexity < sol_2_layer.complexity);
}
#[test]
fn test_lut_1_layer_is_better() {
let cache = decomposition::cache(128);
// for some reason on 4, 5, 6, the complexity is already minimal
// this could be due to pre-defined pareto set
for precision in [1, 2, 3, 7, 8] {
lut_1_layer_has_better_complexity(precision);
lut_1_layer_has_better_complexity(precision, &cache);
}
}
@@ -643,7 +681,10 @@ mod tests {
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
}
fn assert_multi_precision_dominate_single(weight: u64) -> Option<bool> {
fn assert_multi_precision_dominate_single(
weight: u64,
cache: &PersistDecompCache,
) -> Option<bool> {
let low_precision = 4u8;
let high_precision = 5u8;
let mut dag_low = unparametrized::OperationDag::new();
@@ -656,12 +697,12 @@ mod tests {
circuit(&mut dag_multi, low_precision, weight);
circuit(&mut dag_multi, high_precision, 1);
}
let state_multi = optimize(&dag_multi);
let state_multi = optimize(&dag_multi, cache);
let mut sol_multi = state_multi.best_solution?;
let state_low = optimize(&dag_low);
let state_high = optimize(&dag_high);
let state_low = optimize(&dag_low, cache);
let state_high = optimize(&dag_high, cache);
let sol_low = state_low.best_solution.unwrap();
let sol_high = state_high.best_solution.unwrap();
@@ -680,10 +721,11 @@ mod tests {
#[test]
fn test_multi_precision_dominate_single() {
let cache = decomposition::cache(128);
let mut prev = Some(true); // true -> ... -> true -> false -> ... -> false
for log2_weight in 0..29 {
let weight = 1 << log2_weight;
let current = assert_multi_precision_dominate_single(weight);
let current = assert_multi_precision_dominate_single(weight, &cache);
#[allow(clippy::match_like_matches_macro)] // less readable
let authorized = match (prev, current) {
(Some(false), Some(true)) => false,
@@ -713,22 +755,28 @@ mod tests {
#[test]
fn test_global_p_error_input() {
let cache = decomposition::cache(128);
for precision in [4_u8, 8] {
for weight in [1, 3, 27, 243, 729] {
for dim in [1, 2, 16, 32] {
let _ = check_global_p_error_input(dim, weight, precision);
let _ = check_global_p_error_input(dim, weight, precision, &cache);
}
}
}
}
fn check_global_p_error_input(dim: u64, weight: u64, precision: u8) -> f64 {
fn check_global_p_error_input(
dim: u64,
weight: u64,
precision: u8,
cache: &PersistDecompCache,
) -> f64 {
let shape = Shape::vector(dim);
let weights = Weights::number(weight);
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, shape);
let _dot1 = dag.add_dot([input1], weights); // this is just several multiply
let state = optimize(&dag);
let state = optimize(&dag, cache);
let sol = state.best_solution.unwrap();
let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim);
approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim);
@@ -737,16 +785,22 @@ mod tests {
#[test]
fn test_global_p_error_lut() {
let cache = decomposition::cache(128);
for precision in [4_u8, 8] {
for weight in [1, 3, 27, 243, 729] {
for depth in [2, 16, 32] {
check_global_p_error_lut(depth, weight, precision);
check_global_p_error_lut(depth, weight, precision, &cache);
}
}
}
}
fn check_global_p_error_lut(depth: u64, weight: u64, precision: u8) {
fn check_global_p_error_lut(
depth: u64,
weight: u64,
precision: u8,
cache: &PersistDecompCache,
) {
let shape = Shape::number();
let weights = Weights::number(weight);
let mut dag = unparametrized::OperationDag::new();
@@ -755,7 +809,7 @@ mod tests {
let dot = dag.add_dot([last_val], &weights);
last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision);
}
let state = optimize(&dag);
let state = optimize(&dag, cache);
let sol = state.best_solution.unwrap();
// the first lut on input has reduced impact on error probability
let lower_nb_dominating_lut = depth - 1;
@@ -792,6 +846,7 @@ mod tests {
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
#[test]
fn test_global_p_error_dominating_lut() {
let cache = decomposition::cache(128);
let depth = 128;
let weights_low = 1;
let weights_high = 1;
@@ -804,7 +859,7 @@ mod tests {
weights_low,
weights_high,
);
let sol = optimize(&dag).best_solution.unwrap();
let sol = optimize(&dag, &cache).best_solution.unwrap();
// the 2 first luts and low precision/weight luts have little impact on error probability
let nb_dominating_lut = depth - 1;
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
@@ -819,6 +874,7 @@ mod tests {
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
#[test]
fn test_global_p_error_non_dominating_lut() {
let cache = decomposition::cache(128);
let depth = 128;
let weights_low = 1024 * 1024 * 3;
let weights_high = 1;
@@ -831,7 +887,7 @@ mod tests {
weights_low,
weights_high,
);
let sol = optimize(&dag).best_solution.unwrap();
let sol = optimize(&dag, &cache).best_solution.unwrap();
// all intern luts have an impact on error probability almost equaly
let nb_dominating_lut = (2 * depth) - 1;
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);

View File

@@ -4,6 +4,7 @@ use crate::noise_estimator::p_error::repeat_p_error;
use crate::optimization::atomic_pattern::Solution as WpSolution;
use crate::optimization::config::{Config, SearchSpace};
use crate::optimization::dag::solo_key::{analyze, optimize};
use crate::optimization::decomposition::PersistDecompCache;
use crate::optimization::wop_atomic_pattern::optimize::optimize_one as wop_optimize;
use crate::optimization::wop_atomic_pattern::Solution as WopSolution;
use std::ops::RangeInclusive;
@@ -48,6 +49,7 @@ pub fn optimize(
config: Config,
search_space: &SearchSpace,
default_log_norm2_woppbs: f64,
cache: &PersistDecompCache,
) -> Option<Solution> {
let max_precision = max_precision(dag);
let nb_luts = analyze::lut_count_from_dag(dag);
@@ -58,11 +60,17 @@ pub fn optimize(
let default_log_norm = default_log_norm2_woppbs;
let worst_log_norm = analyze::worst_log_norm(dag);
let log_norm = default_log_norm.min(worst_log_norm);
let opt_sol =
wop_optimize(fallback_16b_precision, config, log_norm, search_space).best_solution;
let opt_sol = wop_optimize(
fallback_16b_precision,
config,
log_norm,
search_space,
cache,
)
.best_solution;
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
} else {
let opt_sol = optimize::optimize(dag, config, search_space).best_solution;
let opt_sol = optimize::optimize(dag, config, search_space, cache).best_solution;
opt_sol.map(Solution::WpSolution)
}
}

View File

@@ -1,5 +1,3 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use concrete_commons::dispersion::DispersionParameter;
@@ -8,8 +6,6 @@ use crate::computing_cost::operators::pbs::PbsComplexity;
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
use crate::parameters::{BrDecompositionParameters, GlweParameters, LweDimension, PbsParameters};
use crate::security;
use crate::security::security_weights::SECURITY_WEIGHTS_TABLE;
use crate::utils::cache::ephemeral::{CacheHashMap, EphemeralCache};
use crate::utils::cache::persistent::PersistentCacheHashMap;
@@ -120,42 +116,22 @@ impl Cache {
}
}
type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<BrComplexityNoise>>;
type MultiSecPersistDecompCache = HashMap<u64, PersistDecompCache>; // just to attach a finaly
pub type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<BrComplexityNoise>>;
#[static_init::dynamic]
pub static SHARED_CACHE: MultiSecPersistDecompCache = SECURITY_WEIGHTS_TABLE
.keys()
.map(|&security_level| {
let ciphertext_modulus_log = 64;
let tmp: String = std::env::temp_dir()
.to_str()
.expect("Invalid tmp dir")
.into();
let path = format!("{tmp}/optimizer/cache/br-decomp-cpu-64-{security_level}");
let function = move |(glwe_params, internal_dim): MacroParam| {
pareto_quantities(
ciphertext_modulus_log,
security_level,
internal_dim,
glwe_params,
)
};
(
pub fn cache(security_level: u64) -> PersistDecompCache {
let ciphertext_modulus_log = 64;
let tmp: String = std::env::temp_dir()
.to_str()
.expect("Invalid tmp dir")
.into();
let path = format!("{tmp}/optimizer/cache/br-decomp-cpu-64-{security_level}");
let function = move |(glwe_params, internal_dim): MacroParam| {
pareto_quantities(
ciphertext_modulus_log,
security_level,
PersistentCacheHashMap::new(&path, "v0", function),
internal_dim,
glwe_params,
)
})
.collect::<MultiSecPersistDecompCache>();
#[cfg(not(target_os = "macos"))]
#[static_init::destructor(10)]
extern "C" fn finaly() {
for v in SHARED_CACHE.values() {
v.sync_to_disk();
}
}
pub fn for_security(security_level: u64) -> &'static PersistDecompCache {
SHARED_CACHE.get(&security_level).unwrap()
};
PersistentCacheHashMap::new(&path, "v0", function)
}

View File

@@ -1,5 +1,3 @@
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use concrete_commons::dispersion::DispersionParameter;
@@ -9,7 +7,6 @@ use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
use crate::parameters::{
GlweParameters, KeyswitchParameters, KsDecompositionParameters, LweDimension,
};
use crate::security::security_weights::SECURITY_WEIGHTS_TABLE;
use crate::utils::cache::ephemeral::{CacheHashMap, EphemeralCache};
use crate::utils::cache::persistent::PersistentCacheHashMap;
@@ -121,42 +118,22 @@ impl Cache {
}
}
type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<KsComplexityNoise>>;
type MultiSecPersistDecompCache = HashMap<u64, PersistDecompCache>;
pub type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<KsComplexityNoise>>;
#[static_init::dynamic]
pub static SHARED_CACHE: MultiSecPersistDecompCache = SECURITY_WEIGHTS_TABLE
.keys()
.map(|&security_level| {
let ciphertext_modulus_log = 64;
let tmp: String = std::env::temp_dir()
.to_str()
.expect("Invalid tmp dir")
.into();
let path = format!("{tmp}/optimizer/cache/ks-decomp-cpu-64-{security_level}");
let function = move |(glwe_params, internal_dim): MacroParam| {
pareto_quantities(
ciphertext_modulus_log,
security_level,
internal_dim,
glwe_params,
)
};
(
pub fn cache(security_level: u64) -> PersistDecompCache {
let ciphertext_modulus_log = 64;
let tmp: String = std::env::temp_dir()
.to_str()
.expect("Invalid tmp dir")
.into();
let path = format!("{tmp}/optimizer/cache/ks-decomp-cpu-64-{security_level}");
let function = move |(glwe_params, internal_dim): MacroParam| {
pareto_quantities(
ciphertext_modulus_log,
security_level,
PersistentCacheHashMap::new(&path, "v0", function),
internal_dim,
glwe_params,
)
})
.collect::<MultiSecPersistDecompCache>();
#[cfg(not(target_os = "macos"))]
#[static_init::destructor(10)]
extern "C" fn finaly() {
for v in SHARED_CACHE.values() {
v.sync_to_disk();
}
}
pub fn for_security(security_level: u64) -> &'static PersistDecompCache {
SHARED_CACHE.get(&security_level).unwrap()
};
PersistentCacheHashMap::new(&path, "v0", function)
}

View File

@@ -5,3 +5,15 @@ pub mod keyswitch;
pub use common::MacroParam;
pub use cut::cut_complexity_noise;
pub struct PersistDecompCache {
pub ks: keyswitch::PersistDecompCache,
pub br: blind_rotate::PersistDecompCache,
}
pub fn cache(security_level: u64) -> PersistDecompCache {
PersistDecompCache {
ks: keyswitch::cache(security_level),
br: blind_rotate::cache(security_level),
}
}

View File

@@ -1,5 +1,5 @@
pub mod atomic_pattern;
pub mod config;
pub mod dag;
pub(crate) mod decomposition;
pub mod decomposition;
pub mod wop_atomic_pattern;

View File

@@ -9,15 +9,12 @@ use crate::noise_estimator::error::{
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
use crate::noise_estimator::operators::wop_atomic_pattern::estimate_packing_private_keyswitch;
use crate::optimization::atomic_pattern;
use crate::optimization::atomic_pattern::Caches;
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
use crate::optimization::atomic_pattern::{Caches, OptimizationDecompositionsConsts};
use crate::optimization::config::{Config, SearchSpace};
use crate::optimization::decomposition::blind_rotate;
use crate::optimization::decomposition::blind_rotate::BrComplexityNoise;
use crate::optimization::decomposition::cut_complexity_noise;
use crate::optimization::decomposition::keyswitch;
use crate::optimization::decomposition::keyswitch::KsComplexityNoise;
use crate::optimization::decomposition::{cut_complexity_noise, PersistDecompCache};
use crate::optimization::wop_atomic_pattern::pareto::BR_CIRCUIT_BOOTSTRAP_PARETO_DECOMP;
use crate::parameters::{
GlweParameters, KeyswitchParameters, KsDecompositionParameters, LweDimension, PbsParameters,
@@ -457,12 +454,12 @@ fn optimize_raw(
search_space: &SearchSpace,
n_functions: u64, // Many functions at the same time, stay at 1 for start
partitionning: &[u64],
cache: &PersistDecompCache,
) -> OptimizationState {
assert!(0.0 < config.maximum_acceptable_error_probability);
assert!(config.maximum_acceptable_error_probability < 1.0);
let ciphertext_modulus_log = config.ciphertext_modulus_log;
let security_level = config.security_level;
// Circuit BS bound
// 1 bit of message only here =)
@@ -485,10 +482,7 @@ fn optimize_raw(
safe_variance: safe_variance_bound,
};
let mut caches = Caches {
blind_rotate: blind_rotate::for_security(security_level).cache(),
keyswitch: keyswitch::for_security(security_level).cache(),
};
let mut caches = Caches::new(cache);
for &glwe_dim in &search_space.glwe_dimensions {
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
@@ -516,9 +510,7 @@ fn optimize_raw(
}
}
}
blind_rotate::for_security(security_level).backport(caches.blind_rotate);
keyswitch::for_security(security_level).backport(caches.keyswitch);
caches.backport_to(cache);
state
}
@@ -528,11 +520,19 @@ pub fn optimize_one(
config: Config,
log_norm: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
) -> OptimizationState {
let coprimes = crt_decomposition::default_coprimes(precision as Precision);
let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes);
let n_functions = 1;
let mut state = optimize_raw(log_norm, config, search_space, n_functions, &partitionning);
let mut state = optimize_raw(
log_norm,
config,
search_space,
n_functions,
&partitionning,
cache,
);
state.best_solution = state.best_solution.map(|mut sol| -> Solution {
sol.crt_decomposition = coprimes;
sol
@@ -546,9 +546,10 @@ pub fn optimize_one_compat(
config: Config,
noise_factor: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
) -> atomic_pattern::OptimizationState {
let log_norm = noise_factor.log2();
let result = optimize_one(precision, config, log_norm, search_space);
let result = optimize_one(precision, config, log_norm, search_space, cache);
atomic_pattern::OptimizationState {
best_solution: result.best_solution.map(Solution::into),
}

View File

@@ -16,6 +16,7 @@ use concrete_optimizer::optimization::atomic_pattern::{
};
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::dag::solo_key::optimize::{self as optimize_dag};
use concrete_optimizer::optimization::decomposition;
use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern;
use rayon_cond::CondIterator;
use std::io::Write;
@@ -107,6 +108,8 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
complexity_model: &CpuComplexity::default(),
};
let cache = decomposition::cache(config.security_level);
precisions_iter
.map(|precision| {
let mut last_solution = None;
@@ -121,6 +124,7 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
config,
noise_scale,
&search_space,
&cache,
)
} else if args.simulate_dag {
optimize_dag::optimize_v0(
@@ -129,6 +133,7 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
config,
noise_scale,
&search_space,
&cache,
)
} else {
optimize_atomic_pattern::optimize_one(
@@ -137,6 +142,7 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
config,
noise_scale,
&search_space,
&cache,
)
};
last_solution = result.best_solution;