diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index e5460f1c1..f29648958 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -96,6 +96,8 @@ pub struct OperationDag { // Collect all operators ouput variances pub out_variances: Vec, pub nb_luts: u64, + // True if all luts have noise with origin VarianceOrigin::Input + pub has_only_luts_with_inputs: bool, // The full dag levelled complexity pub levelled_complexity: LevelledComplexity, // Dominating variances and bounds per precision @@ -468,6 +470,9 @@ pub fn analyze( &in_luts_variance, noise_config, ); + let has_only_luts_with_inputs = in_luts_variance + .iter() + .all(|(_, _, sb)| sb.origin() == VarianceOrigin::Input); let result = OperationDag { operators: dag.operators.clone(), out_shapes, @@ -476,6 +481,7 @@ pub fn analyze( nb_luts, levelled_complexity, constraints_by_precisions, + has_only_luts_with_inputs, }; assert_properties_correctness(&result); result @@ -911,4 +917,24 @@ mod tests { prev_safe_noise_bound = ns.safe_variance_bound; } } + + #[test] + fn test_1_layer_lut() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(1, Shape::number()); + let _lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1); + let _lut2 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1); + let analysis = analyze(&graph); + assert!(analysis.has_only_luts_with_inputs); + } + + #[test] + fn test_2_layer_lut() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(1, Shape::number()); + let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1); + let _lut2 = graph.add_lut(lut1, FunctionTable::UNKWOWN, 1); + let analysis = analyze(&graph); + assert!(!analysis.has_only_luts_with_inputs); + } } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index c11aadc52..2e6e05d62 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -35,31 +35,64 @@ fn update_best_solution_with_best_decompositions( let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max); let mut best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error); - let mut cut_complexity = - (best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) / (dag.nb_luts as f64); - let mut cut_noise = safe_variance - noise_modulus_switching; + let input_noise_out = minimal_variance( + glwe_params, + consts.config.ciphertext_modulus_log, + consts.config.security_level, + ) + .get_variance(); - if dag.nb_luts == 0 { - cut_noise = f64::INFINITY; - cut_complexity = f64::INFINITY; + let no_luts = dag.nb_luts == 0; + // if no_luts we disable cuts, any parameters is acceptable in luts + let (cut_noise, cut_complexity) = if no_luts && CUTS { + (f64::INFINITY, f64::INFINITY) + } else { + ( + safe_variance - noise_modulus_switching, + (best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) + / (dag.nb_luts as f64), + ) + }; + + if input_noise_out > cut_noise { + // exact cut when has_only_luts_with_inputs, lower bound cut otherwise + return; } - let br_pareto = - pareto_blind_rotate::(consts, internal_dim, glwe_params, cut_complexity, cut_noise); + // if only one layer of luts, no cut inside pareto_blind_rotate based on br noise, + // since it's never use inside the lut + let br_cut_noise = if dag.has_only_luts_with_inputs { + f64::INFINITY + } else { + cut_noise + }; + let br_cut_complexity = cut_complexity; + + let br_pareto = pareto_blind_rotate::( + consts, + internal_dim, + glwe_params, + br_cut_complexity, + br_cut_noise, + ); if br_pareto.is_empty() { return; } - if PARETO_CUTS { - cut_noise -= br_pareto[br_pareto.len() - 1].noise; - cut_complexity -= br_pareto[0].complexity; - } + + let worst_input_ks_noise = if dag.has_only_luts_with_inputs { + input_noise_out + } else { + br_pareto.last().unwrap().noise + }; + let ks_cut_noise = cut_noise - worst_input_ks_noise; + let ks_cut_complexity = cut_complexity - br_pareto[0].complexity; let ks_pareto = pareto_keyswitch::( consts, input_lwe_dimension, internal_dim, - cut_complexity, - cut_noise, + ks_cut_complexity, + ks_cut_noise, ); if ks_pareto.is_empty() { return; @@ -67,12 +100,6 @@ fn update_best_solution_with_best_decompositions( let i_max_ks = ks_pareto.len() - 1; let mut i_current_max_ks = i_max_ks; - let input_noise_out = minimal_variance( - glwe_params, - consts.config.ciphertext_modulus_log, - consts.config.security_level, - ) - .get_variance(); let mut best_br_noise = f64::INFINITY; let mut best_ks_noise = f64::INFINITY; @@ -571,6 +598,36 @@ mod tests { } } + fn lut_1_layer_has_better_complexity(precision: Precision) { + let dag_1_layer = { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision as u8, Shape::number()); + let _lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision); + let _lut2 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision); + dag + }; + let dag_2_layer = { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision as u8, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision); + let _lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision); + dag + }; + + let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap(); + let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap(); + assert!(sol_1_layer.complexity < sol_2_layer.complexity); + } + + #[test] + fn test_lut_1_layer_is_better() { + // for some reason on 4, 5, 6, the complexity is already minimal + // this could be due to pre-defined pareto set + for precision in [1, 2, 3, 7, 8] { + lut_1_layer_has_better_complexity(precision); + } + } + fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) { let input = dag.add_input(precision, Shape::number()); let dot1 = dag.add_dot([input], [weight]);