chore: clarify no luts optimization and cuts

This commit is contained in:
rudy
2022-08-31 17:54:26 +02:00
parent e2fa88aec2
commit 48e43c5762
2 changed files with 140 additions and 116 deletions

View File

@@ -655,11 +655,15 @@ impl OperationDag {
true
}
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
pub fn complexity(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
let luts_cost = one_lut_cost * (self.nb_luts as f64);
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
luts_cost + levelled_cost
}
pub fn levelled_complexity(&self, input_lwe_dimension: u64) -> f64 {
self.levelled_complexity.cost(input_lwe_dimension)
}
}
#[cfg(test)]
@@ -701,7 +705,7 @@ mod tests {
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT);
assert_eq!(analysis.out_shapes[input1.i], Shape::number());
@@ -724,7 +728,7 @@ mod tests {
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT);
assert!(analysis.out_shapes[lut1.i] == Shape::number());
@@ -750,7 +754,7 @@ mod tests {
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
let expected_var = SymbolicVariance {
input_coeff: norm2,
@@ -781,7 +785,7 @@ mod tests {
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
assert!(analysis.out_variances[dot.i].origin() == VO::Input);
assert_eq!(analysis.out_precisions[dot.i], 3);
@@ -810,7 +814,7 @@ mod tests {
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
let expected_var_dot1 = SymbolicVariance {
input_coeff: weights.square_norm2() as f64,
@@ -858,7 +862,7 @@ mod tests {
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
let expected_cost = (2 * lwe_dim) as f64 + 2.0 * one_lut_cost;
assert_f64_eq(expected_cost, complexity_cost);

View File

@@ -7,16 +7,12 @@ use crate::optimization::atomic_pattern::{
Caches, OptimizationDecompositionsConsts, OptimizationState, Solution,
};
use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace};
use crate::optimization::decomposition::{blind_rotate, cut_complexity_noise, keyswitch};
use crate::optimization::decomposition::{blind_rotate, keyswitch};
use crate::parameters::GlweParameters;
use crate::security::glwe::minimal_variance;
use crate::security;
use concrete_commons::dispersion::DispersionParameter;
const CUTS: bool = true;
const PARETO_CUTS: bool = true;
const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true;
#[allow(clippy::too_many_lines)]
fn update_best_solution_with_best_decompositions(
state: &mut OptimizationState,
@@ -24,10 +20,11 @@ fn update_best_solution_with_best_decompositions(
dag: &analyze::OperationDag,
internal_dim: u64,
glwe_params: GlweParameters,
input_noise_out: f64,
noise_modulus_switching: f64,
caches: &mut Caches,
) {
let safe_variance = consts.safe_variance;
assert!(dag.nb_luts > 0);
let glwe_poly_size = glwe_params.polynomial_size();
let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size;
@@ -35,134 +32,52 @@ fn update_best_solution_with_best_decompositions(
let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max);
let mut best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error);
let input_noise_out = minimal_variance(
glwe_params,
consts.config.ciphertext_modulus_log,
consts.config.security_level,
)
.get_variance();
let no_luts = dag.nb_luts == 0;
// if no_luts we disable cuts, any parameters is acceptable in luts
let (cut_noise, cut_complexity) = if no_luts && CUTS {
(f64::INFINITY, f64::INFINITY)
} else {
(
safe_variance - noise_modulus_switching,
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0))
/ (dag.nb_luts as f64),
)
};
if input_noise_out > cut_noise {
// exact cut when has_only_luts_with_inputs, lower bound cut otherwise
return;
}
// if only one layer of luts, no cut inside pareto_blind_rotate based on br noise,
// since it's never use inside the lut
let br_cut_noise = if dag.has_only_luts_with_inputs {
f64::INFINITY
} else {
cut_noise
};
let br_cut_complexity = cut_complexity;
let br_pareto = caches
.blind_rotate
.pareto_quantities(glwe_params, internal_dim);
// if only one layer of luts, no cut inside pareto_blind_rotate based on br noise,
// since this noise is never used inside a lut
let br_pareto = if dag.has_only_luts_with_inputs {
br_pareto
} else {
cut_complexity_noise(br_cut_complexity, br_cut_noise, br_pareto)
};
if br_pareto.is_empty() {
return;
}
let worst_input_ks_noise = if dag.has_only_luts_with_inputs {
input_noise_out
} else {
br_pareto.last().unwrap().noise
};
let ks_cut_noise = cut_noise - worst_input_ks_noise;
let ks_cut_complexity = cut_complexity - br_pareto[0].complexity;
let ks_pareto = caches
.keyswitch
.pareto_quantities(glwe_params, internal_dim);
let ks_pareto = cut_complexity_noise(ks_cut_complexity, ks_cut_noise, ks_pareto);
if ks_pareto.is_empty() {
return;
}
let i_max_ks = ks_pareto.len() - 1;
let mut i_current_max_ks = i_max_ks;
let input_noise_out = minimal_variance(
glwe_params,
consts.config.ciphertext_modulus_log,
consts.config.security_level,
)
.get_variance();
// by constructon br_pareto and ks_pareto are non-empty
let mut best_br = br_pareto[0];
let mut best_ks = ks_pareto[0];
let mut update_best_solution = false;
for &br_quantity in br_pareto {
// increasing complexity, decreasing variance
let one_lut_cost = br_quantity.complexity;
let complexity = dag.complexity(input_lwe_dimension, one_lut_cost);
if complexity > best_complexity {
// Since br_pareto is scanned by increasing complexity, we can stop
break;
}
let not_feasible = !dag.feasible(
input_noise_out,
br_quantity.noise,
0.0,
noise_modulus_switching,
);
if not_feasible && CUTS {
if not_feasible {
continue;
}
let one_lut_cost = br_quantity.complexity;
let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost);
if complexity > best_complexity {
// As best can evolves it is complementary to blind_rotate_quantities cuts.
if PARETO_CUTS {
break;
} else if CUTS {
for &ks_quantity in ks_pareto.iter().rev() {
let one_lut_cost = ks_quantity.complexity + br_quantity.complexity;
let complexity = dag.complexity(input_lwe_dimension, one_lut_cost);
let worse_complexity = complexity > best_complexity;
if worse_complexity {
continue;
}
}
for i_ks_pareto in (0..=i_current_max_ks).rev() {
// increasing variance, decreasing complexity
let ks_quantity = ks_pareto[i_ks_pareto];
let not_feasible = !dag.feasible(
input_noise_out,
br_quantity.noise,
ks_quantity.noise,
noise_modulus_switching,
);
// let noise_max = br_quantity.noise * dag.lut_base_noise_worst_lut + ks_quantity.noise + noise_modulus_switching;
if not_feasible {
if CROSS_PARETO_CUTS {
// the pareto of 2 added pareto is scanned linearly
// but with all cuts, pre-computing => no gain
i_current_max_ks = usize::min(i_ks_pareto + 1, i_max_ks);
break;
// it's compatible with next i_br but with the worst complexity
} else if PARETO_CUTS {
// increasing variance => we can skip all remaining
break;
}
continue;
}
let one_lut_cost = ks_quantity.complexity + br_quantity.complexity;
let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost);
let worse_complexity = complexity > best_complexity;
if worse_complexity {
continue;
// Since ks_pareto is scanned by increasing noise, we can stop
break;
}
let (peek_p_error, variance) = dag.peek_p_error(
@@ -215,6 +130,103 @@ fn update_best_solution_with_best_decompositions(
const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8;
fn update_no_luts_solution(
state: &mut OptimizationState,
consts: &OptimizationDecompositionsConsts,
dag: &analyze::OperationDag,
glwe_params: GlweParameters,
input_noise_out: f64,
) {
const CHECKED_IGNORED_NOISE: f64 = f64::MAX;
const UNDEFINED_PARAM: u64 = 0;
let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension();
let best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity);
let best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error);
let complexity = if dag.levelled_complexity == LevelledComplexity::ZERO {
// The compiler has given a 0 levelled complexity.
// There is no way to compare solutions.
// Assuming linear complexity.
input_lwe_dimension as f64
} else {
dag.levelled_complexity(input_lwe_dimension)
};
if complexity > best_complexity {
return;
}
let (p_error, variance) = dag.peek_p_error(
input_noise_out,
CHECKED_IGNORED_NOISE,
CHECKED_IGNORED_NOISE,
CHECKED_IGNORED_NOISE,
consts.kappa,
);
#[allow(clippy::float_cmp)]
let same_complexity_no_few_errors = complexity == best_complexity && p_error >= best_p_error;
if same_complexity_no_few_errors {
return;
}
// The complexity is either better or equivalent with less errors
state.best_solution = Some(Solution {
input_lwe_dimension,
internal_ks_output_lwe_dimension: UNDEFINED_PARAM,
ks_decomposition_level_count: UNDEFINED_PARAM,
ks_decomposition_base_log: UNDEFINED_PARAM,
glwe_polynomial_size: glwe_params.polynomial_size(),
glwe_dimension: glwe_params.glwe_dimension,
br_decomposition_level_count: UNDEFINED_PARAM,
br_decomposition_base_log: UNDEFINED_PARAM,
complexity,
p_error,
global_p_error: dag.global_p_error(
input_noise_out,
CHECKED_IGNORED_NOISE,
CHECKED_IGNORED_NOISE,
CHECKED_IGNORED_NOISE,
consts.kappa,
),
noise_max: variance,
});
}
fn minimal_variance(config: &Config, glwe_params: GlweParameters) -> f64 {
security::glwe::minimal_variance(
glwe_params,
config.ciphertext_modulus_log,
config.security_level,
)
.get_variance()
}
fn optimize_no_luts(
mut state: OptimizationState,
consts: &OptimizationDecompositionsConsts,
dag: &analyze::OperationDag,
search_space: &SearchSpace,
) -> OptimizationState {
let not_feasible = |input_noise_out| !dag.feasible(input_noise_out, 0.0, 0.0, 0.0);
for &glwe_dim in &search_space.glwe_dimensions {
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
let glwe_params = GlweParameters {
log2_polynomial_size: glwe_log_poly_size,
glwe_dimension: glwe_dim,
};
let input_noise_out = minimal_variance(&consts.config, glwe_params);
if not_feasible(input_noise_out) {
continue;
}
update_no_luts_solution(&mut state, consts, dag, glwe_params, input_noise_out);
}
}
state
}
#[allow(clippy::too_many_lines)]
pub fn optimize(
dag: &unparametrized::OperationDag,
config: Config,
@@ -251,6 +263,10 @@ pub fn optimize(
best_solution: None,
};
if dag.nb_luts == 0 {
return optimize_no_luts(state, &consts, &dag, search_space);
}
let mut caches = Caches {
blind_rotate: blind_rotate::for_security(security_level).cache(),
keyswitch: keyswitch::for_security(security_level).cache(),
@@ -264,20 +280,23 @@ pub fn optimize(
)
.get_variance()
};
let not_feasible =
|noise_modulus_switching| !dag.feasible(0.0, 0.0, 0.0, noise_modulus_switching);
let not_feasible = |input_noise_out, noise_modulus_switching| {
!dag.feasible(input_noise_out, 0.0, 0.0, noise_modulus_switching)
};
for &glwe_dim in &search_space.glwe_dimensions {
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
let glwe_poly_size = 1 << glwe_log_poly_size;
let glwe_params = GlweParameters {
log2_polynomial_size: glwe_log_poly_size,
glwe_dimension: glwe_dim,
};
let input_noise_out = minimal_variance(&config, glwe_params);
for &internal_dim in &search_space.internal_lwe_dimensions {
let glwe_poly_size = 1 << glwe_log_poly_size;
let noise_modulus_switching = noise_modulus_switching(glwe_poly_size, internal_dim);
if CUTS && not_feasible(noise_modulus_switching) {
// assume this noise is increasing with internal_dim
if not_feasible(input_noise_out, noise_modulus_switching) {
// noise_modulus_switching is increasing with internal_dim
break;
}
update_best_solution_with_best_decompositions(
@@ -286,6 +305,7 @@ pub fn optimize(
&dag,
internal_dim,
glwe_params,
input_noise_out,
noise_modulus_switching,
&mut caches,
);