chore(caches): regroup cache containers

This commit is contained in:
rudy
2022-11-08 09:56:40 +01:00
committed by rudy-6-4
parent 2e5e8a6cc3
commit b5f7715e5e
5 changed files with 68 additions and 52 deletions

View File

@@ -5,9 +5,8 @@ use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositi
use crate::utils::square;
use concrete_commons::dispersion::{DispersionParameter, Variance};
use super::decomposition;
use super::decomposition::{
blind_rotate, circuit_bootstrap, keyswitch, pp_switch, PersistDecompCache,
blind_rotate, circuit_bootstrap, keyswitch, pp_switch, DecompCaches, PersistDecompCaches,
};
// Ref time for v0 table 1 thread: 950ms
@@ -49,30 +48,13 @@ pub struct Caches {
pub cb_pbs: circuit_bootstrap::Cache,
}
impl Caches {
pub fn new(cache: &decomposition::PersistDecompCache) -> Self {
Self {
blind_rotate: cache.br.cache(),
keyswitch: cache.ks.cache(),
pp_switch: cache.pp.cache(),
cb_pbs: cache.cb.cache(),
}
}
pub fn backport_to(self, cache: &decomposition::PersistDecompCache) {
cache.ks.backport(self.keyswitch);
cache.br.backport(self.blind_rotate);
cache.pp.backport(self.pp_switch);
cache.cb.backport(self.cb_pbs);
}
}
#[allow(clippy::too_many_lines)]
fn update_state_with_best_decompositions(
state: &mut OptimizationState,
consts: &OptimizationDecompositionsConsts,
internal_dim: u64,
glwe_params: GlweParameters,
caches: &mut Caches,
caches: &mut DecompCaches,
) {
let glwe_poly_size = glwe_params.polynomial_size();
let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size;
@@ -171,7 +153,7 @@ pub fn optimize_one(
config: Config,
noise_factor: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
persistent_caches: &PersistDecompCaches,
) -> OptimizationState {
assert!(0 < precision && precision <= 16);
assert!(1.0 <= noise_factor);
@@ -217,7 +199,7 @@ pub fn optimize_one(
> consts.safe_variance
};
let mut caches = Caches::new(cache);
let mut caches = persistent_caches.caches();
for &glwe_dim in &search_space.glwe_dimensions {
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
@@ -246,7 +228,7 @@ pub fn optimize_one(
}
}
caches.backport_to(cache);
persistent_caches.backport(caches);
if let Some(sol) = state.best_solution {
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);

View File

@@ -4,10 +4,10 @@ use crate::dag::unparametrized;
use crate::noise_estimator::error;
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
use crate::optimization::atomic_pattern::{
Caches, OptimizationDecompositionsConsts, OptimizationState, Solution,
OptimizationDecompositionsConsts, OptimizationState, Solution,
};
use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace};
use crate::optimization::decomposition::PersistDecompCache;
use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches};
use crate::parameters::GlweParameters;
use crate::security;
@@ -22,7 +22,7 @@ fn update_best_solution_with_best_decompositions(
glwe_params: GlweParameters,
input_noise_out: f64,
noise_modulus_switching: f64,
caches: &mut Caches,
caches: &mut DecompCaches,
) {
assert!(dag.nb_luts > 0);
let glwe_poly_size = glwe_params.polynomial_size();
@@ -231,7 +231,7 @@ pub fn optimize(
dag: &unparametrized::OperationDag,
config: Config,
search_space: &SearchSpace,
cache: &PersistDecompCache,
persistent_caches: &PersistDecompCaches,
) -> OptimizationState {
let ciphertext_modulus_log = config.ciphertext_modulus_log;
let security_level = config.security_level;
@@ -267,7 +267,7 @@ pub fn optimize(
if dag.nb_luts == 0 {
return optimize_no_luts(state, &consts, &dag, search_space);
}
let mut caches = Caches::new(cache);
let mut caches = persistent_caches.caches();
let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| {
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key(
@@ -313,7 +313,7 @@ pub fn optimize(
}
}
caches.backport_to(cache);
persistent_caches.backport(caches);
if let Some(sol) = state.best_solution {
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
@@ -331,7 +331,7 @@ pub fn optimize_v0(
config: Config,
noise_factor: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
cache: &PersistDecompCaches,
) -> OptimizationState {
use crate::dag::operator::{FunctionTable, Shape};
let same_scale_manp = 0.0;
@@ -390,7 +390,7 @@ mod tests {
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
static SHARED_CACHES: Lazy<PersistDecompCache> = Lazy::new(|| {
static SHARED_CACHES: Lazy<PersistDecompCaches> = Lazy::new(|| {
let processing_unit = config::ProcessingUnit::Cpu;
decomposition::cache(128, processing_unit, None)
});

View File

@@ -4,7 +4,7 @@ use crate::noise_estimator::p_error::repeat_p_error;
use crate::optimization::atomic_pattern::Solution as WpSolution;
use crate::optimization::config::{Config, SearchSpace};
use crate::optimization::dag::solo_key::{analyze, optimize};
use crate::optimization::decomposition::PersistDecompCache;
use crate::optimization::decomposition::PersistDecompCaches;
use crate::optimization::wop_atomic_pattern::optimize::optimize_one as wop_optimize;
use crate::optimization::wop_atomic_pattern::Solution as WopSolution;
use std::ops::RangeInclusive;
@@ -36,7 +36,7 @@ pub fn optimize(
config: Config,
search_space: &SearchSpace,
default_log_norm2_woppbs: f64,
cache: &PersistDecompCache,
caches: &PersistDecompCaches,
) -> Option<Solution> {
let max_precision = max_precision(dag);
let nb_luts = analyze::lut_count_from_dag(dag);
@@ -52,12 +52,12 @@ pub fn optimize(
config,
log_norm,
search_space,
cache,
caches,
)
.best_solution;
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
} else {
let opt_sol = optimize::optimize(dag, config, search_space, cache).best_solution;
let opt_sol = optimize::optimize(dag, config, search_space, caches).best_solution;
opt_sol.map(Solution::WpSolution)
}
}

View File

@@ -11,23 +11,57 @@ use crate::config;
use std::sync::Arc;
pub struct PersistDecompCache {
pub struct PersistDecompCaches {
pub ks: keyswitch::PersistDecompCache,
pub br: blind_rotate::PersistDecompCache,
pub pp: pp_switch::PersistDecompCache,
pub cb: circuit_bootstrap::PersistDecompCache,
}
pub struct DecompCaches {
pub blind_rotate: blind_rotate::Cache,
pub keyswitch: keyswitch::Cache,
pub pp_switch: pp_switch::Cache,
pub cb_pbs: circuit_bootstrap::Cache,
}
pub fn cache(
security_level: u64,
processing_unit: config::ProcessingUnit,
complexity_model: Option<Arc<dyn ComplexityModel>>,
) -> PersistDecompCache {
let complexity_model = complexity_model.unwrap_or_else(|| processing_unit.complexity_model());
PersistDecompCache {
ks: keyswitch::cache(security_level, processing_unit, complexity_model.clone()),
br: blind_rotate::cache(security_level, processing_unit, complexity_model.clone()),
pp: pp_switch::cache(security_level, processing_unit, complexity_model.clone()),
cb: circuit_bootstrap::cache(security_level, processing_unit, complexity_model.clone()),
) -> PersistDecompCaches {
PersistDecompCaches::new(security_level, processing_unit, complexity_model)
}
impl PersistDecompCaches {
pub fn new(
security_level: u64,
processing_unit: config::ProcessingUnit,
complexity_model: Option<Arc<dyn ComplexityModel>>,
) -> Self {
let complexity_model =
complexity_model.unwrap_or_else(|| processing_unit.complexity_model());
Self {
ks: keyswitch::cache(security_level, processing_unit, complexity_model.clone()),
br: blind_rotate::cache(security_level, processing_unit, complexity_model.clone()),
pp: pp_switch::cache(security_level, processing_unit, complexity_model.clone()),
cb: circuit_bootstrap::cache(security_level, processing_unit, complexity_model.clone()),
}
}
pub fn backport(&self, cache: DecompCaches) {
self.ks.backport(cache.keyswitch);
self.br.backport(cache.blind_rotate);
self.pp.backport(cache.pp_switch);
self.cb.backport(cache.cb_pbs);
}
pub fn caches(&self) -> DecompCaches {
DecompCaches {
blind_rotate: self.br.cache(),
keyswitch: self.ks.cache(),
pp_switch: self.pp.cache(),
cb_pbs: self.cb.cache(),
}
}
}

View File

@@ -6,10 +6,10 @@ use crate::noise_estimator::error::{
};
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
use crate::optimization::atomic_pattern;
use crate::optimization::atomic_pattern::{Caches, OptimizationDecompositionsConsts};
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
use crate::optimization::config::{Config, SearchSpace};
use crate::optimization::decomposition::PersistDecompCache;
use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches};
use crate::parameters::{GlweParameters, LweDimension, PbsParameters};
use crate::utils::square;
use concrete_commons::dispersion::{DispersionParameter, Variance};
@@ -102,7 +102,7 @@ fn update_state_with_best_decompositions(
internal_dim: u64,
n_functions: u64,
partitionning: &[u64],
caches: &mut Caches,
caches: &mut DecompCaches,
) {
let ciphertext_modulus_log = consts.config.ciphertext_modulus_log;
let precisions_sum = partitionning.iter().copied().sum();
@@ -326,7 +326,7 @@ fn optimize_raw(
search_space: &SearchSpace,
n_functions: u64, // Many functions at the same time, stay at 1 for start
partitionning: &[u64],
cache: &PersistDecompCache,
persistent_caches: &PersistDecompCaches,
) -> OptimizationState {
assert!(0.0 < config.maximum_acceptable_error_probability);
assert!(config.maximum_acceptable_error_probability < 1.0);
@@ -354,7 +354,7 @@ fn optimize_raw(
safe_variance: safe_variance_bound,
};
let mut caches = Caches::new(cache);
let mut caches = persistent_caches.caches();
for &glwe_dim in &search_space.glwe_dimensions {
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
@@ -382,7 +382,7 @@ fn optimize_raw(
}
}
}
caches.backport_to(cache);
persistent_caches.backport(caches);
state
}
@@ -392,7 +392,7 @@ pub fn optimize_one(
config: Config,
log_norm: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
caches: &PersistDecompCaches,
) -> OptimizationState {
let coprimes = crt_decomposition::default_coprimes(precision as Precision);
let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes);
@@ -403,7 +403,7 @@ pub fn optimize_one(
search_space,
n_functions,
&partitionning,
cache,
caches,
);
state.best_solution = state.best_solution.map(|mut sol| -> Solution {
sol.crt_decomposition = coprimes;
@@ -418,7 +418,7 @@ pub fn optimize_one_compat(
config: Config,
noise_factor: f64,
search_space: &SearchSpace,
cache: &PersistDecompCache,
cache: &PersistDecompCaches,
) -> atomic_pattern::OptimizationState {
let log_norm = noise_factor.log2();
let result = optimize_one(precision, config, log_norm, search_space, cache);