mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
chore(caches): regroup cache containers
This commit is contained in:
@@ -5,9 +5,8 @@ use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositi
|
||||
use crate::utils::square;
|
||||
use concrete_commons::dispersion::{DispersionParameter, Variance};
|
||||
|
||||
use super::decomposition;
|
||||
use super::decomposition::{
|
||||
blind_rotate, circuit_bootstrap, keyswitch, pp_switch, PersistDecompCache,
|
||||
blind_rotate, circuit_bootstrap, keyswitch, pp_switch, DecompCaches, PersistDecompCaches,
|
||||
};
|
||||
|
||||
// Ref time for v0 table 1 thread: 950ms
|
||||
@@ -49,30 +48,13 @@ pub struct Caches {
|
||||
pub cb_pbs: circuit_bootstrap::Cache,
|
||||
}
|
||||
|
||||
impl Caches {
|
||||
pub fn new(cache: &decomposition::PersistDecompCache) -> Self {
|
||||
Self {
|
||||
blind_rotate: cache.br.cache(),
|
||||
keyswitch: cache.ks.cache(),
|
||||
pp_switch: cache.pp.cache(),
|
||||
cb_pbs: cache.cb.cache(),
|
||||
}
|
||||
}
|
||||
pub fn backport_to(self, cache: &decomposition::PersistDecompCache) {
|
||||
cache.ks.backport(self.keyswitch);
|
||||
cache.br.backport(self.blind_rotate);
|
||||
cache.pp.backport(self.pp_switch);
|
||||
cache.cb.backport(self.cb_pbs);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn update_state_with_best_decompositions(
|
||||
state: &mut OptimizationState,
|
||||
consts: &OptimizationDecompositionsConsts,
|
||||
internal_dim: u64,
|
||||
glwe_params: GlweParameters,
|
||||
caches: &mut Caches,
|
||||
caches: &mut DecompCaches,
|
||||
) {
|
||||
let glwe_poly_size = glwe_params.polynomial_size();
|
||||
let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size;
|
||||
@@ -171,7 +153,7 @@ pub fn optimize_one(
|
||||
config: Config,
|
||||
noise_factor: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
persistent_caches: &PersistDecompCaches,
|
||||
) -> OptimizationState {
|
||||
assert!(0 < precision && precision <= 16);
|
||||
assert!(1.0 <= noise_factor);
|
||||
@@ -217,7 +199,7 @@ pub fn optimize_one(
|
||||
> consts.safe_variance
|
||||
};
|
||||
|
||||
let mut caches = Caches::new(cache);
|
||||
let mut caches = persistent_caches.caches();
|
||||
|
||||
for &glwe_dim in &search_space.glwe_dimensions {
|
||||
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
|
||||
@@ -246,7 +228,7 @@ pub fn optimize_one(
|
||||
}
|
||||
}
|
||||
|
||||
caches.backport_to(cache);
|
||||
persistent_caches.backport(caches);
|
||||
|
||||
if let Some(sol) = state.best_solution {
|
||||
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
|
||||
|
||||
@@ -4,10 +4,10 @@ use crate::dag::unparametrized;
|
||||
use crate::noise_estimator::error;
|
||||
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
use crate::optimization::atomic_pattern::{
|
||||
Caches, OptimizationDecompositionsConsts, OptimizationState, Solution,
|
||||
OptimizationDecompositionsConsts, OptimizationState, Solution,
|
||||
};
|
||||
use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace};
|
||||
use crate::optimization::decomposition::PersistDecompCache;
|
||||
use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches};
|
||||
use crate::parameters::GlweParameters;
|
||||
use crate::security;
|
||||
|
||||
@@ -22,7 +22,7 @@ fn update_best_solution_with_best_decompositions(
|
||||
glwe_params: GlweParameters,
|
||||
input_noise_out: f64,
|
||||
noise_modulus_switching: f64,
|
||||
caches: &mut Caches,
|
||||
caches: &mut DecompCaches,
|
||||
) {
|
||||
assert!(dag.nb_luts > 0);
|
||||
let glwe_poly_size = glwe_params.polynomial_size();
|
||||
@@ -231,7 +231,7 @@ pub fn optimize(
|
||||
dag: &unparametrized::OperationDag,
|
||||
config: Config,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
persistent_caches: &PersistDecompCaches,
|
||||
) -> OptimizationState {
|
||||
let ciphertext_modulus_log = config.ciphertext_modulus_log;
|
||||
let security_level = config.security_level;
|
||||
@@ -267,7 +267,7 @@ pub fn optimize(
|
||||
if dag.nb_luts == 0 {
|
||||
return optimize_no_luts(state, &consts, &dag, search_space);
|
||||
}
|
||||
let mut caches = Caches::new(cache);
|
||||
let mut caches = persistent_caches.caches();
|
||||
|
||||
let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| {
|
||||
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key(
|
||||
@@ -313,7 +313,7 @@ pub fn optimize(
|
||||
}
|
||||
}
|
||||
|
||||
caches.backport_to(cache);
|
||||
persistent_caches.backport(caches);
|
||||
|
||||
if let Some(sol) = state.best_solution {
|
||||
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
|
||||
@@ -331,7 +331,7 @@ pub fn optimize_v0(
|
||||
config: Config,
|
||||
noise_factor: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
cache: &PersistDecompCaches,
|
||||
) -> OptimizationState {
|
||||
use crate::dag::operator::{FunctionTable, Shape};
|
||||
let same_scale_manp = 0.0;
|
||||
@@ -390,7 +390,7 @@ mod tests {
|
||||
|
||||
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
|
||||
|
||||
static SHARED_CACHES: Lazy<PersistDecompCache> = Lazy::new(|| {
|
||||
static SHARED_CACHES: Lazy<PersistDecompCaches> = Lazy::new(|| {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
decomposition::cache(128, processing_unit, None)
|
||||
});
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::noise_estimator::p_error::repeat_p_error;
|
||||
use crate::optimization::atomic_pattern::Solution as WpSolution;
|
||||
use crate::optimization::config::{Config, SearchSpace};
|
||||
use crate::optimization::dag::solo_key::{analyze, optimize};
|
||||
use crate::optimization::decomposition::PersistDecompCache;
|
||||
use crate::optimization::decomposition::PersistDecompCaches;
|
||||
use crate::optimization::wop_atomic_pattern::optimize::optimize_one as wop_optimize;
|
||||
use crate::optimization::wop_atomic_pattern::Solution as WopSolution;
|
||||
use std::ops::RangeInclusive;
|
||||
@@ -36,7 +36,7 @@ pub fn optimize(
|
||||
config: Config,
|
||||
search_space: &SearchSpace,
|
||||
default_log_norm2_woppbs: f64,
|
||||
cache: &PersistDecompCache,
|
||||
caches: &PersistDecompCaches,
|
||||
) -> Option<Solution> {
|
||||
let max_precision = max_precision(dag);
|
||||
let nb_luts = analyze::lut_count_from_dag(dag);
|
||||
@@ -52,12 +52,12 @@ pub fn optimize(
|
||||
config,
|
||||
log_norm,
|
||||
search_space,
|
||||
cache,
|
||||
caches,
|
||||
)
|
||||
.best_solution;
|
||||
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
|
||||
} else {
|
||||
let opt_sol = optimize::optimize(dag, config, search_space, cache).best_solution;
|
||||
let opt_sol = optimize::optimize(dag, config, search_space, caches).best_solution;
|
||||
opt_sol.map(Solution::WpSolution)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,23 +11,57 @@ use crate::config;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct PersistDecompCache {
|
||||
pub struct PersistDecompCaches {
|
||||
pub ks: keyswitch::PersistDecompCache,
|
||||
pub br: blind_rotate::PersistDecompCache,
|
||||
pub pp: pp_switch::PersistDecompCache,
|
||||
pub cb: circuit_bootstrap::PersistDecompCache,
|
||||
}
|
||||
|
||||
pub struct DecompCaches {
|
||||
pub blind_rotate: blind_rotate::Cache,
|
||||
pub keyswitch: keyswitch::Cache,
|
||||
pub pp_switch: pp_switch::Cache,
|
||||
pub cb_pbs: circuit_bootstrap::Cache,
|
||||
}
|
||||
|
||||
pub fn cache(
|
||||
security_level: u64,
|
||||
processing_unit: config::ProcessingUnit,
|
||||
complexity_model: Option<Arc<dyn ComplexityModel>>,
|
||||
) -> PersistDecompCache {
|
||||
let complexity_model = complexity_model.unwrap_or_else(|| processing_unit.complexity_model());
|
||||
PersistDecompCache {
|
||||
ks: keyswitch::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
br: blind_rotate::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
pp: pp_switch::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
cb: circuit_bootstrap::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
) -> PersistDecompCaches {
|
||||
PersistDecompCaches::new(security_level, processing_unit, complexity_model)
|
||||
}
|
||||
|
||||
impl PersistDecompCaches {
|
||||
pub fn new(
|
||||
security_level: u64,
|
||||
processing_unit: config::ProcessingUnit,
|
||||
complexity_model: Option<Arc<dyn ComplexityModel>>,
|
||||
) -> Self {
|
||||
let complexity_model =
|
||||
complexity_model.unwrap_or_else(|| processing_unit.complexity_model());
|
||||
Self {
|
||||
ks: keyswitch::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
br: blind_rotate::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
pp: pp_switch::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
cb: circuit_bootstrap::cache(security_level, processing_unit, complexity_model.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn backport(&self, cache: DecompCaches) {
|
||||
self.ks.backport(cache.keyswitch);
|
||||
self.br.backport(cache.blind_rotate);
|
||||
self.pp.backport(cache.pp_switch);
|
||||
self.cb.backport(cache.cb_pbs);
|
||||
}
|
||||
|
||||
pub fn caches(&self) -> DecompCaches {
|
||||
DecompCaches {
|
||||
blind_rotate: self.br.cache(),
|
||||
keyswitch: self.ks.cache(),
|
||||
pp_switch: self.pp.cache(),
|
||||
cb_pbs: self.cb.cache(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,10 +6,10 @@ use crate::noise_estimator::error::{
|
||||
};
|
||||
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
use crate::optimization::atomic_pattern;
|
||||
use crate::optimization::atomic_pattern::{Caches, OptimizationDecompositionsConsts};
|
||||
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
|
||||
|
||||
use crate::optimization::config::{Config, SearchSpace};
|
||||
use crate::optimization::decomposition::PersistDecompCache;
|
||||
use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches};
|
||||
use crate::parameters::{GlweParameters, LweDimension, PbsParameters};
|
||||
use crate::utils::square;
|
||||
use concrete_commons::dispersion::{DispersionParameter, Variance};
|
||||
@@ -102,7 +102,7 @@ fn update_state_with_best_decompositions(
|
||||
internal_dim: u64,
|
||||
n_functions: u64,
|
||||
partitionning: &[u64],
|
||||
caches: &mut Caches,
|
||||
caches: &mut DecompCaches,
|
||||
) {
|
||||
let ciphertext_modulus_log = consts.config.ciphertext_modulus_log;
|
||||
let precisions_sum = partitionning.iter().copied().sum();
|
||||
@@ -326,7 +326,7 @@ fn optimize_raw(
|
||||
search_space: &SearchSpace,
|
||||
n_functions: u64, // Many functions at the same time, stay at 1 for start
|
||||
partitionning: &[u64],
|
||||
cache: &PersistDecompCache,
|
||||
persistent_caches: &PersistDecompCaches,
|
||||
) -> OptimizationState {
|
||||
assert!(0.0 < config.maximum_acceptable_error_probability);
|
||||
assert!(config.maximum_acceptable_error_probability < 1.0);
|
||||
@@ -354,7 +354,7 @@ fn optimize_raw(
|
||||
safe_variance: safe_variance_bound,
|
||||
};
|
||||
|
||||
let mut caches = Caches::new(cache);
|
||||
let mut caches = persistent_caches.caches();
|
||||
|
||||
for &glwe_dim in &search_space.glwe_dimensions {
|
||||
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
|
||||
@@ -382,7 +382,7 @@ fn optimize_raw(
|
||||
}
|
||||
}
|
||||
}
|
||||
caches.backport_to(cache);
|
||||
persistent_caches.backport(caches);
|
||||
|
||||
state
|
||||
}
|
||||
@@ -392,7 +392,7 @@ pub fn optimize_one(
|
||||
config: Config,
|
||||
log_norm: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
caches: &PersistDecompCaches,
|
||||
) -> OptimizationState {
|
||||
let coprimes = crt_decomposition::default_coprimes(precision as Precision);
|
||||
let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes);
|
||||
@@ -403,7 +403,7 @@ pub fn optimize_one(
|
||||
search_space,
|
||||
n_functions,
|
||||
&partitionning,
|
||||
cache,
|
||||
caches,
|
||||
);
|
||||
state.best_solution = state.best_solution.map(|mut sol| -> Solution {
|
||||
sol.crt_decomposition = coprimes;
|
||||
@@ -418,7 +418,7 @@ pub fn optimize_one_compat(
|
||||
config: Config,
|
||||
noise_factor: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
cache: &PersistDecompCaches,
|
||||
) -> atomic_pattern::OptimizationState {
|
||||
let log_norm = noise_factor.log2();
|
||||
let result = optimize_one(precision, config, log_norm, search_space, cache);
|
||||
|
||||
Reference in New Issue
Block a user