mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
chore: clarify no luts optimization and cuts
This commit is contained in:
@@ -655,11 +655,15 @@ impl OperationDag {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
|
||||
pub fn complexity(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
|
||||
let luts_cost = one_lut_cost * (self.nb_luts as f64);
|
||||
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
|
||||
luts_cost + levelled_cost
|
||||
}
|
||||
|
||||
pub fn levelled_complexity(&self, input_lwe_dimension: u64) -> f64 {
|
||||
self.levelled_complexity.cost(input_lwe_dimension)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -701,7 +705,7 @@ mod tests {
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT);
|
||||
assert_eq!(analysis.out_shapes[input1.i], Shape::number());
|
||||
@@ -724,7 +728,7 @@ mod tests {
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT);
|
||||
assert!(analysis.out_shapes[lut1.i] == Shape::number());
|
||||
@@ -750,7 +754,7 @@ mod tests {
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
let expected_var = SymbolicVariance {
|
||||
input_coeff: norm2,
|
||||
@@ -781,7 +785,7 @@ mod tests {
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
assert!(analysis.out_variances[dot.i].origin() == VO::Input);
|
||||
assert_eq!(analysis.out_precisions[dot.i], 3);
|
||||
@@ -810,7 +814,7 @@ mod tests {
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
let expected_var_dot1 = SymbolicVariance {
|
||||
input_coeff: weights.square_norm2() as f64,
|
||||
@@ -858,7 +862,7 @@ mod tests {
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost);
|
||||
|
||||
let expected_cost = (2 * lwe_dim) as f64 + 2.0 * one_lut_cost;
|
||||
assert_f64_eq(expected_cost, complexity_cost);
|
||||
|
||||
@@ -7,16 +7,12 @@ use crate::optimization::atomic_pattern::{
|
||||
Caches, OptimizationDecompositionsConsts, OptimizationState, Solution,
|
||||
};
|
||||
use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace};
|
||||
use crate::optimization::decomposition::{blind_rotate, cut_complexity_noise, keyswitch};
|
||||
use crate::optimization::decomposition::{blind_rotate, keyswitch};
|
||||
use crate::parameters::GlweParameters;
|
||||
use crate::security::glwe::minimal_variance;
|
||||
use crate::security;
|
||||
|
||||
use concrete_commons::dispersion::DispersionParameter;
|
||||
|
||||
const CUTS: bool = true;
|
||||
const PARETO_CUTS: bool = true;
|
||||
const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn update_best_solution_with_best_decompositions(
|
||||
state: &mut OptimizationState,
|
||||
@@ -24,10 +20,11 @@ fn update_best_solution_with_best_decompositions(
|
||||
dag: &analyze::OperationDag,
|
||||
internal_dim: u64,
|
||||
glwe_params: GlweParameters,
|
||||
input_noise_out: f64,
|
||||
noise_modulus_switching: f64,
|
||||
caches: &mut Caches,
|
||||
) {
|
||||
let safe_variance = consts.safe_variance;
|
||||
assert!(dag.nb_luts > 0);
|
||||
let glwe_poly_size = glwe_params.polynomial_size();
|
||||
let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size;
|
||||
|
||||
@@ -35,134 +32,52 @@ fn update_best_solution_with_best_decompositions(
|
||||
let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max);
|
||||
let mut best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error);
|
||||
|
||||
let input_noise_out = minimal_variance(
|
||||
glwe_params,
|
||||
consts.config.ciphertext_modulus_log,
|
||||
consts.config.security_level,
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
let no_luts = dag.nb_luts == 0;
|
||||
// if no_luts we disable cuts, any parameters is acceptable in luts
|
||||
let (cut_noise, cut_complexity) = if no_luts && CUTS {
|
||||
(f64::INFINITY, f64::INFINITY)
|
||||
} else {
|
||||
(
|
||||
safe_variance - noise_modulus_switching,
|
||||
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0))
|
||||
/ (dag.nb_luts as f64),
|
||||
)
|
||||
};
|
||||
|
||||
if input_noise_out > cut_noise {
|
||||
// exact cut when has_only_luts_with_inputs, lower bound cut otherwise
|
||||
return;
|
||||
}
|
||||
|
||||
// if only one layer of luts, no cut inside pareto_blind_rotate based on br noise,
|
||||
// since it's never use inside the lut
|
||||
let br_cut_noise = if dag.has_only_luts_with_inputs {
|
||||
f64::INFINITY
|
||||
} else {
|
||||
cut_noise
|
||||
};
|
||||
let br_cut_complexity = cut_complexity;
|
||||
|
||||
let br_pareto = caches
|
||||
.blind_rotate
|
||||
.pareto_quantities(glwe_params, internal_dim);
|
||||
// if only one layer of luts, no cut inside pareto_blind_rotate based on br noise,
|
||||
// since this noise is never used inside a lut
|
||||
let br_pareto = if dag.has_only_luts_with_inputs {
|
||||
br_pareto
|
||||
} else {
|
||||
cut_complexity_noise(br_cut_complexity, br_cut_noise, br_pareto)
|
||||
};
|
||||
|
||||
if br_pareto.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let worst_input_ks_noise = if dag.has_only_luts_with_inputs {
|
||||
input_noise_out
|
||||
} else {
|
||||
br_pareto.last().unwrap().noise
|
||||
};
|
||||
let ks_cut_noise = cut_noise - worst_input_ks_noise;
|
||||
let ks_cut_complexity = cut_complexity - br_pareto[0].complexity;
|
||||
|
||||
let ks_pareto = caches
|
||||
.keyswitch
|
||||
.pareto_quantities(glwe_params, internal_dim);
|
||||
let ks_pareto = cut_complexity_noise(ks_cut_complexity, ks_cut_noise, ks_pareto);
|
||||
|
||||
if ks_pareto.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let i_max_ks = ks_pareto.len() - 1;
|
||||
let mut i_current_max_ks = i_max_ks;
|
||||
let input_noise_out = minimal_variance(
|
||||
glwe_params,
|
||||
consts.config.ciphertext_modulus_log,
|
||||
consts.config.security_level,
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
// by constructon br_pareto and ks_pareto are non-empty
|
||||
let mut best_br = br_pareto[0];
|
||||
let mut best_ks = ks_pareto[0];
|
||||
let mut update_best_solution = false;
|
||||
|
||||
for &br_quantity in br_pareto {
|
||||
// increasing complexity, decreasing variance
|
||||
let one_lut_cost = br_quantity.complexity;
|
||||
let complexity = dag.complexity(input_lwe_dimension, one_lut_cost);
|
||||
if complexity > best_complexity {
|
||||
// Since br_pareto is scanned by increasing complexity, we can stop
|
||||
break;
|
||||
}
|
||||
let not_feasible = !dag.feasible(
|
||||
input_noise_out,
|
||||
br_quantity.noise,
|
||||
0.0,
|
||||
noise_modulus_switching,
|
||||
);
|
||||
if not_feasible && CUTS {
|
||||
if not_feasible {
|
||||
continue;
|
||||
}
|
||||
let one_lut_cost = br_quantity.complexity;
|
||||
let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost);
|
||||
if complexity > best_complexity {
|
||||
// As best can evolves it is complementary to blind_rotate_quantities cuts.
|
||||
if PARETO_CUTS {
|
||||
break;
|
||||
} else if CUTS {
|
||||
for &ks_quantity in ks_pareto.iter().rev() {
|
||||
let one_lut_cost = ks_quantity.complexity + br_quantity.complexity;
|
||||
let complexity = dag.complexity(input_lwe_dimension, one_lut_cost);
|
||||
let worse_complexity = complexity > best_complexity;
|
||||
if worse_complexity {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
for i_ks_pareto in (0..=i_current_max_ks).rev() {
|
||||
// increasing variance, decreasing complexity
|
||||
let ks_quantity = ks_pareto[i_ks_pareto];
|
||||
let not_feasible = !dag.feasible(
|
||||
input_noise_out,
|
||||
br_quantity.noise,
|
||||
ks_quantity.noise,
|
||||
noise_modulus_switching,
|
||||
);
|
||||
// let noise_max = br_quantity.noise * dag.lut_base_noise_worst_lut + ks_quantity.noise + noise_modulus_switching;
|
||||
if not_feasible {
|
||||
if CROSS_PARETO_CUTS {
|
||||
// the pareto of 2 added pareto is scanned linearly
|
||||
// but with all cuts, pre-computing => no gain
|
||||
i_current_max_ks = usize::min(i_ks_pareto + 1, i_max_ks);
|
||||
break;
|
||||
// it's compatible with next i_br but with the worst complexity
|
||||
} else if PARETO_CUTS {
|
||||
// increasing variance => we can skip all remaining
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let one_lut_cost = ks_quantity.complexity + br_quantity.complexity;
|
||||
let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost);
|
||||
let worse_complexity = complexity > best_complexity;
|
||||
if worse_complexity {
|
||||
continue;
|
||||
// Since ks_pareto is scanned by increasing noise, we can stop
|
||||
break;
|
||||
}
|
||||
|
||||
let (peek_p_error, variance) = dag.peek_p_error(
|
||||
@@ -215,6 +130,103 @@ fn update_best_solution_with_best_decompositions(
|
||||
|
||||
const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8;
|
||||
|
||||
fn update_no_luts_solution(
|
||||
state: &mut OptimizationState,
|
||||
consts: &OptimizationDecompositionsConsts,
|
||||
dag: &analyze::OperationDag,
|
||||
glwe_params: GlweParameters,
|
||||
input_noise_out: f64,
|
||||
) {
|
||||
const CHECKED_IGNORED_NOISE: f64 = f64::MAX;
|
||||
const UNDEFINED_PARAM: u64 = 0;
|
||||
|
||||
let input_lwe_dimension = glwe_params.sample_extract_lwe_dimension();
|
||||
|
||||
let best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity);
|
||||
let best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error);
|
||||
|
||||
let complexity = if dag.levelled_complexity == LevelledComplexity::ZERO {
|
||||
// The compiler has given a 0 levelled complexity.
|
||||
// There is no way to compare solutions.
|
||||
// Assuming linear complexity.
|
||||
input_lwe_dimension as f64
|
||||
} else {
|
||||
dag.levelled_complexity(input_lwe_dimension)
|
||||
};
|
||||
|
||||
if complexity > best_complexity {
|
||||
return;
|
||||
}
|
||||
|
||||
let (p_error, variance) = dag.peek_p_error(
|
||||
input_noise_out,
|
||||
CHECKED_IGNORED_NOISE,
|
||||
CHECKED_IGNORED_NOISE,
|
||||
CHECKED_IGNORED_NOISE,
|
||||
consts.kappa,
|
||||
);
|
||||
|
||||
#[allow(clippy::float_cmp)]
|
||||
let same_complexity_no_few_errors = complexity == best_complexity && p_error >= best_p_error;
|
||||
if same_complexity_no_few_errors {
|
||||
return;
|
||||
}
|
||||
// The complexity is either better or equivalent with less errors
|
||||
state.best_solution = Some(Solution {
|
||||
input_lwe_dimension,
|
||||
internal_ks_output_lwe_dimension: UNDEFINED_PARAM,
|
||||
ks_decomposition_level_count: UNDEFINED_PARAM,
|
||||
ks_decomposition_base_log: UNDEFINED_PARAM,
|
||||
glwe_polynomial_size: glwe_params.polynomial_size(),
|
||||
glwe_dimension: glwe_params.glwe_dimension,
|
||||
br_decomposition_level_count: UNDEFINED_PARAM,
|
||||
br_decomposition_base_log: UNDEFINED_PARAM,
|
||||
complexity,
|
||||
p_error,
|
||||
global_p_error: dag.global_p_error(
|
||||
input_noise_out,
|
||||
CHECKED_IGNORED_NOISE,
|
||||
CHECKED_IGNORED_NOISE,
|
||||
CHECKED_IGNORED_NOISE,
|
||||
consts.kappa,
|
||||
),
|
||||
noise_max: variance,
|
||||
});
|
||||
}
|
||||
|
||||
fn minimal_variance(config: &Config, glwe_params: GlweParameters) -> f64 {
|
||||
security::glwe::minimal_variance(
|
||||
glwe_params,
|
||||
config.ciphertext_modulus_log,
|
||||
config.security_level,
|
||||
)
|
||||
.get_variance()
|
||||
}
|
||||
|
||||
fn optimize_no_luts(
|
||||
mut state: OptimizationState,
|
||||
consts: &OptimizationDecompositionsConsts,
|
||||
dag: &analyze::OperationDag,
|
||||
search_space: &SearchSpace,
|
||||
) -> OptimizationState {
|
||||
let not_feasible = |input_noise_out| !dag.feasible(input_noise_out, 0.0, 0.0, 0.0);
|
||||
for &glwe_dim in &search_space.glwe_dimensions {
|
||||
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
|
||||
let glwe_params = GlweParameters {
|
||||
log2_polynomial_size: glwe_log_poly_size,
|
||||
glwe_dimension: glwe_dim,
|
||||
};
|
||||
let input_noise_out = minimal_variance(&consts.config, glwe_params);
|
||||
if not_feasible(input_noise_out) {
|
||||
continue;
|
||||
}
|
||||
update_no_luts_solution(&mut state, consts, dag, glwe_params, input_noise_out);
|
||||
}
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn optimize(
|
||||
dag: &unparametrized::OperationDag,
|
||||
config: Config,
|
||||
@@ -251,6 +263,10 @@ pub fn optimize(
|
||||
best_solution: None,
|
||||
};
|
||||
|
||||
if dag.nb_luts == 0 {
|
||||
return optimize_no_luts(state, &consts, &dag, search_space);
|
||||
}
|
||||
|
||||
let mut caches = Caches {
|
||||
blind_rotate: blind_rotate::for_security(security_level).cache(),
|
||||
keyswitch: keyswitch::for_security(security_level).cache(),
|
||||
@@ -264,20 +280,23 @@ pub fn optimize(
|
||||
)
|
||||
.get_variance()
|
||||
};
|
||||
let not_feasible =
|
||||
|noise_modulus_switching| !dag.feasible(0.0, 0.0, 0.0, noise_modulus_switching);
|
||||
|
||||
let not_feasible = |input_noise_out, noise_modulus_switching| {
|
||||
!dag.feasible(input_noise_out, 0.0, 0.0, noise_modulus_switching)
|
||||
};
|
||||
|
||||
for &glwe_dim in &search_space.glwe_dimensions {
|
||||
for &glwe_log_poly_size in &search_space.glwe_log_polynomial_sizes {
|
||||
let glwe_poly_size = 1 << glwe_log_poly_size;
|
||||
let glwe_params = GlweParameters {
|
||||
log2_polynomial_size: glwe_log_poly_size,
|
||||
glwe_dimension: glwe_dim,
|
||||
};
|
||||
let input_noise_out = minimal_variance(&config, glwe_params);
|
||||
for &internal_dim in &search_space.internal_lwe_dimensions {
|
||||
let glwe_poly_size = 1 << glwe_log_poly_size;
|
||||
let noise_modulus_switching = noise_modulus_switching(glwe_poly_size, internal_dim);
|
||||
if CUTS && not_feasible(noise_modulus_switching) {
|
||||
// assume this noise is increasing with internal_dim
|
||||
if not_feasible(input_noise_out, noise_modulus_switching) {
|
||||
// noise_modulus_switching is increasing with internal_dim
|
||||
break;
|
||||
}
|
||||
update_best_solution_with_best_decompositions(
|
||||
@@ -286,6 +305,7 @@ pub fn optimize(
|
||||
&dag,
|
||||
internal_dim,
|
||||
glwe_params,
|
||||
input_noise_out,
|
||||
noise_modulus_switching,
|
||||
&mut caches,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user