mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
fix: dag with 1 tlu layer were sub-optimized assuming multi-layer in cuts
This commit is contained in:
@@ -96,6 +96,8 @@ pub struct OperationDag {
|
||||
// Collect all operators ouput variances
|
||||
pub out_variances: Vec<SymbolicVariance>,
|
||||
pub nb_luts: u64,
|
||||
// True if all luts have noise with origin VarianceOrigin::Input
|
||||
pub has_only_luts_with_inputs: bool,
|
||||
// The full dag levelled complexity
|
||||
pub levelled_complexity: LevelledComplexity,
|
||||
// Dominating variances and bounds per precision
|
||||
@@ -468,6 +470,9 @@ pub fn analyze(
|
||||
&in_luts_variance,
|
||||
noise_config,
|
||||
);
|
||||
let has_only_luts_with_inputs = in_luts_variance
|
||||
.iter()
|
||||
.all(|(_, _, sb)| sb.origin() == VarianceOrigin::Input);
|
||||
let result = OperationDag {
|
||||
operators: dag.operators.clone(),
|
||||
out_shapes,
|
||||
@@ -476,6 +481,7 @@ pub fn analyze(
|
||||
nb_luts,
|
||||
levelled_complexity,
|
||||
constraints_by_precisions,
|
||||
has_only_luts_with_inputs,
|
||||
};
|
||||
assert_properties_correctness(&result);
|
||||
result
|
||||
@@ -911,4 +917,24 @@ mod tests {
|
||||
prev_safe_noise_bound = ns.safe_variance_bound;
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1_layer_lut() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let _lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1);
|
||||
let _lut2 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1);
|
||||
let analysis = analyze(&graph);
|
||||
assert!(analysis.has_only_luts_with_inputs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_2_layer_lut() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1);
|
||||
let _lut2 = graph.add_lut(lut1, FunctionTable::UNKWOWN, 1);
|
||||
let analysis = analyze(&graph);
|
||||
assert!(!analysis.has_only_luts_with_inputs);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -35,31 +35,64 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
|
||||
let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max);
|
||||
let mut best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error);
|
||||
|
||||
let mut cut_complexity =
|
||||
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) / (dag.nb_luts as f64);
|
||||
let mut cut_noise = safe_variance - noise_modulus_switching;
|
||||
let input_noise_out = minimal_variance(
|
||||
glwe_params,
|
||||
consts.config.ciphertext_modulus_log,
|
||||
consts.config.security_level,
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
if dag.nb_luts == 0 {
|
||||
cut_noise = f64::INFINITY;
|
||||
cut_complexity = f64::INFINITY;
|
||||
let no_luts = dag.nb_luts == 0;
|
||||
// if no_luts we disable cuts, any parameters is acceptable in luts
|
||||
let (cut_noise, cut_complexity) = if no_luts && CUTS {
|
||||
(f64::INFINITY, f64::INFINITY)
|
||||
} else {
|
||||
(
|
||||
safe_variance - noise_modulus_switching,
|
||||
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0))
|
||||
/ (dag.nb_luts as f64),
|
||||
)
|
||||
};
|
||||
|
||||
if input_noise_out > cut_noise {
|
||||
// exact cut when has_only_luts_with_inputs, lower bound cut otherwise
|
||||
return;
|
||||
}
|
||||
|
||||
let br_pareto =
|
||||
pareto_blind_rotate::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
|
||||
// if only one layer of luts, no cut inside pareto_blind_rotate based on br noise,
|
||||
// since it's never use inside the lut
|
||||
let br_cut_noise = if dag.has_only_luts_with_inputs {
|
||||
f64::INFINITY
|
||||
} else {
|
||||
cut_noise
|
||||
};
|
||||
let br_cut_complexity = cut_complexity;
|
||||
|
||||
let br_pareto = pareto_blind_rotate::<W>(
|
||||
consts,
|
||||
internal_dim,
|
||||
glwe_params,
|
||||
br_cut_complexity,
|
||||
br_cut_noise,
|
||||
);
|
||||
if br_pareto.is_empty() {
|
||||
return;
|
||||
}
|
||||
if PARETO_CUTS {
|
||||
cut_noise -= br_pareto[br_pareto.len() - 1].noise;
|
||||
cut_complexity -= br_pareto[0].complexity;
|
||||
}
|
||||
|
||||
let worst_input_ks_noise = if dag.has_only_luts_with_inputs {
|
||||
input_noise_out
|
||||
} else {
|
||||
br_pareto.last().unwrap().noise
|
||||
};
|
||||
let ks_cut_noise = cut_noise - worst_input_ks_noise;
|
||||
let ks_cut_complexity = cut_complexity - br_pareto[0].complexity;
|
||||
|
||||
let ks_pareto = pareto_keyswitch::<W>(
|
||||
consts,
|
||||
input_lwe_dimension,
|
||||
internal_dim,
|
||||
cut_complexity,
|
||||
cut_noise,
|
||||
ks_cut_complexity,
|
||||
ks_cut_noise,
|
||||
);
|
||||
if ks_pareto.is_empty() {
|
||||
return;
|
||||
@@ -67,12 +100,6 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
|
||||
|
||||
let i_max_ks = ks_pareto.len() - 1;
|
||||
let mut i_current_max_ks = i_max_ks;
|
||||
let input_noise_out = minimal_variance(
|
||||
glwe_params,
|
||||
consts.config.ciphertext_modulus_log,
|
||||
consts.config.security_level,
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
let mut best_br_noise = f64::INFINITY;
|
||||
let mut best_ks_noise = f64::INFINITY;
|
||||
@@ -571,6 +598,36 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
fn lut_1_layer_has_better_complexity(precision: Precision) {
|
||||
let dag_1_layer = {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision as u8, Shape::number());
|
||||
let _lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision);
|
||||
let _lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision);
|
||||
dag
|
||||
};
|
||||
let dag_2_layer = {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision as u8, Shape::number());
|
||||
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision);
|
||||
let _lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision);
|
||||
dag
|
||||
};
|
||||
|
||||
let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap();
|
||||
let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap();
|
||||
assert!(sol_1_layer.complexity < sol_2_layer.complexity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lut_1_layer_is_better() {
|
||||
// for some reason on 4, 5, 6, the complexity is already minimal
|
||||
// this could be due to pre-defined pareto set
|
||||
for precision in [1, 2, 3, 7, 8] {
|
||||
lut_1_layer_has_better_complexity(precision);
|
||||
}
|
||||
}
|
||||
|
||||
fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) {
|
||||
let input = dag.add_input(precision, Shape::number());
|
||||
let dot1 = dag.add_dot([input], [weight]);
|
||||
|
||||
Reference in New Issue
Block a user