From b3e3a10f2245f1e050e1811cd560f3ba66ce3e2f Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 25 May 2022 17:15:26 +0200 Subject: [PATCH] feat: precision specific noise summary --- .../src/noise_estimator/error.rs | 2 +- .../src/optimization/atomic_pattern.rs | 4 +- concrete-optimizer/src/optimization/config.rs | 6 + .../src/optimization/dag/solo_key/analyze.rs | 406 ++++++++++++++---- .../src/optimization/dag/solo_key/optimize.rs | 258 ++++++----- .../dag/solo_key/symbolic_variance.rs | 37 +- concrete-optimizer/src/optimization/mod.rs | 1 + .../wop_atomic_pattern/optimize.rs | 1 - 8 files changed, 496 insertions(+), 219 deletions(-) create mode 100644 concrete-optimizer/src/optimization/config.rs diff --git a/concrete-optimizer/src/noise_estimator/error.rs b/concrete-optimizer/src/noise_estimator/error.rs index 55f2fa48e..e7b26c05f 100644 --- a/concrete-optimizer/src/noise_estimator/error.rs +++ b/concrete-optimizer/src/noise_estimator/error.rs @@ -23,7 +23,7 @@ pub fn fatal_noise_limit(precision: u64, ciphertext_modulus_log: u64) -> f64 { 2_f64.powi(noise_bits as i32) } -pub fn variance_max( +pub fn safe_variance_bound( precision: u64, ciphertext_modulus_log: u64, maximum_acceptable_error_probability: f64, diff --git a/concrete-optimizer/src/optimization/atomic_pattern.rs b/concrete-optimizer/src/optimization/atomic_pattern.rs index bd0e7d22b..5c031c9b3 100644 --- a/concrete-optimizer/src/optimization/atomic_pattern.rs +++ b/concrete-optimizer/src/optimization/atomic_pattern.rs @@ -34,7 +34,6 @@ pub struct Solution { pub br_decomposition_level_count: u64, //l(BR) pub br_decomposition_base_log: u64, //b(BR) pub complexity: f64, - pub lut_complexity: f64, pub noise_max: f64, pub p_error: f64, // error probability } @@ -342,7 +341,6 @@ fn update_state_with_best_decompositions( br_decomposition_base_log: br_b, noise_max, complexity, - lut_complexity: complexity_keyswitch + complexity_pbs, p_error, }); } @@ -476,7 +474,7 @@ pub fn optimize_one( // the blind rotate decomposition let ciphertext_modulus_log = W::BITS as u64; - let safe_variance = error::variance_max( + let safe_variance = error::safe_variance_bound( precision, ciphertext_modulus_log, maximum_acceptable_error_probability, diff --git a/concrete-optimizer/src/optimization/config.rs b/concrete-optimizer/src/optimization/config.rs new file mode 100644 index 000000000..0147596a2 --- /dev/null +++ b/concrete-optimizer/src/optimization/config.rs @@ -0,0 +1,6 @@ +#[derive(Clone, Copy, Debug)] +pub struct NoiseBoundConfig { + pub security_level: u64, + pub maximum_acceptable_error_probability: f64, + pub ciphertext_modulus_log: u64, +} diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 75ddef397..0f6a1348e 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -3,7 +3,10 @@ use crate::dag::operator::{ dot_kind, DotKind, LevelledComplexity, OperatorIndex, Precision, Shape, }; use crate::dag::unparametrized; +use crate::noise_estimator::error; +use crate::optimization::config::NoiseBoundConfig; use crate::utils::square; +use std::collections::HashSet; // private short convention use DotKind as DK; @@ -58,8 +61,8 @@ fn assert_valid_variances(dag: &OperationDag) { for &out_variance in &dag.out_variances { assert!( SymbolicVariance::ZERO == out_variance // Special case of multiply by 0 - || 1.0 <= out_variance.input_vf - || 1.0 <= out_variance.lut_vf + || 1.0 <= out_variance.input_coeff + || 1.0 <= out_variance.lut_coeff ); } } @@ -95,27 +98,48 @@ pub struct OperationDag { pub nb_luts: u64, // The full dag levelled complexity pub levelled_complexity: LevelledComplexity, - // Global summaries of worst noise cases - pub noise_summary: NoiseSummary, + // Dominating variances and bounds per precision + pub constraints_by_precisions: Vec, } #[derive(Clone, Debug)] -pub struct NoiseSummary { +pub struct VariancesAndBound { + pub precision: Precision, + pub safe_variance_bound: f64, + pub nb_luts: u64, // All final variance factor not entering a lut (usually final levelledOp) - pub pareto_vfs_final: Vec, + pub pareto_output: Vec, // All variance factor entering a lut - pub pareto_vfs_in_lut: Vec, + pub pareto_in_lut: Vec, } impl OperationDag { - pub fn peek_variance( + pub fn peek_p_error( &self, input_noise_out: f64, blind_rotate_noise_out: f64, noise_keyswitch: f64, noise_modulus_switching: f64, - ) -> f64 { - peek_variance( + kappa: f64, + ) -> (f64, f64) { + peak_p_error( + self, + input_noise_out, + blind_rotate_noise_out, + noise_keyswitch, + noise_modulus_switching, + kappa, + ) + } + + pub fn feasible( + &self, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, + ) -> bool { + feasible( self, input_noise_out, blind_rotate_noise_out, @@ -263,14 +287,15 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec fn extra_final_variances( dag: &unparametrized::OperationDag, + out_precisions: &[Precision], out_variances: &[SymbolicVariance], -) -> Vec { +) -> Vec<(Precision, SymbolicVariance)> { extra_final_values_to_check(dag) .iter() .enumerate() .filter_map(|(i, &is_final)| { if is_final { - Some(out_variances[i]) + Some((out_precisions[i], out_variances[i])) } else { None } @@ -280,11 +305,12 @@ fn extra_final_variances( fn in_luts_variance( dag: &unparametrized::OperationDag, + out_precisions: &[Precision], out_variances: &[SymbolicVariance], -) -> Vec { +) -> Vec<(Precision, SymbolicVariance)> { let only_luts = |op| { if let &Op::Lut { input, .. } = op { - Some(out_variances[input.i]) + Some((out_precisions[input.i], out_variances[input.i])) } else { None } @@ -324,34 +350,85 @@ fn levelled_complexity( levelled_complexity } -fn max_update(current: &mut f64, candidate: f64) { - if candidate > *current { - *current = candidate; - } +fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f64 { + error::safe_variance_bound( + precision as u64, + noise_config.ciphertext_modulus_log, + noise_config.maximum_acceptable_error_probability, + ) } -fn noise_summary( - final_variances: Vec, - in_luts_variance: Vec, -) -> NoiseSummary { - let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(final_variances); +fn constraints_by_precisions( + out_precisions: &[Precision], + final_variances: &[(Precision, SymbolicVariance)], + in_luts_variance: &[(Precision, SymbolicVariance)], + noise_config: &NoiseBoundConfig, +) -> Vec { + let precisions: HashSet = out_precisions.iter().copied().collect(); + let mut precisions: Vec = precisions.iter().copied().collect(); + let to_noise_summary = |precision: &Precision| { + constraint_for_one_precision( + *precision as Precision, + final_variances, + in_luts_variance, + safe_noise_bound(*precision as Precision, noise_config), + ) + }; + // High precision first + precisions.sort_unstable(); + precisions.iter().rev().map(to_noise_summary).collect() +} + +fn select_precision(target_precision: Precision, v: &[(Precision, T)]) -> Vec { + v.iter() + .filter_map(|(p, t)| { + if *p == target_precision { + Some(*t) + } else { + None + } + }) + .collect() +} + +fn constraint_for_one_precision( + target_precision: Precision, + extra_final_variances: &[(Precision, SymbolicVariance)], + in_luts_variance: &[(Precision, SymbolicVariance)], + safe_noise_bound: f64, +) -> VariancesAndBound { + let extra_final_variances = select_precision(target_precision, extra_final_variances); + let in_luts_variance = select_precision(target_precision, in_luts_variance); + let nb_luts = in_luts_variance.len() as u64; + let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(extra_final_variances); let pareto_vfs_in_lut = SymbolicVariance::reduce_to_pareto_front(in_luts_variance); - NoiseSummary { - pareto_vfs_final, - pareto_vfs_in_lut, + VariancesAndBound { + precision: target_precision, + safe_variance_bound: safe_noise_bound, + nb_luts, + pareto_output: pareto_vfs_final, + pareto_in_lut: pareto_vfs_in_lut, } } -pub fn analyze(dag: &unparametrized::OperationDag) -> OperationDag { +pub fn analyze( + dag: &unparametrized::OperationDag, + noise_config: &NoiseBoundConfig, +) -> OperationDag { assert_dag_correctness(dag); let out_shapes = out_shapes(dag); let out_precisions = out_precisions(dag); let out_variances = out_variances(dag, &out_shapes); - let in_luts_variance = in_luts_variance(dag, &out_variances); + let in_luts_variance = in_luts_variance(dag, &out_precisions, &out_variances); let nb_luts = in_luts_variance.len() as u64; - let extra_final_variances = extra_final_variances(dag, &out_variances); + let extra_final_variances = extra_final_variances(dag, &out_precisions, &out_variances); let levelled_complexity = levelled_complexity(dag, &out_shapes); - let noise_summary = noise_summary(extra_final_variances, in_luts_variance); + let constraints_by_precisions = constraints_by_precisions( + &out_precisions, + &extra_final_variances, + &in_luts_variance, + noise_config, + ); let result = OperationDag { operators: dag.operators.clone(), out_shapes, @@ -359,39 +436,122 @@ pub fn analyze(dag: &unparametrized::OperationDag) -> OperationDag { out_variances, nb_luts, levelled_complexity, - noise_summary, + constraints_by_precisions, }; assert_properties_correctness(&result); result } +fn max_update(current: &mut f64, candidate: f64) { + if candidate > *current { + *current = candidate; + } +} + // Compute the maximum attained variance for the full dag // TODO take a noise summary => peek_error or global error -fn peek_variance( - dag: &OperationDag, +fn peak_variance_per_constraint( + constraint: &VariancesAndBound, input_noise_out: f64, blind_rotate_noise_out: f64, noise_keyswitch: f64, noise_modulus_switching: f64, ) -> f64 { - assert!(input_noise_out < blind_rotate_noise_out); - let mut variance_peek_final = 0.0; // updated by the loop - for vf in &dag.noise_summary.pareto_vfs_final { + assert!(input_noise_out < blind_rotate_noise_out || blind_rotate_noise_out == 0.0); + // the maximal variance encountered as an output that can be decrypted + let mut variance_output = 0.0; + for vf in &constraint.pareto_output { max_update( - &mut variance_peek_final, + &mut variance_output, vf.eval(input_noise_out, blind_rotate_noise_out), ); } + if constraint.pareto_in_lut.is_empty() { + return variance_output; + } + // the maximal variance encountered during a lut computation + let mut variance_in_lut = 0.0; + for vf in &constraint.pareto_in_lut { + max_update( + &mut variance_in_lut, + vf.eval(input_noise_out, blind_rotate_noise_out), + ); + } + let peek_in_lut = variance_in_lut + noise_keyswitch + noise_modulus_switching; + peek_in_lut.max(variance_output) +} - let mut variance_peek_in_lut = 0.0; // updated by the loop - for vf in &dag.noise_summary.pareto_vfs_in_lut { - max_update( - &mut variance_peek_in_lut, - vf.eval(input_noise_out, blind_rotate_noise_out), +// Compute the maximum attained relative variance for the full dag +fn peak_relative_variance( + dag: &OperationDag, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, +) -> (f64, f64) { + assert!(!dag.constraints_by_precisions.is_empty()); + assert!(input_noise_out <= blind_rotate_noise_out); + let mut max_relative_var = 0.0; + let mut safe_noise = 0.0; + for ns in &dag.constraints_by_precisions { + let variance_max = peak_variance_per_constraint( + ns, + input_noise_out, + blind_rotate_noise_out, + noise_keyswitch, + noise_modulus_switching, ); + let relative_var = variance_max / ns.safe_variance_bound; + if max_relative_var < relative_var { + max_relative_var = relative_var; + safe_noise = ns.safe_variance_bound; + } } - let peek_in_lut = variance_peek_in_lut + noise_keyswitch + noise_modulus_switching; - peek_in_lut.max(variance_peek_final) + (max_relative_var, safe_noise) +} + +fn peak_p_error( + dag: &OperationDag, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, + kappa: f64, +) -> (f64, f64) { + let (relative_var, variance_bound) = peak_relative_variance( + dag, + input_noise_out, + blind_rotate_noise_out, + noise_keyswitch, + noise_modulus_switching, + ); + let sigma_scale = kappa / relative_var.sqrt(); + ( + error::error_probability_of_sigma_scale(sigma_scale), + relative_var * variance_bound, + ) +} + +fn feasible( + dag: &OperationDag, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, +) -> bool { + for ns in &dag.constraints_by_precisions { + if peak_variance_per_constraint( + ns, + input_noise_out, + blind_rotate_noise_out, + noise_keyswitch, + noise_modulus_switching, + ) > ns.safe_variance_bound + { + return false; + } + } + true } #[cfg(test)] @@ -406,6 +566,26 @@ mod tests { approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON); } + impl OperationDag { + pub fn constraint(&self) -> VariancesAndBound { + assert!(!self.constraints_by_precisions.is_empty()); + assert_eq!(self.constraints_by_precisions.len(), 1); + self.constraints_by_precisions[0].clone() + } + } + + const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; + + const CONFIG: NoiseBoundConfig = NoiseBoundConfig { + security_level: 128, + ciphertext_modulus_log: 64, + maximum_acceptable_error_probability: _4_SIGMA, + }; + + fn analyze(dag: &unparametrized::OperationDag) -> super::OperationDag { + super::analyze(dag, &CONFIG) + } + #[test] fn test_1_input() { let mut graph = unparametrized::OperationDag::new(); @@ -421,11 +601,11 @@ mod tests { assert_eq!(analysis.out_precisions[input1.i], 1); assert_f64_eq(complexity_cost, 0.0); assert!(analysis.nb_luts == 0); - let summary = analysis.noise_summary; - assert!(summary.pareto_vfs_final.len() == 1); - assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 1.0); - assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0); - assert!(summary.pareto_vfs_in_lut.is_empty()); + let constraint = analysis.constraint(); + assert!(constraint.pareto_output.len() == 1); + assert_f64_eq(constraint.pareto_output[0].input_coeff, 1.0); + assert_f64_eq(constraint.pareto_output[0].lut_coeff, 0.0); + assert!(constraint.pareto_in_lut.is_empty()); } #[test] @@ -443,13 +623,13 @@ mod tests { assert!(analysis.levelled_complexity == LevelledComplexity::ZERO); assert_eq!(analysis.out_precisions[lut1.i], 8); assert_f64_eq(one_lut_cost, complexity_cost); - let summary = analysis.noise_summary; - assert!(summary.pareto_vfs_final.len() == 1); - assert!(summary.pareto_vfs_in_lut.len() == 1); - assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 0.0); - assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0); - assert_f64_eq(summary.pareto_vfs_in_lut[0].input_vf, 1.0); - assert_f64_eq(summary.pareto_vfs_in_lut[0].lut_vf, 0.0); + let constraint = analysis.constraint(); + assert!(constraint.pareto_output.len() == 1); + assert!(constraint.pareto_in_lut.len() == 1); + assert_f64_eq(constraint.pareto_output[0].input_coeff, 0.0); + assert_f64_eq(constraint.pareto_output[0].lut_coeff, 1.0); + assert_f64_eq(constraint.pareto_in_lut[0].input_coeff, 1.0); + assert_f64_eq(constraint.pareto_in_lut[0].lut_coeff, 0.0); } #[test] @@ -465,8 +645,8 @@ mod tests { let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); let expected_var = SymbolicVariance { - input_vf: norm2, - lut_vf: 0.0, + input_coeff: norm2, + lut_coeff: 0.0, }; assert!(analysis.out_variances[dot.i] == expected_var); assert!(analysis.out_shapes[dot.i] == Shape::number()); @@ -474,11 +654,11 @@ mod tests { assert_eq!(analysis.out_precisions[dot.i], 1); let expected_dot_cost = (2 * lwe_dim) as f64; assert_f64_eq(expected_dot_cost, complexity_cost); - let summary = analysis.noise_summary; - assert!(summary.pareto_vfs_in_lut.is_empty()); - assert!(summary.pareto_vfs_final.len() == 1); - assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0); - assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0); + let constraint = analysis.constraint(); + assert!(constraint.pareto_in_lut.is_empty()); + assert!(constraint.pareto_output.len() == 1); + assert_f64_eq(constraint.pareto_output[0].input_coeff, 5.0); + assert_f64_eq(constraint.pareto_output[0].lut_coeff, 0.0); } #[test] @@ -497,16 +677,16 @@ mod tests { assert!(analysis.out_variances[dot.i].origin() == VO::Input); assert_eq!(analysis.out_precisions[dot.i], 3); let expected_square_norm2 = weights.square_norm2() as f64; - let actual_square_norm2 = analysis.out_variances[dot.i].input_vf; + let actual_square_norm2 = analysis.out_variances[dot.i].input_coeff; // Due to call on log2() to compute manp the result is not exact assert_f64_eq(actual_square_norm2, expected_square_norm2); assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION); assert_f64_eq(lwe_dim as f64, complexity_cost); - let summary = analysis.noise_summary; - assert!(summary.pareto_vfs_in_lut.is_empty()); - assert!(summary.pareto_vfs_final.len() == 1); - assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Input); - assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0); + let constraint = analysis.constraint(); + assert!(constraint.pareto_in_lut.is_empty()); + assert!(constraint.pareto_output.len() == 1); + assert_eq!(constraint.pareto_output[0].origin(), VO::Input); + assert_f64_eq(constraint.pareto_output[0].input_coeff, 5.0); } #[test] @@ -524,20 +704,20 @@ mod tests { let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); let expected_var_dot1 = SymbolicVariance { - input_vf: weights.square_norm2() as f64, - lut_vf: 0.0, + input_coeff: weights.square_norm2() as f64, + lut_coeff: 0.0, }; let expected_var_lut1 = SymbolicVariance { - input_vf: 0.0, - lut_vf: 1.0, + input_coeff: 0.0, + lut_coeff: 1.0, }; let expected_var_dot2 = SymbolicVariance { - input_vf: 0.0, - lut_vf: weights.square_norm2() as f64, + input_coeff: 0.0, + lut_coeff: weights.square_norm2() as f64, }; let expected_var_lut2 = SymbolicVariance { - input_vf: 0.0, - lut_vf: 1.0, + input_coeff: 0.0, + lut_coeff: 1.0, }; assert!(analysis.out_variances[dot1.i] == expected_var_dot1); assert!(analysis.out_variances[lut1.i] == expected_var_lut1); @@ -546,14 +726,14 @@ mod tests { assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 4); let expected_cost = (lwe_dim * 4) as f64 + 2.0 * one_lut_cost; assert_f64_eq(expected_cost, complexity_cost); - let summary = analysis.noise_summary; - assert_eq!(summary.pareto_vfs_final.len(), 1); - assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Lut); - assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0); - assert_eq!(summary.pareto_vfs_in_lut.len(), 1); - assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Lut); + let constraint = analysis.constraint(); + assert_eq!(constraint.pareto_output.len(), 1); + assert_eq!(constraint.pareto_output[0].origin(), VO::Lut); + assert_f64_eq(constraint.pareto_output[0].lut_coeff, 1.0); + assert_eq!(constraint.pareto_in_lut.len(), 1); + assert_eq!(constraint.pareto_in_lut[0].origin(), VO::Lut); assert_f64_eq( - summary.pareto_vfs_in_lut[0].lut_vf, + constraint.pareto_in_lut[0].lut_coeff, weights.square_norm2() as f64, ); } @@ -574,14 +754,56 @@ mod tests { let expected_cost = (2 * lwe_dim) as f64 + 2.0 * one_lut_cost; assert_f64_eq(expected_cost, complexity_cost); let expected_mixed = SymbolicVariance { - input_vf: square(weights.values[0] as f64), - lut_vf: square(weights.values[1] as f64), + input_coeff: square(weights.values[0] as f64), + lut_coeff: square(weights.values[1] as f64), }; - let summary = analysis.noise_summary; - assert_eq!(summary.pareto_vfs_final.len(), 1); - assert_eq!(summary.pareto_vfs_final[0], SymbolicVariance::LUT); - assert_eq!(summary.pareto_vfs_in_lut.len(), 1); - assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Mixed); - assert_eq!(summary.pareto_vfs_in_lut[0], expected_mixed); + let constraint = analysis.constraint(); + assert_eq!(constraint.pareto_output.len(), 1); + assert_eq!(constraint.pareto_output[0], SymbolicVariance::LUT); + assert_eq!(constraint.pareto_in_lut.len(), 1); + assert_eq!(constraint.pareto_in_lut[0].origin(), VO::Mixed); + assert_eq!(constraint.pareto_in_lut[0], expected_mixed); + } + + #[test] + fn test_multi_precision_input() { + let mut graph = unparametrized::OperationDag::new(); + let max_precision = 5_usize; + for i in 1..=max_precision { + let _ = graph.add_input(i as u8, Shape::number()); + } + let analysis = analyze(&graph); + assert!(analysis.constraints_by_precisions.len() == max_precision); + let mut prev_safe_noise_bound = 0.0; + for (i, ns) in analysis.constraints_by_precisions.iter().enumerate() { + assert_eq!(ns.precision, (max_precision - i) as u8); + assert_f64_eq(ns.pareto_output[0].input_coeff, 1.0); + assert!(prev_safe_noise_bound < ns.safe_variance_bound); + prev_safe_noise_bound = ns.safe_variance_bound; + } + } + + #[test] + fn test_multi_precision_lut() { + let mut graph = unparametrized::OperationDag::new(); + let max_precision = 5_usize; + for i in 1..=max_precision { + let input = graph.add_input(i as u8, Shape::number()); + let _lut = graph.add_lut(input, FunctionTable::UNKWOWN); + } + let analysis = analyze(&graph); + assert!(analysis.constraints_by_precisions.len() == max_precision); + let mut prev_safe_noise_bound = 0.0; + for (i, ns) in analysis.constraints_by_precisions.iter().enumerate() { + assert_eq!(ns.precision, (max_precision - i) as u8); + assert_eq!(ns.pareto_output.len(), 1); + assert_eq!(ns.pareto_in_lut.len(), 1); + assert_f64_eq(ns.pareto_output[0].input_coeff, 0.0); + assert_f64_eq(ns.pareto_output[0].lut_coeff, 1.0); + assert_f64_eq(ns.pareto_in_lut[0].input_coeff, 1.0); + assert_f64_eq(ns.pareto_in_lut[0].lut_coeff, 0.0); + assert!(prev_safe_noise_bound < ns.safe_variance_bound); + prev_safe_noise_bound = ns.safe_variance_bound; + } } } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 5ca4c5630..a8887a611 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -1,10 +1,9 @@ -use concrete_commons::dispersion::{DispersionParameter, Variance}; +use concrete_commons::dispersion::DispersionParameter; use concrete_commons::numeric::UnsignedInteger; use crate::dag::operator::LevelledComplexity; use crate::dag::unparametrized; use crate::noise_estimator::error; -use crate::noise_estimator::error::error_probability_of_sigma_scale; use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; use crate::optimization::atomic_pattern::{ @@ -12,6 +11,7 @@ use crate::optimization::atomic_pattern::{ Solution, }; +use crate::optimization::config::NoiseBoundConfig; use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositionParameters}; use crate::pareto; use crate::security::glwe::minimal_variance; @@ -37,10 +37,8 @@ fn update_best_solution_with_best_decompositions( let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size; let mut best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity); - let mut best_lut_complexity = state - .best_solution - .map_or(f64::INFINITY, |s| s.lut_complexity); let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max); + let mut best_p_error = state.best_solution.map_or(f64::INFINITY, |s| s.p_error); let mut cut_complexity = (best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) / (dag.nb_luts as f64); @@ -87,17 +85,17 @@ fn update_best_solution_with_best_decompositions( for br_quantity in br_pareto { // increasing complexity, decreasing variance - let peek_variance = dag.peek_variance( + let not_feasible = !dag.feasible( input_noise_out, br_quantity.noise, 0.0, noise_modulus_switching, ); - if peek_variance > safe_variance && CUTS { + if not_feasible && CUTS { continue; } - let one_pbs_cost = br_quantity.complexity; - let complexity = dag.complexity_cost(input_lwe_dimension, one_pbs_cost); + let one_lut_cost = br_quantity.complexity; + let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost); if complexity > best_complexity { // As best can evolves it is complementary to blind_rotate_quantities cuts. if PARETO_CUTS { @@ -109,14 +107,14 @@ fn update_best_solution_with_best_decompositions( for i_ks_pareto in (0..=i_current_max_ks).rev() { // increasing variance, decreasing complexity let ks_quantity = ks_pareto[i_ks_pareto]; - let peek_variance = dag.peek_variance( + let not_feasible = !dag.feasible( input_noise_out, br_quantity.noise, ks_quantity.noise, noise_modulus_switching, ); // let noise_max = br_quantity.noise * dag.lut_base_noise_worst_lut + ks_quantity.noise + noise_modulus_switching; - if peek_variance > safe_variance { + if not_feasible { if CROSS_PARETO_CUTS { // the pareto of 2 added pareto is scanned linearly // but with all cuts, pre-computing => no gain @@ -129,28 +127,39 @@ fn update_best_solution_with_best_decompositions( } continue; } + let one_lut_cost = ks_quantity.complexity + br_quantity.complexity; let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost); - - let better_complexity = complexity < best_complexity; - #[allow(clippy::float_cmp)] - let same_complexity_with_less_errors = - complexity == best_complexity && peek_variance < best_variance; - if better_complexity || same_complexity_with_less_errors { - best_lut_complexity = one_lut_cost; - best_complexity = complexity; - best_variance = peek_variance; - best_br_i = br_quantity.index; - best_ks_i = ks_quantity.index; - update_best_solution = true; + let worse_complexity = complexity > best_complexity; + if worse_complexity { + continue; } + + let (peek_p_error, variance) = dag.peek_p_error( + input_noise_out, + br_quantity.noise, + ks_quantity.noise, + noise_modulus_switching, + consts.kappa, + ); + #[allow(clippy::float_cmp)] + let same_comlexity_no_few_errors = + complexity == best_complexity && peek_p_error >= best_p_error; + if same_comlexity_no_few_errors { + continue; + } + + // The complexity is either better or equivalent with less errors + update_best_solution = true; + best_complexity = complexity; + best_p_error = peek_p_error; + best_variance = variance; + best_br_i = br_quantity.index; + best_ks_i = ks_quantity.index; } } // br ks if update_best_solution { - let sigma = Variance(safe_variance).get_standard_dev() * consts.kappa; - let sigma_scale = sigma / Variance(best_variance).get_standard_dev(); - let p_error = error_probability_of_sigma_scale(sigma_scale); let BrDecompositionParameters { level: br_l, log2_base: br_b, @@ -159,6 +168,7 @@ fn update_best_solution_with_best_decompositions( level: ks_l, log2_base: ks_b, } = consts.keyswitch_decompositions[best_ks_i]; + state.best_solution = Some(Solution { input_lwe_dimension, internal_ks_output_lwe_dimension: internal_dim, @@ -168,10 +178,9 @@ fn update_best_solution_with_best_decompositions( glwe_dimension: glwe_params.glwe_dimension, br_decomposition_level_count: br_l, br_decomposition_base_log: br_b, - noise_max: best_variance, complexity: best_complexity, - lut_complexity: best_lut_complexity, - p_error, + p_error: best_p_error, + noise_max: best_variance, }); } } @@ -188,12 +197,17 @@ pub fn optimize( internal_lwe_dimensions: &[u64], ) -> OptimizationState { let ciphertext_modulus_log = W::BITS as u64; - let dag = analyze::analyze(dag); + let noise_config = NoiseBoundConfig { + security_level, + maximum_acceptable_error_probability, + ciphertext_modulus_log, + }; + let dag = analyze::analyze(dag, &noise_config); - let &max_precision = dag.out_precisions.iter().max().unwrap(); + let &min_precision = dag.out_precisions.iter().min().unwrap(); - let safe_variance = error::variance_max( - max_precision as u64, + let safe_variance = error::safe_variance_bound( + min_precision as u64, ciphertext_modulus_log, maximum_acceptable_error_probability, ); @@ -230,6 +244,8 @@ pub fn optimize( ) .get_variance() }; + let not_feasible = + |noise_modulus_switching| !dag.feasible(0.0, 0.0, 0.0, noise_modulus_switching); for &glwe_dim in glwe_dimensions { for &glwe_log_poly_size in glwe_log_polynomial_sizes { @@ -240,7 +256,7 @@ pub fn optimize( }; for &internal_dim in internal_lwe_dimensions { let noise_modulus_switching = noise_modulus_switching(glwe_poly_size, internal_dim); - if CUTS && noise_modulus_switching > consts.safe_variance { + if CUTS && not_feasible(noise_modulus_switching) { // assume this noise is increasing with internal_dim break; } @@ -310,7 +326,6 @@ mod tests { use crate::dag::operator::{FunctionTable, Shape, Weights}; use crate::global_parameters::DEFAUT_DOMAINS; use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin; - use crate::utils::square; use super::*; use crate::optimization::atomic_pattern; @@ -320,7 +335,7 @@ mod tests { } impl Solution { - fn same(&self, other: Self) -> bool { + fn assert_same(&self, other: Self) -> bool { let mut other = other; if small_relative_diff(self.noise_max, other.noise_max) && small_relative_diff(self.p_error, other.p_error) @@ -328,12 +343,38 @@ mod tests { other.noise_max = self.noise_max; other.p_error = self.p_error; } + assert_eq!(self, &other); self == &other } } const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; + const CONFIG: NoiseBoundConfig = NoiseBoundConfig { + security_level: 128, + ciphertext_modulus_log: 64, + maximum_acceptable_error_probability: _4_SIGMA, + }; + + fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState { + let security_level = 128; + let maximum_acceptable_error_probability = _4_SIGMA; + let glwe_log_polynomial_sizes: Vec = DEFAUT_DOMAINS + .glwe_pbs_constrained + .log2_polynomial_size + .as_vec(); + let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); + let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + super::optimize::( + dag, + security_level, + maximum_acceptable_error_probability, + &glwe_log_polynomial_sizes, + &glwe_dimensions, + &internal_lwe_dimensions, + ) + } + struct Times { worst_time: u128, dag_time: u128, @@ -367,15 +408,14 @@ mod tests { let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); let sum_size = 1; - let maximum_acceptable_error_probability = _4_SIGMA; let chrono = Instant::now(); let state = optimize_v0::( sum_size, precision, - security_level, + CONFIG.security_level, weight as f64, - maximum_acceptable_error_probability, + CONFIG.maximum_acceptable_error_probability, &glwe_log_polynomial_sizes, &glwe_dimensions, &internal_lwe_dimensions, @@ -387,7 +427,7 @@ mod tests { precision, security_level, weight as f64, - maximum_acceptable_error_probability, + CONFIG.maximum_acceptable_error_probability, &glwe_log_polynomial_sizes, &glwe_dimensions, &internal_lwe_dimensions, @@ -403,7 +443,7 @@ mod tests { } let sol = state.best_solution.unwrap(); let sol_ref = state_ref.best_solution.unwrap(); - assert!(sol.same(sol_ref)); + assert!(sol.assert_same(sol_ref)); } #[test] @@ -426,15 +466,15 @@ mod tests { let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN); } { - let dag2 = analyze::analyze(&dag); - let summary = dag2.noise_summary; - assert_eq!(summary.pareto_vfs_final.len(), 1); - assert_eq!(summary.pareto_vfs_in_lut.len(), 1); - assert_eq!(summary.pareto_vfs_final[0].origin(), VarianceOrigin::Lut); - assert_f64_eq(1.0, summary.pareto_vfs_final[0].lut_vf); - assert!(summary.pareto_vfs_in_lut.len() == 1); - assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VarianceOrigin::Lut); - assert_f64_eq(square(weight) as f64, summary.pareto_vfs_in_lut[0].lut_vf); + let dag2 = analyze::analyze(&dag, &CONFIG); + let constraint = dag2.constraint(); + assert_eq!(constraint.pareto_output.len(), 1); + assert_eq!(constraint.pareto_in_lut.len(), 1); + assert_eq!(constraint.pareto_output[0].origin(), VarianceOrigin::Lut); + assert_f64_eq(1.0, constraint.pareto_output[0].lut_coeff); + assert!(constraint.pareto_in_lut.len() == 1); + assert_eq!(constraint.pareto_in_lut[0].origin(), VarianceOrigin::Lut); + assert_f64_eq(square(weight) as f64, constraint.pareto_in_lut[0].lut_coeff); } let security_level = 128; @@ -445,14 +485,7 @@ mod tests { .as_vec(); let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); - let state = optimize::( - &dag, - security_level, - maximum_acceptable_error_probability, - &glwe_log_polynomial_sizes, - &glwe_dimensions, - &internal_lwe_dimensions, - ); + let state = optimize(&dag); let state_ref = atomic_pattern::optimize_one::( 1, precision, @@ -474,7 +507,7 @@ mod tests { let sol = state.best_solution.unwrap(); let mut sol_ref = state_ref.best_solution.unwrap(); sol_ref.complexity *= 2.0 /* number of luts */; - assert!(sol.same(sol_ref)); + assert!(sol.assert_same(sol_ref)); } fn no_lut_vs_lut(precision: u64) { @@ -485,28 +518,8 @@ mod tests { let mut dag_no_lut = unparametrized::OperationDag::new(); let _input2 = dag_no_lut.add_input(precision as u8, Shape::number()); - let security_level = 128; - let maximum_acceptable_error_probability = _4_SIGMA; - let glwe_log_polynomial_sizes: Vec = DEFAUT_DOMAINS - .glwe_pbs_constrained - .log2_polynomial_size - .as_vec(); - let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); - let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); - - let opt = |dag: &unparametrized::OperationDag| { - optimize::( - dag, - security_level, - maximum_acceptable_error_probability, - &glwe_log_polynomial_sizes, - &glwe_dimensions, - &internal_lwe_dimensions, - ) - }; - - let state_no_lut = opt(&dag_no_lut); - let state_lut = opt(&dag_lut); + let state_no_lut = optimize(&dag_no_lut); + let state_lut = optimize(&dag_lut); assert_eq!( state_no_lut.best_solution.is_some(), state_lut.best_solution.is_some() @@ -546,28 +559,8 @@ mod tests { let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN); } - let security_level = 128; - let maximum_acceptable_error_probability = _4_SIGMA; - let glwe_log_polynomial_sizes: Vec = DEFAUT_DOMAINS - .glwe_pbs_constrained - .log2_polynomial_size - .as_vec(); - let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); - let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); - - let opt = |dag: &unparametrized::OperationDag| { - optimize::( - dag, - security_level, - maximum_acceptable_error_probability, - &glwe_log_polynomial_sizes, - &glwe_dimensions, - &internal_lwe_dimensions, - ) - }; - - let state_1 = opt(&dag_1); - let state_2 = opt(&dag_2); + let state_1 = optimize(&dag_1); + let state_2 = optimize(&dag_2); if state_1.best_solution.is_none() { assert!(state_2.best_solution.is_none()); @@ -587,4 +580,63 @@ mod tests { } } } + + fn circuit(dag: &mut unparametrized::OperationDag, precision: u8, weight: u64) { + let input = dag.add_input(precision, Shape::number()); + let dot1 = dag.add_dot([input], [weight]); + let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN); + let dot2 = dag.add_dot([lut1], [weight]); + let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN); + } + + fn assert_multi_precision_dominate_single(weight: u64) -> Option { + let low_precision = 4u8; + let high_precision = 5u8; + let mut dag_low = unparametrized::OperationDag::new(); + let mut dag_high = unparametrized::OperationDag::new(); + let mut dag_multi = unparametrized::OperationDag::new(); + + { + circuit(&mut dag_low, low_precision, weight); + circuit(&mut dag_high, high_precision, 1); + circuit(&mut dag_multi, low_precision, weight); + circuit(&mut dag_multi, high_precision, 1); + } + let state_multi = optimize(&dag_multi); + #[allow(clippy::question_mark)] // question mark doesn't work here + if state_multi.best_solution.is_none() { + return None; + } + let state_low = optimize(&dag_low); + let state_high = optimize(&dag_high); + + let sol_low = state_low.best_solution.unwrap(); + let sol_high = state_high.best_solution.unwrap(); + let mut sol_multi = state_multi.best_solution.unwrap(); + sol_multi.complexity /= 2.0; + if sol_low.complexity < sol_high.complexity { + assert!(sol_high.assert_same(sol_multi)); + Some(true) + } else { + assert!(sol_low.complexity < sol_multi.complexity || sol_low.assert_same(sol_multi)); + Some(false) + } + } + + #[test] + fn test_multi_precision_dominate_single() { + let mut prev = Some(true); // true -> ... -> true -> false -> ... -> false + for log2_weight in 0..29 { + let weight = 1 << log2_weight; + let current = assert_multi_precision_dominate_single(weight); + #[allow(clippy::match_like_matches_macro)] // less readable + let authorized = match (prev, current) { + (Some(false), Some(true)) => false, + (None, Some(_)) => false, + _ => true, + }; + assert!(authorized); + prev = current; + } + } } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs b/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs index abcfc4ed2..f336c0423 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs @@ -3,7 +3,6 @@ use derive_more::{Add, AddAssign, Sum}; * A variance that is represented as a linear combination of base variances. * Only the linear coefficient are known. * The base variances are unknown. - * Each linear coefficients is a variance factor. * * Only 2 base variances are possible in the solo key setup: * - from input, @@ -15,11 +14,11 @@ use derive_more::{Add, AddAssign, Sum}; */ #[derive(Clone, Copy, Add, AddAssign, Sum, Debug, PartialEq, PartialOrd)] pub struct SymbolicVariance { - pub lut_vf: f64, - pub input_vf: f64, - // variance = vf.lut_vf * lut_out_noise - // + vf.input_vf * input_out_noise - // E.g. variance(dot([lut, input], [3, 4])) = VariancesFactors {lut_vf:9, input_vf: 16} + pub lut_coeff: f64, + pub input_coeff: f64, + // variance = vf.lut_coeff * lut_out_noise + // + vf.input_coeff * input_out_noise + // E.g. variance(dot([lut, input], [3, 4])) = VariancesFactors {lut_coeff:9, input_coeff: 16} // NOTE: lut_base_noise is the first field since it has higher impact, // see pareto sorting and dominate_or_equal @@ -36,8 +35,8 @@ impl std::ops::Mul for SymbolicVariance { type Output = Self; fn mul(self, sq_weight: f64) -> Self { Self { - input_vf: self.input_vf * sq_weight, - lut_vf: self.lut_vf * sq_weight, + input_coeff: self.input_coeff * sq_weight, + lut_coeff: self.lut_coeff * sq_weight, } } } @@ -51,22 +50,22 @@ impl std::ops::Mul for SymbolicVariance { impl SymbolicVariance { pub const ZERO: Self = Self { - input_vf: 0.0, - lut_vf: 0.0, + input_coeff: 0.0, + lut_coeff: 0.0, }; pub const INPUT: Self = Self { - input_vf: 1.0, - lut_vf: 0.0, + input_coeff: 1.0, + lut_coeff: 0.0, }; pub const LUT: Self = Self { - input_vf: 0.0, - lut_vf: 1.0, + input_coeff: 0.0, + lut_coeff: 1.0, }; pub fn origin(&self) -> VarianceOrigin { - if self.lut_vf == 0.0 { + if self.lut_coeff == 0.0 { VarianceOrigin::Input - } else if self.input_vf == 0.0 { + } else if self.input_coeff == 0.0 { VarianceOrigin::Lut } else { VarianceOrigin::Mixed @@ -78,12 +77,12 @@ impl SymbolicVariance { } pub fn dominate_or_equal(&self, other: &Self) -> bool { - let extra_other_minimal_base_noise = 0.0_f64.max(other.input_vf - self.input_vf); - other.lut_vf + extra_other_minimal_base_noise <= self.lut_vf + let extra_other_minimal_base_noise = 0.0_f64.max(other.input_coeff - self.input_coeff); + other.lut_coeff + extra_other_minimal_base_noise <= self.lut_coeff } pub fn eval(&self, minimal_base_noise: f64, lut_base_noise: f64) -> f64 { - minimal_base_noise * self.input_vf + lut_base_noise * self.lut_vf + minimal_base_noise * self.input_coeff + lut_base_noise * self.lut_coeff } pub fn reduce_to_pareto_front(mut vfs: Vec) -> Vec { diff --git a/concrete-optimizer/src/optimization/mod.rs b/concrete-optimizer/src/optimization/mod.rs index 38011bd76..ba02c377f 100644 --- a/concrete-optimizer/src/optimization/mod.rs +++ b/concrete-optimizer/src/optimization/mod.rs @@ -1,3 +1,4 @@ pub mod atomic_pattern; +pub mod config; pub mod dag; pub mod wop_atomic_pattern; diff --git a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs index e03b14dae..e3917b4e6 100644 --- a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs +++ b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs @@ -533,7 +533,6 @@ pub fn optimize_one( br_decomposition_level_count: sol.br_decomposition_level_count, br_decomposition_base_log: sol.br_decomposition_base_log, complexity: sol.complexity, - lut_complexity: sol.complexity, noise_max: sol.noise_max, p_error: sol.p_error, });