diff --git a/concrete-optimizer/src/lib.rs b/concrete-optimizer/src/lib.rs index 13ef55924..b08495840 100644 --- a/concrete-optimizer/src/lib.rs +++ b/concrete-optimizer/src/lib.rs @@ -6,6 +6,7 @@ #![allow(clippy::cast_possible_truncation)] // u64 to usize #![allow(clippy::inline_always)] // needed by delegate #![allow(clippy::match_wildcard_for_single_variants)] +#![allow(clippy::manual_range_contains)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_const_for_fn)] #![allow(clippy::module_name_repetitions)] diff --git a/concrete-optimizer/src/noise_estimator/error.rs b/concrete-optimizer/src/noise_estimator/error.rs index 0a4c56fde..f85dc9d88 100644 --- a/concrete-optimizer/src/noise_estimator/error.rs +++ b/concrete-optimizer/src/noise_estimator/error.rs @@ -9,9 +9,13 @@ pub fn sigma_scale_of_error_probability(p_error: f64) -> f64 { statrs::function::erf::erf_inv(p_in) * 2_f64.sqrt() } -pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 { +pub fn success_probability_of_sigma_scale(sigma_scale: f64) -> f64 { // https://en.wikipedia.org/wiki/Error_function#Applications - 1.0 - statrs::function::erf::erf(sigma_scale / 2_f64.sqrt()) + statrs::function::erf::erf(sigma_scale / 2_f64.sqrt()) +} + +pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 { + 1.0 - success_probability_of_sigma_scale(sigma_scale) } const LEFT_PADDING_BITS: u64 = 1; diff --git a/concrete-optimizer/src/optimization/atomic_pattern.rs b/concrete-optimizer/src/optimization/atomic_pattern.rs index c6c3fceca..6af0f4689 100644 --- a/concrete-optimizer/src/optimization/atomic_pattern.rs +++ b/concrete-optimizer/src/optimization/atomic_pattern.rs @@ -36,6 +36,7 @@ pub struct Solution { pub complexity: f64, pub noise_max: f64, pub p_error: f64, // error probability + pub global_p_error: f64, } // Constants during optimisation of decompositions @@ -379,6 +380,7 @@ fn update_state_with_best_decompositions( noise_max, complexity, p_error, + global_p_error: f64::NAN, }); } } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 948b04dac..7e6874475 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -6,7 +6,7 @@ use crate::dag::unparametrized; use crate::noise_estimator::error; use crate::optimization::config::NoiseBoundConfig; use crate::utils::square; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; // private short convention use DotKind as DK; @@ -107,52 +107,13 @@ pub struct VariancesAndBound { pub precision: Precision, pub safe_variance_bound: f64, pub nb_luts: u64, - // All final variance factor not entering a lut (usually final levelledOp) + // All dominating final variance factor not entering a lut (usually final levelledOp) pub pareto_output: Vec, - // All variance factor entering a lut + // All dominating variance factor entering a lut pub pareto_in_lut: Vec, -} - -impl OperationDag { - pub fn peek_p_error( - &self, - input_noise_out: f64, - blind_rotate_noise_out: f64, - noise_keyswitch: f64, - noise_modulus_switching: f64, - kappa: f64, - ) -> (f64, f64) { - peak_p_error( - self, - input_noise_out, - blind_rotate_noise_out, - noise_keyswitch, - noise_modulus_switching, - kappa, - ) - } - - pub fn feasible( - &self, - input_noise_out: f64, - blind_rotate_noise_out: f64, - noise_keyswitch: f64, - noise_modulus_switching: f64, - ) -> bool { - feasible( - self, - input_noise_out, - blind_rotate_noise_out, - noise_keyswitch, - noise_modulus_switching, - ) - } - - pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 { - let luts_cost = one_lut_cost * (self.nb_luts as f64); - let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension); - luts_cost + levelled_cost - } + // All counted variances for computing exact full dag error probability + pub all_output: Vec<(u64, SymbolicVariance)>, + pub all_in_lut: Vec<(u64, SymbolicVariance)>, } fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape { @@ -294,15 +255,16 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec fn extra_final_variances( dag: &unparametrized::OperationDag, + out_shapes: &[Shape], out_precisions: &[Precision], out_variances: &[SymbolicVariance], -) -> Vec<(Precision, SymbolicVariance)> { +) -> Vec<(Precision, Shape, SymbolicVariance)> { extra_final_values_to_check(dag) .iter() .enumerate() .filter_map(|(i, &is_final)| { if is_final { - Some((out_precisions[i], out_variances[i])) + Some((out_precisions[i], out_shapes[i].clone(), out_variances[i])) } else { None } @@ -312,17 +274,25 @@ fn extra_final_variances( fn in_luts_variance( dag: &unparametrized::OperationDag, + out_shapes: &[Shape], out_precisions: &[Precision], out_variances: &[SymbolicVariance], -) -> Vec<(Precision, SymbolicVariance)> { - let only_luts = |op| { - if let &Op::Lut { input, .. } = op { - Some((out_precisions[input.i], out_variances[input.i])) - } else { - None - } - }; - dag.operators.iter().filter_map(only_luts).collect() +) -> Vec<(Precision, Shape, SymbolicVariance)> { + dag.operators + .iter() + .enumerate() + .filter_map(|(i, op)| { + if let &Op::Lut { input, .. } = op { + Some(( + out_precisions[input.i], + out_shapes[i].clone(), + out_variances[input.i], + )) + } else { + None + } + }) + .collect() } fn op_levelled_complexity( @@ -378,8 +348,8 @@ fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f6 fn constraints_by_precisions( out_precisions: &[Precision], - final_variances: &[(Precision, SymbolicVariance)], - in_luts_variance: &[(Precision, SymbolicVariance)], + final_variances: &[(Precision, Shape, SymbolicVariance)], + in_luts_variance: &[(Precision, Shape, SymbolicVariance)], noise_config: &NoiseBoundConfig, ) -> Vec { let precisions: HashSet = out_precisions.iter().copied().collect(); @@ -397,11 +367,14 @@ fn constraints_by_precisions( precisions.iter().rev().map(to_noise_summary).collect() } -fn select_precision(target_precision: Precision, v: &[(Precision, T)]) -> Vec { +fn select_precision( + target_precision: Precision, + v: &[(Precision, T1, T2)], +) -> Vec<(T1, T2)> { v.iter() - .filter_map(|(p, t)| { + .filter_map(|(p, s, t)| { if *p == target_precision { - Some(*t) + Some((s.clone(), *t)) } else { None } @@ -409,16 +382,41 @@ fn select_precision(target_precision: Precision, v: &[(Precision, T)]) .collect() } +fn counted_symbolic_variance( + symbolic_variances: &[(Shape, SymbolicVariance)], +) -> Vec<(u64, SymbolicVariance)> { + pub fn exact_key(v: &SymbolicVariance) -> (u64, u64) { + (v.lut_coeff.to_bits(), v.input_coeff.to_bits()) + } + let mut count: HashMap<(u64, u64), u64> = HashMap::new(); + for (s, v) in symbolic_variances { + *count.entry(exact_key(v)).or_insert(0) += s.flat_size(); + } + let mut res = Vec::new(); + res.reserve_exact(count.len()); + for (_s, v) in symbolic_variances { + if let Some(c) = count.remove(&exact_key(v)) { + res.push((c, *v)); + } + } + res +} + fn constraint_for_one_precision( target_precision: Precision, - extra_final_variances: &[(Precision, SymbolicVariance)], - in_luts_variance: &[(Precision, SymbolicVariance)], + extra_final_variances: &[(Precision, Shape, SymbolicVariance)], + in_luts_variance: &[(Precision, Shape, SymbolicVariance)], safe_noise_bound: f64, ) -> VariancesAndBound { - let extra_final_variances = select_precision(target_precision, extra_final_variances); + let extra_finals_variance = select_precision(target_precision, extra_final_variances); let in_luts_variance = select_precision(target_precision, in_luts_variance); let nb_luts = in_luts_variance.len() as u64; - let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(extra_final_variances); + let all_output = counted_symbolic_variance(&extra_finals_variance); + let all_in_lut = counted_symbolic_variance(&in_luts_variance); + let remove_shape = |t: &(Shape, SymbolicVariance)| t.1; + let extra_finals_variance = extra_finals_variance.iter().map(remove_shape).collect(); + let in_luts_variance = in_luts_variance.iter().map(remove_shape).collect(); + let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(extra_finals_variance); let pareto_vfs_in_lut = SymbolicVariance::reduce_to_pareto_front(in_luts_variance); VariancesAndBound { precision: target_precision, @@ -426,6 +424,8 @@ fn constraint_for_one_precision( nb_luts, pareto_output: pareto_vfs_final, pareto_in_lut: pareto_vfs_in_lut, + all_output, + all_in_lut, } } @@ -434,10 +434,10 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 { let out_shapes = out_shapes(dag); let out_precisions = out_precisions(dag); let out_variances = out_variances(dag, &out_shapes); - let in_luts_variance = in_luts_variance(dag, &out_precisions, &out_variances); + let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances); let coeffs = in_luts_variance .iter() - .map(|(_precision, symbolic_variance)| { + .map(|(_precision, _shape, symbolic_variance)| { symbolic_variance.lut_coeff + symbolic_variance.input_coeff }) .filter(|v| *v >= 1.0); @@ -445,6 +445,10 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 { worst.log2() } +pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 { + lut_count(dag, &out_shapes(dag)) +} + pub fn analyze( dag: &unparametrized::OperationDag, noise_config: &NoiseBoundConfig, @@ -453,9 +457,10 @@ pub fn analyze( let out_shapes = out_shapes(dag); let out_precisions = out_precisions(dag); let out_variances = out_variances(dag, &out_shapes); - let in_luts_variance = in_luts_variance(dag, &out_precisions, &out_variances); + let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances); let nb_luts = lut_count(dag, &out_shapes); - let extra_final_variances = extra_final_variances(dag, &out_precisions, &out_variances); + let extra_final_variances = + extra_final_variances(dag, &out_shapes, &out_precisions, &out_variances); let levelled_complexity = levelled_complexity(dag, &out_shapes); let constraints_by_precisions = constraints_by_precisions( &out_precisions, @@ -544,48 +549,111 @@ fn peak_relative_variance( (max_relative_var, safe_noise) } -fn peak_p_error( - dag: &OperationDag, +fn p_success_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 { + let sigma_scale = kappa / relative_variance.sqrt(); + error::success_probability_of_sigma_scale(sigma_scale) +} + +fn p_success_per_constraint( + constraint: &VariancesAndBound, input_noise_out: f64, blind_rotate_noise_out: f64, noise_keyswitch: f64, noise_modulus_switching: f64, kappa: f64, -) -> (f64, f64) { - let (relative_var, variance_bound) = peak_relative_variance( - dag, - input_noise_out, - blind_rotate_noise_out, - noise_keyswitch, - noise_modulus_switching, - ); - let sigma_scale = kappa / relative_var.sqrt(); - ( - error::error_probability_of_sigma_scale(sigma_scale), - relative_var * variance_bound, - ) +) -> f64 { + // Note: no log probability to keep accuracy near 0, 0 is a fine answer when p_success is very small. + let mut p_success = 1.0; + for &(count, vf) in &constraint.all_output { + assert!(0 < count); + let variance = vf.eval(input_noise_out, blind_rotate_noise_out); + let relative_variance = variance / constraint.safe_variance_bound; + let vf_p_success = p_success_from_relative_variance(relative_variance, kappa); + p_success *= vf_p_success.powi(count as i32); + } + // the maximal variance encountered during a lut computation + for &(count, vf) in &constraint.all_in_lut { + assert!(0 < count); + let variance = vf.eval(input_noise_out, blind_rotate_noise_out); + let relative_variance = + (variance + noise_keyswitch + noise_modulus_switching) / constraint.safe_variance_bound; + let vf_p_success = p_success_from_relative_variance(relative_variance, kappa); + p_success *= vf_p_success.powi(count as i32); + } + p_success } -fn feasible( - dag: &OperationDag, - input_noise_out: f64, - blind_rotate_noise_out: f64, - noise_keyswitch: f64, - noise_modulus_switching: f64, -) -> bool { - for ns in &dag.constraints_by_precisions { - if peak_variance_per_constraint( - ns, +impl OperationDag { + pub fn peek_p_error( + &self, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, + kappa: f64, + ) -> (f64, f64) { + let (relative_var, variance_bound) = peak_relative_variance( + self, input_noise_out, blind_rotate_noise_out, noise_keyswitch, noise_modulus_switching, - ) > ns.safe_variance_bound - { - return false; - } + ); + ( + 1.0 - p_success_from_relative_variance(relative_var, kappa), + relative_var * variance_bound, + ) + } + pub fn global_p_error( + &self, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, + kappa: f64, + ) -> f64 { + let mut p_success = 1.0; + for ns in &self.constraints_by_precisions { + p_success *= p_success_per_constraint( + ns, + input_noise_out, + blind_rotate_noise_out, + noise_keyswitch, + noise_modulus_switching, + kappa, + ); + } + assert!(0.0 <= p_success && p_success <= 1.0); + 1.0 - p_success + } + + pub fn feasible( + &self, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, + ) -> bool { + for ns in &self.constraints_by_precisions { + if peak_variance_per_constraint( + ns, + input_noise_out, + blind_rotate_noise_out, + noise_keyswitch, + noise_modulus_switching, + ) > ns.safe_variance_bound + { + return false; + } + } + true + } + + pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 { + let luts_cost = one_lut_cost * (self.nb_luts as f64); + let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension); + luts_cost + levelled_cost } - true } #[cfg(test)] diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 735ad6a32..2082f5b2f 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -79,6 +79,8 @@ fn update_best_solution_with_best_decompositions( ) .get_variance(); + let mut best_br_noise = f64::INFINITY; + let mut best_ks_noise = f64::INFINITY; let mut best_br_i = 0; let mut best_ks_i = 0; let mut update_best_solution = false; @@ -154,6 +156,8 @@ fn update_best_solution_with_best_decompositions( best_complexity = complexity; best_p_error = peek_p_error; best_variance = variance; + best_br_noise = br_quantity.noise; + best_ks_noise = ks_quantity.noise; best_br_i = br_quantity.index; best_ks_i = ks_quantity.index; } @@ -180,6 +184,13 @@ fn update_best_solution_with_best_decompositions( br_decomposition_base_log: br_b, complexity: best_complexity, p_error: best_p_error, + global_p_error: dag.global_p_error( + input_noise_out, + best_br_noise, + best_ks_noise, + noise_modulus_switching, + consts.kappa, + ), noise_max: best_variance, }); } @@ -277,7 +288,9 @@ pub fn optimize( if let Some(sol) = state.best_solution { assert!(0.0 <= sol.p_error && sol.p_error <= 1.0); + assert!(0.0 <= sol.global_p_error && sol.global_p_error <= 1.0); assert!(sol.p_error <= maximum_acceptable_error_probability * REL_EPSILON_PROBA); + assert!(sol.p_error <= sol.global_p_error * REL_EPSILON_PROBA); } state @@ -320,6 +333,7 @@ pub fn optimize_v0( state } +#[allow(clippy::unnecessary_cast)] // unecessary warning on 'as Precision' #[cfg(test)] mod tests { use std::time::Instant; @@ -336,8 +350,9 @@ mod tests { } impl Solution { - fn assert_same(&self, other: Self) -> bool { + fn assert_same_pbs_solution(&self, other: Self) -> bool { let mut other = other; + other.global_p_error = self.global_p_error; if small_relative_diff(self.noise_max, other.noise_max) && small_relative_diff(self.p_error, other.p_error) { @@ -444,7 +459,10 @@ mod tests { } let sol = state.best_solution.unwrap(); let sol_ref = state_ref.best_solution.unwrap(); - assert!(sol.assert_same(sol_ref)); + assert!(sol.assert_same_pbs_solution(sol_ref)); + assert!(!sol.global_p_error.is_nan()); + assert!(sol.p_error <= sol.global_p_error); + assert!(sol.global_p_error <= 1.0); } #[test] @@ -508,7 +526,10 @@ mod tests { let sol = state.best_solution.unwrap(); let mut sol_ref = state_ref.best_solution.unwrap(); sol_ref.complexity *= 2.0 /* number of luts */; - assert!(sol.assert_same(sol_ref)); + assert!(sol.assert_same_pbs_solution(sol_ref)); + assert!(!sol.global_p_error.is_nan()); + assert!(sol.p_error <= sol.global_p_error); + assert!(sol.global_p_error <= 1.0); } fn no_lut_vs_lut(precision: Precision) { @@ -619,10 +640,13 @@ mod tests { let mut sol_multi = state_multi.best_solution.unwrap(); sol_multi.complexity /= 2.0; if sol_low.complexity < sol_high.complexity { - assert!(sol_high.assert_same(sol_multi)); + assert!(sol_high.assert_same_pbs_solution(sol_multi)); Some(true) } else { - assert!(sol_low.complexity < sol_multi.complexity || sol_low.assert_same(sol_multi)); + assert!( + sol_low.complexity < sol_multi.complexity + || sol_low.assert_same_pbs_solution(sol_multi) + ); Some(false) } } @@ -643,4 +667,151 @@ mod tests { prev = current; } } + + fn local_to_approx_global_p_error(local_p_error: f64, nb_pbs: u64) -> f64 { + #[allow(clippy::float_cmp)] + if local_p_error == 1f64 { + return 1.0; + } + #[allow(clippy::float_cmp)] + if local_p_error == 0f64 { + return 0.0; + } + let local_p_success = 1.0 - local_p_error; + assert!(local_p_success < 1.0); + let p_success = local_p_success.powi(nb_pbs as i32); + assert!(p_success < 1.0); + assert!(0.0 < p_success); + 1.0 - p_success + } + + #[test] + fn test_global_p_error_input() { + for precision in [4_u8, 8] { + for weight in [1, 3, 27, 243, 729] { + for dim in [1, 2, 16, 32] { + let _ = check_global_p_error_input(dim, weight, precision); + } + } + } + } + + fn check_global_p_error_input(dim: u64, weight: u64, precision: u8) -> f64 { + let shape = Shape::vector(dim); + let weights = Weights::number(weight); + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision as u8, shape); + let _dot1 = dag.add_dot([input1], weights); // this is just several multiply + let state = optimize(&dag); + let sol = state.best_solution.unwrap(); + let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim); + approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim); + sol.global_p_error + } + + #[test] + fn test_global_p_error_lut() { + for precision in [4_u8, 8] { + for weight in [1, 3, 27, 243, 729] { + for depth in [2, 16, 32] { + check_global_p_error_lut(depth, weight, precision); + } + } + } + } + + fn check_global_p_error_lut(depth: u64, weight: u64, precision: u8) { + let shape = Shape::number(); + let weights = Weights::number(weight); + let mut dag = unparametrized::OperationDag::new(); + let mut last_val = dag.add_input(precision as u8, shape); + for _i in 0..depth { + let dot = dag.add_dot([last_val], &weights); + last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision); + } + let state = optimize(&dag); + let sol = state.best_solution.unwrap(); + // the first lut on input has reduced impact on error probability + let lower_nb_dominating_lut = depth - 1; + let lower_global_p_error = + local_to_approx_global_p_error(sol.p_error, lower_nb_dominating_lut); + let higher_global_p_error = + local_to_approx_global_p_error(sol.p_error, lower_nb_dominating_lut + 1); + assert!(lower_global_p_error <= sol.global_p_error); + assert!(sol.global_p_error <= higher_global_p_error); + } + + fn dag_2_precisions_lut_chain( + depth: u64, + precision_low: Precision, + precision_high: Precision, + weight_low: u64, + weight_high: u64, + ) -> unparametrized::OperationDag { + let shape = Shape::number(); + let mut dag = unparametrized::OperationDag::new(); + let weights_low = Weights::number(weight_low); + let weights_high = Weights::number(weight_high); + let mut last_val_low = dag.add_input(precision_low as u8, &shape); + let mut last_val_high = dag.add_input(precision_high as u8, &shape); + for _i in 0..depth { + let dot_low = dag.add_dot([last_val_low], &weights_low); + last_val_low = dag.add_lut(dot_low, FunctionTable::UNKWOWN, precision_low); + let dot_high = dag.add_dot([last_val_high], &weights_high); + last_val_high = dag.add_lut(dot_high, FunctionTable::UNKWOWN, precision_high); + } + dag + } + + #[test] + fn test_global_p_error_dominating_lut() { + let depth = 128; + let weights_low = 1; + let weights_high = 1; + let precision_low = 6 as Precision; + let precision_high = 8 as Precision; + let dag = dag_2_precisions_lut_chain( + depth, + precision_low, + precision_high, + weights_low, + weights_high, + ); + let sol = optimize(&dag).best_solution.unwrap(); + // the 2 first luts and low precision/weight luts have little impact on error probability + let nb_dominating_lut = depth - 1; + let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut); + // errors rate is approximated accurately + approx::assert_relative_eq!( + sol.global_p_error, + approx_global_p_error, + max_relative = 1e-01 + ); + } + + #[test] + fn test_global_p_error_non_dominating_lut() { + let depth = 128; + let weights_low = 1024 * 1024 * 3; + let weights_high = 1; + let precision_low = 6 as Precision; + let precision_high = 8 as Precision; + let dag = dag_2_precisions_lut_chain( + depth, + precision_low, + precision_high, + weights_low, + weights_high, + ); + let sol = optimize(&dag).best_solution.unwrap(); + // all intern luts have an impact on error probability almost equaly + let nb_dominating_lut = (2 * depth) - 1; + let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut); + // errors rate is approximated accurately + approx::assert_relative_eq!( + sol.global_p_error, + approx_global_p_error, + max_relative = 0.05 + ); + } } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs index 70e6fab73..24b1fd744 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs @@ -34,6 +34,14 @@ fn max_precision(dag: &OperationDag) -> Precision { .unwrap_or(0) } +fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution { + let global_p_error = 1.0 - (1.0 - sol.p_error).powi(nb_luts as i32); + WopSolution { + global_p_error, + ..sol + } +} + pub fn optimize( dag: &OperationDag, security_level: u64, @@ -61,6 +69,7 @@ pub fn optimize( let fallback_16b_precision = 16; let default_log_norm = default_log_norm2_woppbs; let worst_log_norm = analyze::worst_log_norm(dag); + let nb_luts = analyze::lut_count_from_dag(dag); let log_norm = default_log_norm.min(worst_log_norm); let opt_sol = wop_optimize::( fallback_16b_precision, @@ -72,6 +81,6 @@ pub fn optimize( internal_lwe_dimensions, ) .best_solution; - opt_sol.map(Solution::WopSolution) + opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol))) } } diff --git a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs index eb77d121a..f9bc2ed62 100644 --- a/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs +++ b/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs @@ -65,6 +65,7 @@ pub struct Solution { pub complexity: f64, pub noise_max: f64, pub p_error: f64, + pub global_p_error: f64, // error probability pub cb_decomposition_level_count: u64, pub cb_decomposition_base_log: u64, @@ -84,6 +85,7 @@ impl Solution { complexity: 0., noise_max: 0.0, p_error: 0.0, + global_p_error: 0.0, cb_decomposition_level_count: 0, cb_decomposition_base_log: 0, } @@ -104,6 +106,7 @@ impl From for atomic_pattern::Solution { complexity: sol.complexity, noise_max: sol.noise_max, p_error: sol.p_error, + global_p_error: sol.global_p_error, } } } @@ -431,6 +434,7 @@ fn update_state_with_best_decompositions( noise_max: variance_max, complexity, p_error, + global_p_error: f64::NAN, cb_decomposition_level_count: circuit_pbs_decomposition_parameter.level, cb_decomposition_base_log: circuit_pbs_decomposition_parameter.log2_base, });