feat(optimizer): improve repeat_p_error speed

This commit is contained in:
Mayeul@Zama
2022-09-22 16:14:52 +02:00
committed by mayeul-zama
parent 48962811b9
commit 8ec0d4f3bd

View File

@@ -4,6 +4,45 @@ pub fn combine_errors(p_error1: f64, p_error2: f64) -> f64 {
}
pub fn repeat_p_error(p_error: f64, count: u64) -> f64 {
if p_error * count as f64 > 1. {
iterative_repeat_p_error(p_error, count)
} else {
binomial_decomposition_repeat_p_error(p_error, count)
}
}
// (1 - global_p_error) = (1 - p_error)^count
// global_p_error = 1 - (1-p)^N = 1 - (1 - N p + N(N-1)/2 p^2 - N(N-1)(N-2)/(2*3) p^3...)
// global_p_error = N p - N(N-1)/2 p^2 + N(N-1)(N-2)/(2*3) p^3...
fn binomial_decomposition_repeat_p_error(p_error: f64, count: u64) -> f64 {
// This guarantees abs(factor) is decreasing
// Without that, factors grow and lose precision
assert!(p_error * (count as f64) <= 1.0);
let mut global_p_error = 0.0;
let mut factor = -1.0;
for i in 1..=count {
factor *= -p_error * (count - i + 1) as f64 / i as f64;
let new_global_p_error = global_p_error + factor;
#[allow(clippy::float_cmp)]
//if factor is too small to make a difference
if new_global_p_error == global_p_error {
// abs(factor) is decreasing and factor sign alternates
// so the remaining series is bounded (in absolute value) by abs(factor) which makes no difference
break;
}
global_p_error = new_global_p_error;
}
global_p_error
}
fn iterative_repeat_p_error(p_error: f64, count: u64) -> f64 {
let mut global_p_error = 0.0;
for _ in 0..count {
@@ -12,3 +51,29 @@ pub fn repeat_p_error(p_error: f64, count: u64) -> f64 {
global_p_error
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(clippy::float_cmp)]
fn assert_eq_both_repeat_p_error(p_error: f64, count: u64) {
let iterative = iterative_repeat_p_error(p_error, count);
let binomial = binomial_decomposition_repeat_p_error(p_error, count);
assert!(((iterative - binomial) / (iterative + binomial)).abs() < 0.000_000_1);
}
#[test]
#[allow(clippy::float_cmp)]
fn test_repeat_p_error() {
assert_eq!(repeat_p_error(0.5, 1), 0.5);
assert_eq!(repeat_p_error(0.5, 2), 0.75);
assert_eq_both_repeat_p_error(0.00001, 10000);
assert_eq_both_repeat_p_error(0.001, 100);
assert_eq_both_repeat_p_error(0.000_000_000_01, 100);
}
}