mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
chore(optimizer): make cache not static
This commit is contained in:
@@ -3,6 +3,7 @@ use concrete_optimizer::computing_cost::cpu::CpuComplexity;
|
||||
use concrete_optimizer::global_parameters::DEFAUT_DOMAINS;
|
||||
use concrete_optimizer::optimization::atomic_pattern::{self as optimize_atomic_pattern};
|
||||
use concrete_optimizer::optimization::config::{Config, SearchSpace};
|
||||
use concrete_optimizer::optimization::decomposition;
|
||||
use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern;
|
||||
|
||||
pub const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
|
||||
@@ -40,6 +41,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let cache = decomposition::cache(security_level);
|
||||
|
||||
let solutions: Vec<_> = log_norm2s
|
||||
.clone()
|
||||
.filter_map(|log_norm2| {
|
||||
@@ -51,6 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
config,
|
||||
noise_scale,
|
||||
&search_space,
|
||||
&cache,
|
||||
)
|
||||
.best_solution
|
||||
.map(|a| (log_norm2, a.complexity))
|
||||
@@ -61,9 +65,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
.filter_map(|log_norm2| {
|
||||
let noise_scale = 2_f64.powi(log_norm2 as i32);
|
||||
|
||||
optimize_wop_atomic_pattern::optimize_one(precision, config, noise_scale, &search_space)
|
||||
.best_solution
|
||||
.map(|a| (log_norm2, a.complexity))
|
||||
optimize_wop_atomic_pattern::optimize_one(
|
||||
precision,
|
||||
config,
|
||||
noise_scale,
|
||||
&search_space,
|
||||
&cache,
|
||||
)
|
||||
.best_solution
|
||||
.map(|a| (log_norm2, a.complexity))
|
||||
})
|
||||
.collect();
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use concrete_optimizer::computing_cost::cpu::CpuComplexity;
|
||||
use concrete_optimizer::global_parameters::DEFAUT_DOMAINS;
|
||||
use concrete_optimizer::optimization::atomic_pattern::{self as optimize_atomic_pattern};
|
||||
use concrete_optimizer::optimization::config::{Config, SearchSpace};
|
||||
use concrete_optimizer::optimization::decomposition;
|
||||
use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern;
|
||||
|
||||
pub const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
|
||||
@@ -40,6 +41,8 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let cache = decomposition::cache(security_level);
|
||||
|
||||
let solutions: Vec<_> = precisions
|
||||
.clone()
|
||||
.filter_map(|precision| {
|
||||
@@ -51,6 +54,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
config,
|
||||
noise_factor,
|
||||
&search_space,
|
||||
&cache,
|
||||
)
|
||||
.best_solution
|
||||
.map(|a| (precision, a.complexity))
|
||||
@@ -64,6 +68,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
config,
|
||||
log_norm2 as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
)
|
||||
.best_solution
|
||||
.map(|a| (precision, a.complexity))
|
||||
|
||||
@@ -5,6 +5,7 @@ use concrete_optimizer::dag::operator::{
|
||||
use concrete_optimizer::dag::unparametrized;
|
||||
use concrete_optimizer::optimization::config::{Config, SearchSpace};
|
||||
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::Solution as DagSolution;
|
||||
use concrete_optimizer::optimization::decomposition;
|
||||
|
||||
fn no_solution() -> ffi::Solution {
|
||||
ffi::Solution {
|
||||
@@ -43,6 +44,7 @@ fn optimize_bootstrap(
|
||||
config,
|
||||
noise_factor,
|
||||
&search_space,
|
||||
&decomposition::cache(security_level),
|
||||
);
|
||||
result
|
||||
.best_solution
|
||||
@@ -219,6 +221,7 @@ impl OperationDag {
|
||||
&self.0,
|
||||
config,
|
||||
&search_space,
|
||||
&decomposition::cache(security_level),
|
||||
);
|
||||
result
|
||||
.best_solution
|
||||
@@ -239,12 +242,14 @@ impl OperationDag {
|
||||
};
|
||||
|
||||
let search_space = SearchSpace::default();
|
||||
let cache = decomposition::cache(security_level);
|
||||
|
||||
let result = concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize(
|
||||
&self.0,
|
||||
config,
|
||||
&search_space,
|
||||
default_log_norm2_woppbs,
|
||||
&cache,
|
||||
);
|
||||
result.map_or_else(no_dag_solution, |solution| solution.into())
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ static_init = "1.0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
rmp-serde = "1.1.0"
|
||||
statrs = "0.15.0"
|
||||
lazy_static = "1.4.0"
|
||||
|
||||
[dev-dependencies]
|
||||
approx = "0.5"
|
||||
|
||||
@@ -5,7 +5,8 @@ use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositi
|
||||
use crate::utils::square;
|
||||
use concrete_commons::dispersion::{DispersionParameter, Variance};
|
||||
|
||||
use super::decomposition::{blind_rotate, cut_complexity_noise, keyswitch};
|
||||
use super::decomposition;
|
||||
use super::decomposition::{blind_rotate, cut_complexity_noise, keyswitch, PersistDecompCache};
|
||||
|
||||
// Ref time for v0 table 1 thread: 950ms
|
||||
const CUTS: bool = true; // 80ms
|
||||
@@ -46,6 +47,19 @@ pub struct Caches {
|
||||
pub keyswitch: keyswitch::Cache,
|
||||
}
|
||||
|
||||
impl Caches {
|
||||
pub fn new(cache: &decomposition::PersistDecompCache) -> Self {
|
||||
Self {
|
||||
blind_rotate: cache.br.cache(),
|
||||
keyswitch: cache.ks.cache(),
|
||||
}
|
||||
}
|
||||
pub fn backport_to(self, cache: &decomposition::PersistDecompCache) {
|
||||
cache.ks.backport(self.keyswitch);
|
||||
cache.br.backport(self.blind_rotate);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn update_state_with_best_decompositions(
|
||||
state: &mut OptimizationState,
|
||||
@@ -182,6 +196,7 @@ pub fn optimize_one(
|
||||
config: Config,
|
||||
noise_factor: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
) -> OptimizationState {
|
||||
assert!(0 < precision && precision <= 16);
|
||||
assert!(1.0 <= noise_factor);
|
||||
@@ -193,7 +208,6 @@ pub fn optimize_one(
|
||||
// the blind rotate decomposition
|
||||
|
||||
let ciphertext_modulus_log = config.ciphertext_modulus_log;
|
||||
let security_level = config.security_level;
|
||||
let safe_variance = error::safe_variance_bound_2padbits(
|
||||
precision,
|
||||
ciphertext_modulus_log,
|
||||
@@ -228,10 +242,7 @@ pub fn optimize_one(
|
||||
> consts.safe_variance
|
||||
};
|
||||
|
||||
let mut caches = Caches {
|
||||
blind_rotate: blind_rotate::for_security(security_level).cache(),
|
||||
keyswitch: keyswitch::for_security(security_level).cache(),
|
||||
};
|
||||
let mut caches = Caches::new(cache);
|
||||
|
||||
for &glwe_dim in &search_space.glwe_dimensions {
|
||||
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
|
||||
@@ -260,8 +271,7 @@ pub fn optimize_one(
|
||||
}
|
||||
}
|
||||
|
||||
blind_rotate::for_security(security_level).backport(caches.blind_rotate);
|
||||
keyswitch::for_security(security_level).backport(caches.keyswitch);
|
||||
caches.backport_to(cache);
|
||||
|
||||
if let Some(sol) = state.best_solution {
|
||||
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
|
||||
|
||||
@@ -7,7 +7,7 @@ use crate::optimization::atomic_pattern::{
|
||||
Caches, OptimizationDecompositionsConsts, OptimizationState, Solution,
|
||||
};
|
||||
use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace};
|
||||
use crate::optimization::decomposition::{blind_rotate, keyswitch};
|
||||
use crate::optimization::decomposition::PersistDecompCache;
|
||||
use crate::parameters::GlweParameters;
|
||||
use crate::security;
|
||||
|
||||
@@ -231,6 +231,7 @@ pub fn optimize(
|
||||
dag: &unparametrized::OperationDag,
|
||||
config: Config,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
) -> OptimizationState {
|
||||
let ciphertext_modulus_log = config.ciphertext_modulus_log;
|
||||
let security_level = config.security_level;
|
||||
@@ -266,11 +267,7 @@ pub fn optimize(
|
||||
if dag.nb_luts == 0 {
|
||||
return optimize_no_luts(state, &consts, &dag, search_space);
|
||||
}
|
||||
|
||||
let mut caches = Caches {
|
||||
blind_rotate: blind_rotate::for_security(security_level).cache(),
|
||||
keyswitch: keyswitch::for_security(security_level).cache(),
|
||||
};
|
||||
let mut caches = Caches::new(cache);
|
||||
|
||||
let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| {
|
||||
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key(
|
||||
@@ -316,8 +313,7 @@ pub fn optimize(
|
||||
}
|
||||
}
|
||||
|
||||
blind_rotate::for_security(security_level).backport(caches.blind_rotate);
|
||||
keyswitch::for_security(security_level).backport(caches.keyswitch);
|
||||
caches.backport_to(cache);
|
||||
|
||||
if let Some(sol) = state.best_solution {
|
||||
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
|
||||
@@ -335,6 +331,7 @@ pub fn optimize_v0(
|
||||
config: Config,
|
||||
noise_factor: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
) -> OptimizationState {
|
||||
use crate::dag::operator::{FunctionTable, Shape};
|
||||
let same_scale_manp = 0.0;
|
||||
@@ -349,7 +346,7 @@ pub fn optimize_v0(
|
||||
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
|
||||
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
|
||||
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
|
||||
let mut state = optimize(&dag, config, search_space);
|
||||
let mut state = optimize(&dag, config, search_space, cache);
|
||||
if let Some(sol) = &mut state.best_solution {
|
||||
sol.complexity /= 2.0;
|
||||
}
|
||||
@@ -364,9 +361,9 @@ mod tests {
|
||||
use crate::computing_cost::cpu::CpuComplexity;
|
||||
use crate::dag::operator::{FunctionTable, Shape, Weights};
|
||||
use crate::noise_estimator::p_error::repeat_p_error;
|
||||
use crate::optimization::atomic_pattern;
|
||||
use crate::optimization::config::SearchSpace;
|
||||
use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin;
|
||||
use crate::optimization::{atomic_pattern, decomposition};
|
||||
use crate::utils::square;
|
||||
|
||||
fn small_relative_diff(v1: f64, v2: f64) -> bool {
|
||||
@@ -390,7 +387,10 @@ mod tests {
|
||||
|
||||
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
|
||||
|
||||
fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
|
||||
fn optimize(
|
||||
dag: &unparametrized::OperationDag,
|
||||
cache: &PersistDecompCache,
|
||||
) -> OptimizationState {
|
||||
let config = Config {
|
||||
security_level: 128,
|
||||
maximum_acceptable_error_probability: _4_SIGMA,
|
||||
@@ -400,7 +400,7 @@ mod tests {
|
||||
|
||||
let search_space = SearchSpace::default();
|
||||
|
||||
super::optimize(dag, config, &search_space)
|
||||
super::optimize(dag, config, &search_space, cache)
|
||||
}
|
||||
|
||||
struct Times {
|
||||
@@ -439,16 +439,38 @@ mod tests {
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let _ = optimize_v0(sum_size, precision, config, weight as f64, &search_space);
|
||||
let cache = decomposition::cache(config.security_level);
|
||||
|
||||
let _ = optimize_v0(
|
||||
sum_size,
|
||||
precision,
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
);
|
||||
// ensure cache is filled
|
||||
|
||||
let chrono = Instant::now();
|
||||
let state = optimize_v0(sum_size, precision, config, weight as f64, &search_space);
|
||||
let state = optimize_v0(
|
||||
sum_size,
|
||||
precision,
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
);
|
||||
|
||||
times.dag_time += chrono.elapsed().as_nanos();
|
||||
let chrono = Instant::now();
|
||||
let state_ref =
|
||||
atomic_pattern::optimize_one(sum_size, precision, config, weight as f64, &search_space);
|
||||
let state_ref = atomic_pattern::optimize_one(
|
||||
sum_size,
|
||||
precision,
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
);
|
||||
times.worst_time += chrono.elapsed().as_nanos();
|
||||
assert_eq!(
|
||||
state.best_solution.is_some(),
|
||||
@@ -476,6 +498,10 @@ mod tests {
|
||||
}
|
||||
|
||||
fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) {
|
||||
let security_level = 128;
|
||||
|
||||
let cache = decomposition::cache(security_level);
|
||||
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
{
|
||||
let input1 = dag.add_input(precision, Shape::number());
|
||||
@@ -488,7 +514,7 @@ mod tests {
|
||||
let dag2 = analyze::analyze(
|
||||
&dag,
|
||||
&NoiseBoundConfig {
|
||||
security_level: 128,
|
||||
security_level,
|
||||
maximum_acceptable_error_probability: _4_SIGMA,
|
||||
ciphertext_modulus_log: 64,
|
||||
},
|
||||
@@ -506,15 +532,21 @@ mod tests {
|
||||
let search_space = SearchSpace::default();
|
||||
|
||||
let config = Config {
|
||||
security_level: 128,
|
||||
security_level,
|
||||
maximum_acceptable_error_probability: _4_SIGMA,
|
||||
ciphertext_modulus_log: 64,
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let state = optimize(&dag);
|
||||
let state_ref =
|
||||
atomic_pattern::optimize_one(1, precision as u64, config, weight as f64, &search_space);
|
||||
let state = optimize(&dag, &cache);
|
||||
let state_ref = atomic_pattern::optimize_one(
|
||||
1,
|
||||
precision as u64,
|
||||
config,
|
||||
weight as f64,
|
||||
&search_space,
|
||||
&cache,
|
||||
);
|
||||
assert_eq!(
|
||||
state.best_solution.is_some(),
|
||||
state_ref.best_solution.is_some()
|
||||
@@ -531,7 +563,7 @@ mod tests {
|
||||
assert!(sol.global_p_error <= 1.0);
|
||||
}
|
||||
|
||||
fn no_lut_vs_lut(precision: Precision) {
|
||||
fn no_lut_vs_lut(precision: Precision, cache: &PersistDecompCache) {
|
||||
let mut dag_lut = unparametrized::OperationDag::new();
|
||||
let input1 = dag_lut.add_input(precision as u8, Shape::number());
|
||||
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision);
|
||||
@@ -539,8 +571,8 @@ mod tests {
|
||||
let mut dag_no_lut = unparametrized::OperationDag::new();
|
||||
let _input2 = dag_no_lut.add_input(precision as u8, Shape::number());
|
||||
|
||||
let state_no_lut = optimize(&dag_no_lut);
|
||||
let state_lut = optimize(&dag_lut);
|
||||
let state_no_lut = optimize(&dag_no_lut, cache);
|
||||
let state_lut = optimize(&dag_lut, cache);
|
||||
assert_eq!(
|
||||
state_no_lut.best_solution.is_some(),
|
||||
state_lut.best_solution.is_some()
|
||||
@@ -556,14 +588,16 @@ mod tests {
|
||||
}
|
||||
#[test]
|
||||
fn test_lut_vs_no_lut() {
|
||||
let cache = decomposition::cache(128);
|
||||
for precision in 1..=8 {
|
||||
no_lut_vs_lut(precision);
|
||||
no_lut_vs_lut(precision, &cache);
|
||||
}
|
||||
}
|
||||
|
||||
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
|
||||
precision: Precision,
|
||||
weight: u64,
|
||||
cache: &PersistDecompCache,
|
||||
) {
|
||||
let weight = &Weights::number(weight);
|
||||
|
||||
@@ -583,8 +617,8 @@ mod tests {
|
||||
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
|
||||
let state_1 = optimize(&dag_1);
|
||||
let state_2 = optimize(&dag_2);
|
||||
let state_1 = optimize(&dag_1, cache);
|
||||
let state_2 = optimize(&dag_2, cache);
|
||||
|
||||
if state_1.best_solution.is_none() {
|
||||
assert!(state_2.best_solution.is_none());
|
||||
@@ -597,15 +631,18 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() {
|
||||
let cache = decomposition::cache(128);
|
||||
for log_weight in 1..=16 {
|
||||
let weight = 1 << log_weight;
|
||||
for precision in 5..=9 {
|
||||
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight);
|
||||
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(
|
||||
precision, weight, &cache,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn lut_1_layer_has_better_complexity(precision: Precision) {
|
||||
fn lut_1_layer_has_better_complexity(precision: Precision, cache: &PersistDecompCache) {
|
||||
let dag_1_layer = {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision as u8, Shape::number());
|
||||
@@ -621,17 +658,18 @@ mod tests {
|
||||
dag
|
||||
};
|
||||
|
||||
let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap();
|
||||
let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap();
|
||||
let sol_1_layer = optimize(&dag_1_layer, cache).best_solution.unwrap();
|
||||
let sol_2_layer = optimize(&dag_2_layer, cache).best_solution.unwrap();
|
||||
assert!(sol_1_layer.complexity < sol_2_layer.complexity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lut_1_layer_is_better() {
|
||||
let cache = decomposition::cache(128);
|
||||
// for some reason on 4, 5, 6, the complexity is already minimal
|
||||
// this could be due to pre-defined pareto set
|
||||
for precision in [1, 2, 3, 7, 8] {
|
||||
lut_1_layer_has_better_complexity(precision);
|
||||
lut_1_layer_has_better_complexity(precision, &cache);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -643,7 +681,10 @@ mod tests {
|
||||
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
|
||||
fn assert_multi_precision_dominate_single(weight: u64) -> Option<bool> {
|
||||
fn assert_multi_precision_dominate_single(
|
||||
weight: u64,
|
||||
cache: &PersistDecompCache,
|
||||
) -> Option<bool> {
|
||||
let low_precision = 4u8;
|
||||
let high_precision = 5u8;
|
||||
let mut dag_low = unparametrized::OperationDag::new();
|
||||
@@ -656,12 +697,12 @@ mod tests {
|
||||
circuit(&mut dag_multi, low_precision, weight);
|
||||
circuit(&mut dag_multi, high_precision, 1);
|
||||
}
|
||||
let state_multi = optimize(&dag_multi);
|
||||
let state_multi = optimize(&dag_multi, cache);
|
||||
|
||||
let mut sol_multi = state_multi.best_solution?;
|
||||
|
||||
let state_low = optimize(&dag_low);
|
||||
let state_high = optimize(&dag_high);
|
||||
let state_low = optimize(&dag_low, cache);
|
||||
let state_high = optimize(&dag_high, cache);
|
||||
|
||||
let sol_low = state_low.best_solution.unwrap();
|
||||
let sol_high = state_high.best_solution.unwrap();
|
||||
@@ -680,10 +721,11 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_multi_precision_dominate_single() {
|
||||
let cache = decomposition::cache(128);
|
||||
let mut prev = Some(true); // true -> ... -> true -> false -> ... -> false
|
||||
for log2_weight in 0..29 {
|
||||
let weight = 1 << log2_weight;
|
||||
let current = assert_multi_precision_dominate_single(weight);
|
||||
let current = assert_multi_precision_dominate_single(weight, &cache);
|
||||
#[allow(clippy::match_like_matches_macro)] // less readable
|
||||
let authorized = match (prev, current) {
|
||||
(Some(false), Some(true)) => false,
|
||||
@@ -713,22 +755,28 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_input() {
|
||||
let cache = decomposition::cache(128);
|
||||
for precision in [4_u8, 8] {
|
||||
for weight in [1, 3, 27, 243, 729] {
|
||||
for dim in [1, 2, 16, 32] {
|
||||
let _ = check_global_p_error_input(dim, weight, precision);
|
||||
let _ = check_global_p_error_input(dim, weight, precision, &cache);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_global_p_error_input(dim: u64, weight: u64, precision: u8) -> f64 {
|
||||
fn check_global_p_error_input(
|
||||
dim: u64,
|
||||
weight: u64,
|
||||
precision: u8,
|
||||
cache: &PersistDecompCache,
|
||||
) -> f64 {
|
||||
let shape = Shape::vector(dim);
|
||||
let weights = Weights::number(weight);
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision as u8, shape);
|
||||
let _dot1 = dag.add_dot([input1], weights); // this is just several multiply
|
||||
let state = optimize(&dag);
|
||||
let state = optimize(&dag, cache);
|
||||
let sol = state.best_solution.unwrap();
|
||||
let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim);
|
||||
approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim);
|
||||
@@ -737,16 +785,22 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_lut() {
|
||||
let cache = decomposition::cache(128);
|
||||
for precision in [4_u8, 8] {
|
||||
for weight in [1, 3, 27, 243, 729] {
|
||||
for depth in [2, 16, 32] {
|
||||
check_global_p_error_lut(depth, weight, precision);
|
||||
check_global_p_error_lut(depth, weight, precision, &cache);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_global_p_error_lut(depth: u64, weight: u64, precision: u8) {
|
||||
fn check_global_p_error_lut(
|
||||
depth: u64,
|
||||
weight: u64,
|
||||
precision: u8,
|
||||
cache: &PersistDecompCache,
|
||||
) {
|
||||
let shape = Shape::number();
|
||||
let weights = Weights::number(weight);
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
@@ -755,7 +809,7 @@ mod tests {
|
||||
let dot = dag.add_dot([last_val], &weights);
|
||||
last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
let state = optimize(&dag);
|
||||
let state = optimize(&dag, cache);
|
||||
let sol = state.best_solution.unwrap();
|
||||
// the first lut on input has reduced impact on error probability
|
||||
let lower_nb_dominating_lut = depth - 1;
|
||||
@@ -792,6 +846,7 @@ mod tests {
|
||||
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
|
||||
#[test]
|
||||
fn test_global_p_error_dominating_lut() {
|
||||
let cache = decomposition::cache(128);
|
||||
let depth = 128;
|
||||
let weights_low = 1;
|
||||
let weights_high = 1;
|
||||
@@ -804,7 +859,7 @@ mod tests {
|
||||
weights_low,
|
||||
weights_high,
|
||||
);
|
||||
let sol = optimize(&dag).best_solution.unwrap();
|
||||
let sol = optimize(&dag, &cache).best_solution.unwrap();
|
||||
// the 2 first luts and low precision/weight luts have little impact on error probability
|
||||
let nb_dominating_lut = depth - 1;
|
||||
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
|
||||
@@ -819,6 +874,7 @@ mod tests {
|
||||
#[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const
|
||||
#[test]
|
||||
fn test_global_p_error_non_dominating_lut() {
|
||||
let cache = decomposition::cache(128);
|
||||
let depth = 128;
|
||||
let weights_low = 1024 * 1024 * 3;
|
||||
let weights_high = 1;
|
||||
@@ -831,7 +887,7 @@ mod tests {
|
||||
weights_low,
|
||||
weights_high,
|
||||
);
|
||||
let sol = optimize(&dag).best_solution.unwrap();
|
||||
let sol = optimize(&dag, &cache).best_solution.unwrap();
|
||||
// all intern luts have an impact on error probability almost equaly
|
||||
let nb_dominating_lut = (2 * depth) - 1;
|
||||
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::noise_estimator::p_error::repeat_p_error;
|
||||
use crate::optimization::atomic_pattern::Solution as WpSolution;
|
||||
use crate::optimization::config::{Config, SearchSpace};
|
||||
use crate::optimization::dag::solo_key::{analyze, optimize};
|
||||
use crate::optimization::decomposition::PersistDecompCache;
|
||||
use crate::optimization::wop_atomic_pattern::optimize::optimize_one as wop_optimize;
|
||||
use crate::optimization::wop_atomic_pattern::Solution as WopSolution;
|
||||
use std::ops::RangeInclusive;
|
||||
@@ -48,6 +49,7 @@ pub fn optimize(
|
||||
config: Config,
|
||||
search_space: &SearchSpace,
|
||||
default_log_norm2_woppbs: f64,
|
||||
cache: &PersistDecompCache,
|
||||
) -> Option<Solution> {
|
||||
let max_precision = max_precision(dag);
|
||||
let nb_luts = analyze::lut_count_from_dag(dag);
|
||||
@@ -58,11 +60,17 @@ pub fn optimize(
|
||||
let default_log_norm = default_log_norm2_woppbs;
|
||||
let worst_log_norm = analyze::worst_log_norm(dag);
|
||||
let log_norm = default_log_norm.min(worst_log_norm);
|
||||
let opt_sol =
|
||||
wop_optimize(fallback_16b_precision, config, log_norm, search_space).best_solution;
|
||||
let opt_sol = wop_optimize(
|
||||
fallback_16b_precision,
|
||||
config,
|
||||
log_norm,
|
||||
search_space,
|
||||
cache,
|
||||
)
|
||||
.best_solution;
|
||||
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
|
||||
} else {
|
||||
let opt_sol = optimize::optimize(dag, config, search_space).best_solution;
|
||||
let opt_sol = optimize::optimize(dag, config, search_space, cache).best_solution;
|
||||
opt_sol.map(Solution::WpSolution)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use concrete_commons::dispersion::DispersionParameter;
|
||||
@@ -8,8 +6,6 @@ use crate::computing_cost::operators::pbs::PbsComplexity;
|
||||
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
use crate::parameters::{BrDecompositionParameters, GlweParameters, LweDimension, PbsParameters};
|
||||
use crate::security;
|
||||
|
||||
use crate::security::security_weights::SECURITY_WEIGHTS_TABLE;
|
||||
use crate::utils::cache::ephemeral::{CacheHashMap, EphemeralCache};
|
||||
use crate::utils::cache::persistent::PersistentCacheHashMap;
|
||||
|
||||
@@ -120,42 +116,22 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<BrComplexityNoise>>;
|
||||
type MultiSecPersistDecompCache = HashMap<u64, PersistDecompCache>; // just to attach a finaly
|
||||
pub type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<BrComplexityNoise>>;
|
||||
|
||||
#[static_init::dynamic]
|
||||
pub static SHARED_CACHE: MultiSecPersistDecompCache = SECURITY_WEIGHTS_TABLE
|
||||
.keys()
|
||||
.map(|&security_level| {
|
||||
let ciphertext_modulus_log = 64;
|
||||
let tmp: String = std::env::temp_dir()
|
||||
.to_str()
|
||||
.expect("Invalid tmp dir")
|
||||
.into();
|
||||
let path = format!("{tmp}/optimizer/cache/br-decomp-cpu-64-{security_level}");
|
||||
let function = move |(glwe_params, internal_dim): MacroParam| {
|
||||
pareto_quantities(
|
||||
ciphertext_modulus_log,
|
||||
security_level,
|
||||
internal_dim,
|
||||
glwe_params,
|
||||
)
|
||||
};
|
||||
(
|
||||
pub fn cache(security_level: u64) -> PersistDecompCache {
|
||||
let ciphertext_modulus_log = 64;
|
||||
let tmp: String = std::env::temp_dir()
|
||||
.to_str()
|
||||
.expect("Invalid tmp dir")
|
||||
.into();
|
||||
let path = format!("{tmp}/optimizer/cache/br-decomp-cpu-64-{security_level}");
|
||||
let function = move |(glwe_params, internal_dim): MacroParam| {
|
||||
pareto_quantities(
|
||||
ciphertext_modulus_log,
|
||||
security_level,
|
||||
PersistentCacheHashMap::new(&path, "v0", function),
|
||||
internal_dim,
|
||||
glwe_params,
|
||||
)
|
||||
})
|
||||
.collect::<MultiSecPersistDecompCache>();
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
#[static_init::destructor(10)]
|
||||
extern "C" fn finaly() {
|
||||
for v in SHARED_CACHE.values() {
|
||||
v.sync_to_disk();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn for_security(security_level: u64) -> &'static PersistDecompCache {
|
||||
SHARED_CACHE.get(&security_level).unwrap()
|
||||
};
|
||||
PersistentCacheHashMap::new(&path, "v0", function)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use concrete_commons::dispersion::DispersionParameter;
|
||||
@@ -9,7 +7,6 @@ use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
use crate::parameters::{
|
||||
GlweParameters, KeyswitchParameters, KsDecompositionParameters, LweDimension,
|
||||
};
|
||||
use crate::security::security_weights::SECURITY_WEIGHTS_TABLE;
|
||||
use crate::utils::cache::ephemeral::{CacheHashMap, EphemeralCache};
|
||||
use crate::utils::cache::persistent::PersistentCacheHashMap;
|
||||
|
||||
@@ -121,42 +118,22 @@ impl Cache {
|
||||
}
|
||||
}
|
||||
|
||||
type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<KsComplexityNoise>>;
|
||||
type MultiSecPersistDecompCache = HashMap<u64, PersistDecompCache>;
|
||||
pub type PersistDecompCache = PersistentCacheHashMap<MacroParam, Vec<KsComplexityNoise>>;
|
||||
|
||||
#[static_init::dynamic]
|
||||
pub static SHARED_CACHE: MultiSecPersistDecompCache = SECURITY_WEIGHTS_TABLE
|
||||
.keys()
|
||||
.map(|&security_level| {
|
||||
let ciphertext_modulus_log = 64;
|
||||
let tmp: String = std::env::temp_dir()
|
||||
.to_str()
|
||||
.expect("Invalid tmp dir")
|
||||
.into();
|
||||
let path = format!("{tmp}/optimizer/cache/ks-decomp-cpu-64-{security_level}");
|
||||
let function = move |(glwe_params, internal_dim): MacroParam| {
|
||||
pareto_quantities(
|
||||
ciphertext_modulus_log,
|
||||
security_level,
|
||||
internal_dim,
|
||||
glwe_params,
|
||||
)
|
||||
};
|
||||
(
|
||||
pub fn cache(security_level: u64) -> PersistDecompCache {
|
||||
let ciphertext_modulus_log = 64;
|
||||
let tmp: String = std::env::temp_dir()
|
||||
.to_str()
|
||||
.expect("Invalid tmp dir")
|
||||
.into();
|
||||
let path = format!("{tmp}/optimizer/cache/ks-decomp-cpu-64-{security_level}");
|
||||
let function = move |(glwe_params, internal_dim): MacroParam| {
|
||||
pareto_quantities(
|
||||
ciphertext_modulus_log,
|
||||
security_level,
|
||||
PersistentCacheHashMap::new(&path, "v0", function),
|
||||
internal_dim,
|
||||
glwe_params,
|
||||
)
|
||||
})
|
||||
.collect::<MultiSecPersistDecompCache>();
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
#[static_init::destructor(10)]
|
||||
extern "C" fn finaly() {
|
||||
for v in SHARED_CACHE.values() {
|
||||
v.sync_to_disk();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn for_security(security_level: u64) -> &'static PersistDecompCache {
|
||||
SHARED_CACHE.get(&security_level).unwrap()
|
||||
};
|
||||
PersistentCacheHashMap::new(&path, "v0", function)
|
||||
}
|
||||
|
||||
@@ -5,3 +5,15 @@ pub mod keyswitch;
|
||||
|
||||
pub use common::MacroParam;
|
||||
pub use cut::cut_complexity_noise;
|
||||
|
||||
pub struct PersistDecompCache {
|
||||
pub ks: keyswitch::PersistDecompCache,
|
||||
pub br: blind_rotate::PersistDecompCache,
|
||||
}
|
||||
|
||||
pub fn cache(security_level: u64) -> PersistDecompCache {
|
||||
PersistDecompCache {
|
||||
ks: keyswitch::cache(security_level),
|
||||
br: blind_rotate::cache(security_level),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pub mod atomic_pattern;
|
||||
pub mod config;
|
||||
pub mod dag;
|
||||
pub(crate) mod decomposition;
|
||||
pub mod decomposition;
|
||||
pub mod wop_atomic_pattern;
|
||||
|
||||
@@ -9,15 +9,12 @@ use crate::noise_estimator::error::{
|
||||
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
use crate::noise_estimator::operators::wop_atomic_pattern::estimate_packing_private_keyswitch;
|
||||
use crate::optimization::atomic_pattern;
|
||||
use crate::optimization::atomic_pattern::Caches;
|
||||
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
|
||||
use crate::optimization::atomic_pattern::{Caches, OptimizationDecompositionsConsts};
|
||||
|
||||
use crate::optimization::config::{Config, SearchSpace};
|
||||
use crate::optimization::decomposition::blind_rotate;
|
||||
use crate::optimization::decomposition::blind_rotate::BrComplexityNoise;
|
||||
use crate::optimization::decomposition::cut_complexity_noise;
|
||||
use crate::optimization::decomposition::keyswitch;
|
||||
use crate::optimization::decomposition::keyswitch::KsComplexityNoise;
|
||||
use crate::optimization::decomposition::{cut_complexity_noise, PersistDecompCache};
|
||||
use crate::optimization::wop_atomic_pattern::pareto::BR_CIRCUIT_BOOTSTRAP_PARETO_DECOMP;
|
||||
use crate::parameters::{
|
||||
GlweParameters, KeyswitchParameters, KsDecompositionParameters, LweDimension, PbsParameters,
|
||||
@@ -457,12 +454,12 @@ fn optimize_raw(
|
||||
search_space: &SearchSpace,
|
||||
n_functions: u64, // Many functions at the same time, stay at 1 for start
|
||||
partitionning: &[u64],
|
||||
cache: &PersistDecompCache,
|
||||
) -> OptimizationState {
|
||||
assert!(0.0 < config.maximum_acceptable_error_probability);
|
||||
assert!(config.maximum_acceptable_error_probability < 1.0);
|
||||
|
||||
let ciphertext_modulus_log = config.ciphertext_modulus_log;
|
||||
let security_level = config.security_level;
|
||||
|
||||
// Circuit BS bound
|
||||
// 1 bit of message only here =)
|
||||
@@ -485,10 +482,7 @@ fn optimize_raw(
|
||||
safe_variance: safe_variance_bound,
|
||||
};
|
||||
|
||||
let mut caches = Caches {
|
||||
blind_rotate: blind_rotate::for_security(security_level).cache(),
|
||||
keyswitch: keyswitch::for_security(security_level).cache(),
|
||||
};
|
||||
let mut caches = Caches::new(cache);
|
||||
|
||||
for &glwe_dim in &search_space.glwe_dimensions {
|
||||
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
|
||||
@@ -516,9 +510,7 @@ fn optimize_raw(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
blind_rotate::for_security(security_level).backport(caches.blind_rotate);
|
||||
keyswitch::for_security(security_level).backport(caches.keyswitch);
|
||||
caches.backport_to(cache);
|
||||
|
||||
state
|
||||
}
|
||||
@@ -528,11 +520,19 @@ pub fn optimize_one(
|
||||
config: Config,
|
||||
log_norm: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
) -> OptimizationState {
|
||||
let coprimes = crt_decomposition::default_coprimes(precision as Precision);
|
||||
let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes);
|
||||
let n_functions = 1;
|
||||
let mut state = optimize_raw(log_norm, config, search_space, n_functions, &partitionning);
|
||||
let mut state = optimize_raw(
|
||||
log_norm,
|
||||
config,
|
||||
search_space,
|
||||
n_functions,
|
||||
&partitionning,
|
||||
cache,
|
||||
);
|
||||
state.best_solution = state.best_solution.map(|mut sol| -> Solution {
|
||||
sol.crt_decomposition = coprimes;
|
||||
sol
|
||||
@@ -546,9 +546,10 @@ pub fn optimize_one_compat(
|
||||
config: Config,
|
||||
noise_factor: f64,
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCache,
|
||||
) -> atomic_pattern::OptimizationState {
|
||||
let log_norm = noise_factor.log2();
|
||||
let result = optimize_one(precision, config, log_norm, search_space);
|
||||
let result = optimize_one(precision, config, log_norm, search_space, cache);
|
||||
atomic_pattern::OptimizationState {
|
||||
best_solution: result.best_solution.map(Solution::into),
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use concrete_optimizer::optimization::atomic_pattern::{
|
||||
};
|
||||
use concrete_optimizer::optimization::config::{Config, SearchSpace};
|
||||
use concrete_optimizer::optimization::dag::solo_key::optimize::{self as optimize_dag};
|
||||
use concrete_optimizer::optimization::decomposition;
|
||||
use concrete_optimizer::optimization::wop_atomic_pattern::optimize as optimize_wop_atomic_pattern;
|
||||
use rayon_cond::CondIterator;
|
||||
use std::io::Write;
|
||||
@@ -107,6 +108,8 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let cache = decomposition::cache(config.security_level);
|
||||
|
||||
precisions_iter
|
||||
.map(|precision| {
|
||||
let mut last_solution = None;
|
||||
@@ -121,6 +124,7 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
|
||||
config,
|
||||
noise_scale,
|
||||
&search_space,
|
||||
&cache,
|
||||
)
|
||||
} else if args.simulate_dag {
|
||||
optimize_dag::optimize_v0(
|
||||
@@ -129,6 +133,7 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
|
||||
config,
|
||||
noise_scale,
|
||||
&search_space,
|
||||
&cache,
|
||||
)
|
||||
} else {
|
||||
optimize_atomic_pattern::optimize_one(
|
||||
@@ -137,6 +142,7 @@ pub fn all_results(args: &Args) -> Vec<Vec<OptimizationState>> {
|
||||
config,
|
||||
noise_scale,
|
||||
&search_space,
|
||||
&cache,
|
||||
)
|
||||
};
|
||||
last_solution = result.best_solution;
|
||||
|
||||
Reference in New Issue
Block a user