fix: dag with 1 tlu layer were sub-optimized assuming multi-layer in cuts

This commit is contained in:
rudy
2022-08-16 18:43:20 +02:00
committed by rudy-6-4
parent 915cbe6647
commit 320e3c1963
2 changed files with 103 additions and 20 deletions

View File

@@ -96,6 +96,8 @@ pub struct OperationDag {
// Collect all operators ouput variances
pub out_variances: Vec<SymbolicVariance>,
pub nb_luts: u64,
// True if all luts have noise with origin VarianceOrigin::Input
pub has_only_luts_with_inputs: bool,
// The full dag levelled complexity
pub levelled_complexity: LevelledComplexity,
// Dominating variances and bounds per precision
@@ -468,6 +470,9 @@ pub fn analyze(
&in_luts_variance,
noise_config,
);
let has_only_luts_with_inputs = in_luts_variance
.iter()
.all(|(_, _, sb)| sb.origin() == VarianceOrigin::Input);
let result = OperationDag {
operators: dag.operators.clone(),
out_shapes,
@@ -476,6 +481,7 @@ pub fn analyze(
nb_luts,
levelled_complexity,
constraints_by_precisions,
has_only_luts_with_inputs,
};
assert_properties_correctness(&result);
result
@@ -911,4 +917,24 @@ mod tests {
prev_safe_noise_bound = ns.safe_variance_bound;
}
}
#[test]
fn test_1_layer_lut() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let _lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1);
let _lut2 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1);
let analysis = analyze(&graph);
assert!(analysis.has_only_luts_with_inputs);
}
#[test]
fn test_2_layer_lut() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1);
let _lut2 = graph.add_lut(lut1, FunctionTable::UNKWOWN, 1);
let analysis = analyze(&graph);
assert!(!analysis.has_only_luts_with_inputs);
}
}

View File

@@ -35,31 +35,64 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max);
let mut best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error);
let mut cut_complexity =
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) / (dag.nb_luts as f64);
let mut cut_noise = safe_variance - noise_modulus_switching;
let input_noise_out = minimal_variance(
glwe_params,
consts.config.ciphertext_modulus_log,
consts.config.security_level,
)
.get_variance();
if dag.nb_luts == 0 {
cut_noise = f64::INFINITY;
cut_complexity = f64::INFINITY;
let no_luts = dag.nb_luts == 0;
// if no_luts we disable cuts, any parameters is acceptable in luts
let (cut_noise, cut_complexity) = if no_luts && CUTS {
(f64::INFINITY, f64::INFINITY)
} else {
(
safe_variance - noise_modulus_switching,
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0))
/ (dag.nb_luts as f64),
)
};
if input_noise_out > cut_noise {
// exact cut when has_only_luts_with_inputs, lower bound cut otherwise
return;
}
let br_pareto =
pareto_blind_rotate::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
// if only one layer of luts, no cut inside pareto_blind_rotate based on br noise,
// since it's never use inside the lut
let br_cut_noise = if dag.has_only_luts_with_inputs {
f64::INFINITY
} else {
cut_noise
};
let br_cut_complexity = cut_complexity;
let br_pareto = pareto_blind_rotate::<W>(
consts,
internal_dim,
glwe_params,
br_cut_complexity,
br_cut_noise,
);
if br_pareto.is_empty() {
return;
}
if PARETO_CUTS {
cut_noise -= br_pareto[br_pareto.len() - 1].noise;
cut_complexity -= br_pareto[0].complexity;
}
let worst_input_ks_noise = if dag.has_only_luts_with_inputs {
input_noise_out
} else {
br_pareto.last().unwrap().noise
};
let ks_cut_noise = cut_noise - worst_input_ks_noise;
let ks_cut_complexity = cut_complexity - br_pareto[0].complexity;
let ks_pareto = pareto_keyswitch::<W>(
consts,
input_lwe_dimension,
internal_dim,
cut_complexity,
cut_noise,
ks_cut_complexity,
ks_cut_noise,
);
if ks_pareto.is_empty() {
return;
@@ -67,12 +100,6 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
let i_max_ks = ks_pareto.len() - 1;
let mut i_current_max_ks = i_max_ks;
let input_noise_out = minimal_variance(
glwe_params,
consts.config.ciphertext_modulus_log,
consts.config.security_level,
)
.get_variance();
let mut best_br_noise = f64::INFINITY;
let mut best_ks_noise = f64::INFINITY;
@@ -571,6 +598,36 @@ mod tests {
}
}
fn lut_1_layer_has_better_complexity(precision: Precision) {
let dag_1_layer = {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, Shape::number());
let _lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision);
let _lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision);
dag
};
let dag_2_layer = {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, Shape::number());
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision);
let _lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision);
dag
};
let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap();
let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap();
assert!(sol_1_layer.complexity < sol_2_layer.complexity);
}
#[test]
fn test_lut_1_layer_is_better() {
// for some reason on 4, 5, 6, the complexity is already minimal
// this could be due to pre-defined pareto set
for precision in [1, 2, 3, 7, 8] {
lut_1_layer_has_better_complexity(precision);
}
}
fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) {
let input = dag.add_input(precision, Shape::number());
let dot1 = dag.add_dot([input], [weight]);