mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
539 lines
18 KiB
Rust
539 lines
18 KiB
Rust
use super::crt_decomposition;
|
|
use crate::dag::operator::Precision;
|
|
use crate::noise_estimator::error::{
|
|
error_probability_of_sigma_scale, safe_variance_bound_product_1padbit,
|
|
sigma_scale_of_error_probability,
|
|
};
|
|
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
|
use crate::optimization::atomic_pattern;
|
|
use crate::optimization::atomic_pattern::OptimizationDecompositionsConsts;
|
|
|
|
use crate::optimization::config::{Config, SearchSpace};
|
|
use crate::optimization::decomposition::circuit_bootstrap::CbComplexityNoise;
|
|
use crate::optimization::decomposition::{DecompCaches, PersistDecompCaches};
|
|
use crate::parameters::{BrDecompositionParameters, GlweParameters};
|
|
use concrete_commons::dispersion::{DispersionParameter, Variance};
|
|
|
|
pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 {
|
|
let sigma = Variance(variance_bound).get_standard_dev() * kappa;
|
|
let sigma_scale = sigma / Variance(current_maximum_noise).get_standard_dev();
|
|
error_probability_of_sigma_scale(sigma_scale)
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct OptimizationState {
|
|
pub best_solution: Option<Solution>,
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Solution {
|
|
pub input_lwe_dimension: u64,
|
|
//n_big
|
|
pub internal_ks_output_lwe_dimension: u64,
|
|
//n_small
|
|
pub ks_decomposition_level_count: u64,
|
|
//l(KS)
|
|
pub ks_decomposition_base_log: u64,
|
|
//b(KS)
|
|
pub glwe_polynomial_size: u64,
|
|
//N
|
|
pub glwe_dimension: u64,
|
|
//k
|
|
pub br_decomposition_level_count: u64,
|
|
//l(BR)
|
|
pub br_decomposition_base_log: u64,
|
|
//b(BR)
|
|
pub complexity: f64,
|
|
pub noise_max: f64,
|
|
pub p_error: f64,
|
|
pub global_p_error: f64,
|
|
// error probability
|
|
pub cb_decomposition_level_count: u64,
|
|
pub cb_decomposition_base_log: u64,
|
|
pub crt_decomposition: Vec<u64>,
|
|
}
|
|
|
|
impl Solution {
|
|
pub fn init() -> Self {
|
|
Self {
|
|
input_lwe_dimension: 0,
|
|
internal_ks_output_lwe_dimension: 0,
|
|
ks_decomposition_level_count: 0,
|
|
ks_decomposition_base_log: 0,
|
|
glwe_polynomial_size: 0,
|
|
glwe_dimension: 0,
|
|
br_decomposition_level_count: 0,
|
|
br_decomposition_base_log: 0,
|
|
complexity: 0.,
|
|
noise_max: 0.0,
|
|
p_error: 0.0,
|
|
global_p_error: 0.0,
|
|
cb_decomposition_level_count: 0,
|
|
cb_decomposition_base_log: 0,
|
|
crt_decomposition: vec![],
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<Solution> for atomic_pattern::Solution {
|
|
fn from(sol: Solution) -> Self {
|
|
Self {
|
|
input_lwe_dimension: sol.input_lwe_dimension,
|
|
internal_ks_output_lwe_dimension: sol.internal_ks_output_lwe_dimension,
|
|
ks_decomposition_level_count: sol.ks_decomposition_level_count,
|
|
ks_decomposition_base_log: sol.ks_decomposition_base_log,
|
|
glwe_polynomial_size: sol.glwe_polynomial_size,
|
|
glwe_dimension: sol.glwe_dimension,
|
|
br_decomposition_level_count: sol.br_decomposition_level_count,
|
|
br_decomposition_base_log: sol.br_decomposition_base_log,
|
|
complexity: sol.complexity,
|
|
noise_max: sol.noise_max,
|
|
p_error: sol.p_error,
|
|
global_p_error: sol.global_p_error,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn estimate_variance(
|
|
br_variance: f64,
|
|
pp_variance: f64,
|
|
cb_decomp: &CbComplexityNoise,
|
|
ks_variance: f64,
|
|
variance_modulus_switching: f64,
|
|
log_norm: f64,
|
|
precisions_sum: u64,
|
|
max_precision: u64,
|
|
) -> f64 {
|
|
assert!(max_precision <= precisions_sum);
|
|
let variance_ggsw = pp_variance + br_variance / 2.;
|
|
let variance_coeff_1_cmux_tree =
|
|
2_f64.powf(2. * log_norm) // variance_coeff for the multisum
|
|
* (precisions_sum // for hybrid packing
|
|
<< (2 * (max_precision - 1))) as f64 // for left shift
|
|
;
|
|
let variance_one_external_product_for_cmux_tree = cb_decomp.variance_from_ggsw(variance_ggsw);
|
|
variance_modulus_switching
|
|
+ variance_coeff_1_cmux_tree * variance_one_external_product_for_cmux_tree
|
|
+ ks_variance
|
|
}
|
|
|
|
fn estimate_complexity(
|
|
glwe_params: &GlweParameters,
|
|
br_cost: f64,
|
|
pp_cost: f64,
|
|
cb_decomp: &CbComplexityNoise,
|
|
ks_cost: f64,
|
|
precisions_sum: u64,
|
|
nb_blocks: u64,
|
|
n_functions: u64,
|
|
) -> f64 {
|
|
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
|
// Hybrid packing
|
|
let cb_level = cb_decomp.decomp.level;
|
|
let complexity_1_cmux_hp = cb_decomp.complexity_one_cmux_hp;
|
|
let complexity_1_ggsw_to_fft = cb_decomp.complexity_one_ggsw_to_fft;
|
|
// BitExtract use br
|
|
let complexity_bit_extract_1_pbs = br_cost;
|
|
let complexity_bit_extract_wo_ks =
|
|
(precisions_sum - nb_blocks) as f64 * complexity_bit_extract_1_pbs;
|
|
|
|
// Hybrid packing
|
|
// Circuit bs: fp-ks
|
|
let complexity_ppks = pp_cost;
|
|
let complexity_all_ppks =
|
|
((glwe_params.glwe_dimension + 1) * cb_level * precisions_sum) as f64 * complexity_ppks;
|
|
|
|
// Circuit bs: pbs
|
|
let complexity_all_pbs = (precisions_sum * cb_level) as f64 * br_cost;
|
|
|
|
let complexity_circuit_bs = complexity_all_pbs + complexity_all_ppks;
|
|
|
|
// Hybrid packing (Do we have 1 or 2 groups)
|
|
let log2_polynomial_size = glwe_params.log2_polynomial_size;
|
|
// Size of cmux_group, can be zero
|
|
let cmux_group_count = if precisions_sum > log2_polynomial_size {
|
|
2f64.powi((precisions_sum - log2_polynomial_size - 1) as i32)
|
|
} else {
|
|
0.0
|
|
};
|
|
let complexity_cmux_tree = cmux_group_count * complexity_1_cmux_hp;
|
|
|
|
let complexity_all_ggsw_to_fft = precisions_sum as f64 * complexity_1_ggsw_to_fft;
|
|
|
|
// Hybrid packing blind rotate
|
|
let complexity_g_br =
|
|
complexity_1_cmux_hp * u64::min(glwe_params.log2_polynomial_size, precisions_sum) as f64;
|
|
|
|
let complexity_hybrid_packing = complexity_cmux_tree + complexity_g_br;
|
|
|
|
let complexity_multi_hybrid_packing =
|
|
n_functions as f64 * complexity_hybrid_packing + complexity_all_ggsw_to_fft;
|
|
|
|
let complexity_all_ks = precisions_sum as f64 * ks_cost;
|
|
|
|
complexity_bit_extract_wo_ks
|
|
+ complexity_circuit_bs
|
|
+ complexity_multi_hybrid_packing
|
|
+ complexity_all_ks
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
fn update_state_with_best_decompositions(
|
|
state: &mut OptimizationState,
|
|
consts: &OptimizationDecompositionsConsts,
|
|
glwe_params: GlweParameters,
|
|
internal_dim: u64,
|
|
n_functions: u64,
|
|
partitionning: &[u64],
|
|
caches: &mut DecompCaches,
|
|
) {
|
|
let ciphertext_modulus_log = consts.config.ciphertext_modulus_log;
|
|
let precisions_sum = partitionning.iter().copied().sum();
|
|
let max_precision = partitionning.iter().copied().max().unwrap();
|
|
let nb_blocks = partitionning.len() as u64;
|
|
|
|
let safe_variance_bound = consts.safe_variance;
|
|
let log_norm = consts.noise_factor.log2();
|
|
|
|
let variance_modulus_switching =
|
|
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key(
|
|
internal_dim,
|
|
glwe_params.polynomial_size(),
|
|
ciphertext_modulus_log,
|
|
)
|
|
.get_variance();
|
|
|
|
if variance_modulus_switching > consts.safe_variance {
|
|
return;
|
|
}
|
|
|
|
let mut best_complexity = state
|
|
.best_solution
|
|
.as_ref()
|
|
.map_or(f64::INFINITY, |s| s.complexity);
|
|
let mut best_variance = state
|
|
.best_solution
|
|
.as_ref()
|
|
.map_or(f64::INFINITY, |s| s.noise_max);
|
|
|
|
let pareto_blind_rotate = caches
|
|
.blind_rotate
|
|
.pareto_quantities(glwe_params, internal_dim);
|
|
|
|
let pareto_keyswitch = caches
|
|
.keyswitch
|
|
.pareto_quantities(glwe_params, internal_dim);
|
|
|
|
let pp_switch = caches
|
|
.pp_switch
|
|
.pareto_quantities(glwe_params, internal_dim);
|
|
|
|
let pareto_cb = caches.cb_pbs.pareto_quantities(glwe_params);
|
|
|
|
let lower_bound_variance_blind_rotate = pareto_blind_rotate.last().unwrap().noise;
|
|
let lower_bound_variance_keyswitch = pareto_keyswitch.last().unwrap().noise;
|
|
let lower_bound_variance_private_packing = pp_switch.last().unwrap().noise;
|
|
let lower_pareto_cb_bias = pareto_cb
|
|
.iter()
|
|
.map(|cb| cb.variance_bias)
|
|
.reduce(f64::min)
|
|
.unwrap();
|
|
let lower_pareto_cb_slope = pareto_cb
|
|
.iter()
|
|
.map(|cb| cb.variance_ggsw_factor)
|
|
.reduce(f64::min)
|
|
.unwrap();
|
|
let lower_bound_cost_blind_rotate = pareto_blind_rotate[0].complexity;
|
|
let lower_bound_cost_keyswitch = pareto_keyswitch[0].complexity;
|
|
let lower_bound_cost_pp = pp_switch[0].complexity;
|
|
let lower_bound_cost_cb_complexity_1_cmux_hp = pareto_cb
|
|
.iter()
|
|
.map(|cb| cb.complexity_one_cmux_hp)
|
|
.reduce(f64::min)
|
|
.unwrap_or(0.0);
|
|
let lower_bound_cost_cb_complexity_1_ggsw_to_fft = pareto_cb
|
|
.iter()
|
|
.map(|cb| cb.complexity_one_ggsw_to_fft)
|
|
.reduce(f64::min)
|
|
.unwrap_or(0.0);
|
|
let lower_bound_cb = CbComplexityNoise {
|
|
decomp: BrDecompositionParameters {
|
|
level: 1,
|
|
log2_base: 1,
|
|
},
|
|
complexity_one_cmux_hp: lower_bound_cost_cb_complexity_1_cmux_hp,
|
|
complexity_one_ggsw_to_fft: lower_bound_cost_cb_complexity_1_ggsw_to_fft,
|
|
variance_bias: lower_pareto_cb_bias,
|
|
variance_ggsw_factor: lower_pareto_cb_slope,
|
|
};
|
|
|
|
let variance = |br_variance: Option<_>,
|
|
pp_variance: Option<_>,
|
|
cb_decomp: Option<&CbComplexityNoise>,
|
|
ks_variance: Option<_>| {
|
|
let br_variance = br_variance.unwrap_or(lower_bound_variance_blind_rotate);
|
|
let pp_variance = pp_variance.unwrap_or(lower_bound_variance_private_packing);
|
|
let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb);
|
|
let ks_variance = ks_variance.unwrap_or(lower_bound_variance_keyswitch);
|
|
estimate_variance(
|
|
br_variance,
|
|
pp_variance,
|
|
cb_decomp,
|
|
ks_variance,
|
|
variance_modulus_switching,
|
|
log_norm,
|
|
precisions_sum,
|
|
max_precision,
|
|
)
|
|
};
|
|
|
|
let lower_bound_variance = variance(None, None, None, None);
|
|
if lower_bound_variance > consts.safe_variance {
|
|
// saves 20%
|
|
return;
|
|
}
|
|
|
|
let complexity = |br_cost: Option<_>,
|
|
pp_cost: Option<_>,
|
|
cb_decomp: Option<&CbComplexityNoise>,
|
|
ks_cost: Option<_>| {
|
|
// Pbs dans BitExtract et Circuit BS et FP-KS (partagés)
|
|
let br_cost = br_cost.unwrap_or(lower_bound_cost_blind_rotate);
|
|
let ks_cost = ks_cost.unwrap_or(lower_bound_cost_keyswitch);
|
|
let pp_cost = pp_cost.unwrap_or(lower_bound_cost_pp);
|
|
let cb_decomp = cb_decomp.unwrap_or(&lower_bound_cb);
|
|
estimate_complexity(
|
|
&glwe_params,
|
|
br_cost,
|
|
pp_cost,
|
|
cb_decomp,
|
|
ks_cost,
|
|
precisions_sum,
|
|
nb_blocks,
|
|
n_functions,
|
|
)
|
|
};
|
|
|
|
// BlindRotate dans Circuit BS
|
|
for (br_i, shared_br_decomp) in pareto_blind_rotate.iter().enumerate() {
|
|
let lower_bound_variance = variance(Some(shared_br_decomp.noise), None, None, None);
|
|
|
|
if lower_bound_variance > consts.safe_variance {
|
|
// saves 20%
|
|
continue;
|
|
}
|
|
|
|
// Circuit Boostrap
|
|
// private packing keyswitch, <=> FP-KS
|
|
let pp_switching_index = br_i;
|
|
let pp_switching = pp_switch[pp_switching_index];
|
|
|
|
let lower_bound_variance = variance(
|
|
Some(shared_br_decomp.noise),
|
|
Some(pp_switching.noise),
|
|
None,
|
|
None,
|
|
);
|
|
if lower_bound_variance > safe_variance_bound {
|
|
continue;
|
|
}
|
|
let lower_bound_complexity = complexity(
|
|
Some(shared_br_decomp.complexity),
|
|
Some(pp_switching.complexity),
|
|
None,
|
|
None,
|
|
);
|
|
if lower_bound_complexity > best_complexity {
|
|
// saves ?? TODO
|
|
// next br_decomp are at least as costly
|
|
break;
|
|
}
|
|
|
|
// CircuitBootstrap: new parameters l,b
|
|
// for &circuit_pbs_decomposition in pareto_circuit_pbs {
|
|
for cb_decomp in pareto_cb {
|
|
let lower_bound_variance = variance(
|
|
Some(shared_br_decomp.noise),
|
|
Some(pp_switching.noise),
|
|
Some(cb_decomp),
|
|
None,
|
|
);
|
|
if lower_bound_variance > safe_variance_bound {
|
|
continue;
|
|
}
|
|
let lower_bound_complexity = complexity(
|
|
Some(shared_br_decomp.complexity),
|
|
Some(pp_switching.complexity),
|
|
Some(cb_decomp),
|
|
None,
|
|
);
|
|
if lower_bound_complexity > best_complexity {
|
|
// saves 50%
|
|
// next circuit_pbs_decomposition_parameter are at least as costly
|
|
break;
|
|
}
|
|
// Shared by all pbs (like brs)
|
|
for ks_decomp in pareto_keyswitch.iter().rev() {
|
|
let variance_max = variance(
|
|
Some(shared_br_decomp.noise),
|
|
Some(pp_switching.noise),
|
|
Some(cb_decomp),
|
|
Some(ks_decomp.noise),
|
|
);
|
|
if variance_max > safe_variance_bound {
|
|
// saves 40%
|
|
break;
|
|
}
|
|
|
|
let complexity = complexity(
|
|
Some(shared_br_decomp.complexity),
|
|
Some(pp_switching.complexity),
|
|
Some(cb_decomp),
|
|
Some(ks_decomp.complexity),
|
|
);
|
|
|
|
if complexity > best_complexity {
|
|
continue;
|
|
}
|
|
|
|
#[allow(clippy::float_cmp)]
|
|
if complexity == best_complexity && variance_max > best_variance {
|
|
continue;
|
|
}
|
|
|
|
let kappa = consts.kappa;
|
|
best_complexity = complexity;
|
|
best_variance = variance_max;
|
|
let p_error = find_p_error(kappa, safe_variance_bound, variance_max);
|
|
state.best_solution = Some(Solution {
|
|
input_lwe_dimension: glwe_params.sample_extract_lwe_dimension(),
|
|
internal_ks_output_lwe_dimension: internal_dim,
|
|
ks_decomposition_level_count: ks_decomp.decomp.level,
|
|
ks_decomposition_base_log: ks_decomp.decomp.log2_base,
|
|
glwe_polynomial_size: glwe_params.polynomial_size(),
|
|
glwe_dimension: glwe_params.glwe_dimension,
|
|
br_decomposition_level_count: shared_br_decomp.decomp.level,
|
|
br_decomposition_base_log: shared_br_decomp.decomp.log2_base,
|
|
noise_max: variance_max,
|
|
complexity,
|
|
p_error,
|
|
global_p_error: f64::NAN,
|
|
cb_decomposition_level_count: cb_decomp.decomp.level,
|
|
cb_decomposition_base_log: cb_decomp.decomp.log2_base,
|
|
crt_decomposition: vec![],
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn optimize_raw(
|
|
log_norm: f64, // ?? norm2 of noise multisum, complexity of multisum is neglected
|
|
config: Config,
|
|
search_space: &SearchSpace,
|
|
n_functions: u64, // Many functions at the same time, stay at 1 for start
|
|
partitionning: &[u64],
|
|
persistent_caches: &PersistDecompCaches,
|
|
) -> OptimizationState {
|
|
assert!(0.0 < config.maximum_acceptable_error_probability);
|
|
assert!(config.maximum_acceptable_error_probability < 1.0);
|
|
assert!(!partitionning.is_empty());
|
|
|
|
let ciphertext_modulus_log = config.ciphertext_modulus_log;
|
|
|
|
// Circuit BS bound
|
|
// 1 bit of message only here =)
|
|
// Bound for first bit extract in BitExtract (dominate others)
|
|
let max_block_precision = *partitionning.iter().max().unwrap();
|
|
let safe_variance_bound = safe_variance_bound_product_1padbit(
|
|
max_block_precision,
|
|
ciphertext_modulus_log,
|
|
config.maximum_acceptable_error_probability,
|
|
);
|
|
let kappa: f64 = sigma_scale_of_error_probability(config.maximum_acceptable_error_probability);
|
|
|
|
let mut state = OptimizationState {
|
|
best_solution: None,
|
|
};
|
|
|
|
let consts = OptimizationDecompositionsConsts {
|
|
config,
|
|
kappa,
|
|
sum_size: 1, // Ignored
|
|
noise_factor: log_norm.exp2(),
|
|
safe_variance: safe_variance_bound,
|
|
};
|
|
|
|
let mut caches = persistent_caches.caches();
|
|
|
|
for &glwe_dim in &search_space.glwe_dimensions {
|
|
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
|
|
let input_lwe_dimension = glwe_dim << glwe_log_poly_size;
|
|
// Manual experimental CUT
|
|
if input_lwe_dimension > 1 << 13 {
|
|
continue;
|
|
}
|
|
|
|
let glwe_params = GlweParameters {
|
|
log2_polynomial_size: glwe_log_poly_size,
|
|
glwe_dimension: glwe_dim,
|
|
};
|
|
|
|
for &internal_dim in &search_space.internal_lwe_dimensions {
|
|
update_state_with_best_decompositions(
|
|
&mut state,
|
|
&consts,
|
|
glwe_params,
|
|
internal_dim,
|
|
n_functions,
|
|
partitionning,
|
|
&mut caches,
|
|
);
|
|
}
|
|
}
|
|
}
|
|
persistent_caches.backport(caches);
|
|
|
|
state
|
|
}
|
|
|
|
pub fn optimize_one(
|
|
precision: u64,
|
|
config: Config,
|
|
log_norm: f64,
|
|
search_space: &SearchSpace,
|
|
caches: &PersistDecompCaches,
|
|
) -> OptimizationState {
|
|
let coprimes = crt_decomposition::default_coprimes(precision as Precision);
|
|
let partitionning = crt_decomposition::precisions_from_coprimes(&coprimes);
|
|
let n_functions = 1;
|
|
let mut state = optimize_raw(
|
|
log_norm,
|
|
config,
|
|
search_space,
|
|
n_functions,
|
|
&partitionning,
|
|
caches,
|
|
);
|
|
state.best_solution = state.best_solution.map(|mut sol| -> Solution {
|
|
sol.crt_decomposition = coprimes;
|
|
sol
|
|
});
|
|
state
|
|
}
|
|
|
|
pub fn optimize_one_compat(
|
|
_sum_size: u64,
|
|
precision: u64,
|
|
config: Config,
|
|
noise_factor: f64,
|
|
search_space: &SearchSpace,
|
|
cache: &PersistDecompCaches,
|
|
) -> atomic_pattern::OptimizationState {
|
|
let log_norm = noise_factor.log2();
|
|
let result = optimize_one(precision, config, log_norm, search_space, cache);
|
|
atomic_pattern::OptimizationState {
|
|
best_solution: result.best_solution.map(Solution::into),
|
|
}
|
|
}
|