mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(dag): compute global p-error after the local one is optimized
Resolves zama-ai/products#302
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#![allow(clippy::cast_possible_truncation)] // u64 to usize
|
||||
#![allow(clippy::inline_always)] // needed by delegate
|
||||
#![allow(clippy::match_wildcard_for_single_variants)]
|
||||
#![allow(clippy::manual_range_contains)]
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
#![allow(clippy::missing_const_for_fn)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
@@ -9,9 +9,13 @@ pub fn sigma_scale_of_error_probability(p_error: f64) -> f64 {
|
||||
statrs::function::erf::erf_inv(p_in) * 2_f64.sqrt()
|
||||
}
|
||||
|
||||
pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
|
||||
pub fn success_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
|
||||
// https://en.wikipedia.org/wiki/Error_function#Applications
|
||||
1.0 - statrs::function::erf::erf(sigma_scale / 2_f64.sqrt())
|
||||
statrs::function::erf::erf(sigma_scale / 2_f64.sqrt())
|
||||
}
|
||||
|
||||
pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
|
||||
1.0 - success_probability_of_sigma_scale(sigma_scale)
|
||||
}
|
||||
|
||||
const LEFT_PADDING_BITS: u64 = 1;
|
||||
|
||||
@@ -36,6 +36,7 @@ pub struct Solution {
|
||||
pub complexity: f64,
|
||||
pub noise_max: f64,
|
||||
pub p_error: f64, // error probability
|
||||
pub global_p_error: f64,
|
||||
}
|
||||
|
||||
// Constants during optimisation of decompositions
|
||||
@@ -379,6 +380,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
noise_max,
|
||||
complexity,
|
||||
p_error,
|
||||
global_p_error: f64::NAN,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::dag::unparametrized;
|
||||
use crate::noise_estimator::error;
|
||||
use crate::optimization::config::NoiseBoundConfig;
|
||||
use crate::utils::square;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
// private short convention
|
||||
use DotKind as DK;
|
||||
@@ -107,52 +107,13 @@ pub struct VariancesAndBound {
|
||||
pub precision: Precision,
|
||||
pub safe_variance_bound: f64,
|
||||
pub nb_luts: u64,
|
||||
// All final variance factor not entering a lut (usually final levelledOp)
|
||||
// All dominating final variance factor not entering a lut (usually final levelledOp)
|
||||
pub pareto_output: Vec<SymbolicVariance>,
|
||||
// All variance factor entering a lut
|
||||
// All dominating variance factor entering a lut
|
||||
pub pareto_in_lut: Vec<SymbolicVariance>,
|
||||
}
|
||||
|
||||
impl OperationDag {
|
||||
pub fn peek_p_error(
|
||||
&self,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
kappa: f64,
|
||||
) -> (f64, f64) {
|
||||
peak_p_error(
|
||||
self,
|
||||
input_noise_out,
|
||||
blind_rotate_noise_out,
|
||||
noise_keyswitch,
|
||||
noise_modulus_switching,
|
||||
kappa,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn feasible(
|
||||
&self,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
) -> bool {
|
||||
feasible(
|
||||
self,
|
||||
input_noise_out,
|
||||
blind_rotate_noise_out,
|
||||
noise_keyswitch,
|
||||
noise_modulus_switching,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
|
||||
let luts_cost = one_lut_cost * (self.nb_luts as f64);
|
||||
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
|
||||
luts_cost + levelled_cost
|
||||
}
|
||||
// All counted variances for computing exact full dag error probability
|
||||
pub all_output: Vec<(u64, SymbolicVariance)>,
|
||||
pub all_in_lut: Vec<(u64, SymbolicVariance)>,
|
||||
}
|
||||
|
||||
fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape {
|
||||
@@ -294,15 +255,16 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
|
||||
|
||||
fn extra_final_variances(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
out_precisions: &[Precision],
|
||||
out_variances: &[SymbolicVariance],
|
||||
) -> Vec<(Precision, SymbolicVariance)> {
|
||||
) -> Vec<(Precision, Shape, SymbolicVariance)> {
|
||||
extra_final_values_to_check(dag)
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, &is_final)| {
|
||||
if is_final {
|
||||
Some((out_precisions[i], out_variances[i]))
|
||||
Some((out_precisions[i], out_shapes[i].clone(), out_variances[i]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -312,17 +274,25 @@ fn extra_final_variances(
|
||||
|
||||
fn in_luts_variance(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
out_precisions: &[Precision],
|
||||
out_variances: &[SymbolicVariance],
|
||||
) -> Vec<(Precision, SymbolicVariance)> {
|
||||
let only_luts = |op| {
|
||||
if let &Op::Lut { input, .. } = op {
|
||||
Some((out_precisions[input.i], out_variances[input.i]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
dag.operators.iter().filter_map(only_luts).collect()
|
||||
) -> Vec<(Precision, Shape, SymbolicVariance)> {
|
||||
dag.operators
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, op)| {
|
||||
if let &Op::Lut { input, .. } = op {
|
||||
Some((
|
||||
out_precisions[input.i],
|
||||
out_shapes[i].clone(),
|
||||
out_variances[input.i],
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn op_levelled_complexity(
|
||||
@@ -378,8 +348,8 @@ fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f6
|
||||
|
||||
fn constraints_by_precisions(
|
||||
out_precisions: &[Precision],
|
||||
final_variances: &[(Precision, SymbolicVariance)],
|
||||
in_luts_variance: &[(Precision, SymbolicVariance)],
|
||||
final_variances: &[(Precision, Shape, SymbolicVariance)],
|
||||
in_luts_variance: &[(Precision, Shape, SymbolicVariance)],
|
||||
noise_config: &NoiseBoundConfig,
|
||||
) -> Vec<VariancesAndBound> {
|
||||
let precisions: HashSet<Precision> = out_precisions.iter().copied().collect();
|
||||
@@ -397,11 +367,14 @@ fn constraints_by_precisions(
|
||||
precisions.iter().rev().map(to_noise_summary).collect()
|
||||
}
|
||||
|
||||
fn select_precision<T: Copy>(target_precision: Precision, v: &[(Precision, T)]) -> Vec<T> {
|
||||
fn select_precision<T1: Clone, T2: Copy>(
|
||||
target_precision: Precision,
|
||||
v: &[(Precision, T1, T2)],
|
||||
) -> Vec<(T1, T2)> {
|
||||
v.iter()
|
||||
.filter_map(|(p, t)| {
|
||||
.filter_map(|(p, s, t)| {
|
||||
if *p == target_precision {
|
||||
Some(*t)
|
||||
Some((s.clone(), *t))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -409,16 +382,41 @@ fn select_precision<T: Copy>(target_precision: Precision, v: &[(Precision, T)])
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn counted_symbolic_variance(
|
||||
symbolic_variances: &[(Shape, SymbolicVariance)],
|
||||
) -> Vec<(u64, SymbolicVariance)> {
|
||||
pub fn exact_key(v: &SymbolicVariance) -> (u64, u64) {
|
||||
(v.lut_coeff.to_bits(), v.input_coeff.to_bits())
|
||||
}
|
||||
let mut count: HashMap<(u64, u64), u64> = HashMap::new();
|
||||
for (s, v) in symbolic_variances {
|
||||
*count.entry(exact_key(v)).or_insert(0) += s.flat_size();
|
||||
}
|
||||
let mut res = Vec::new();
|
||||
res.reserve_exact(count.len());
|
||||
for (_s, v) in symbolic_variances {
|
||||
if let Some(c) = count.remove(&exact_key(v)) {
|
||||
res.push((c, *v));
|
||||
}
|
||||
}
|
||||
res
|
||||
}
|
||||
|
||||
fn constraint_for_one_precision(
|
||||
target_precision: Precision,
|
||||
extra_final_variances: &[(Precision, SymbolicVariance)],
|
||||
in_luts_variance: &[(Precision, SymbolicVariance)],
|
||||
extra_final_variances: &[(Precision, Shape, SymbolicVariance)],
|
||||
in_luts_variance: &[(Precision, Shape, SymbolicVariance)],
|
||||
safe_noise_bound: f64,
|
||||
) -> VariancesAndBound {
|
||||
let extra_final_variances = select_precision(target_precision, extra_final_variances);
|
||||
let extra_finals_variance = select_precision(target_precision, extra_final_variances);
|
||||
let in_luts_variance = select_precision(target_precision, in_luts_variance);
|
||||
let nb_luts = in_luts_variance.len() as u64;
|
||||
let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(extra_final_variances);
|
||||
let all_output = counted_symbolic_variance(&extra_finals_variance);
|
||||
let all_in_lut = counted_symbolic_variance(&in_luts_variance);
|
||||
let remove_shape = |t: &(Shape, SymbolicVariance)| t.1;
|
||||
let extra_finals_variance = extra_finals_variance.iter().map(remove_shape).collect();
|
||||
let in_luts_variance = in_luts_variance.iter().map(remove_shape).collect();
|
||||
let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(extra_finals_variance);
|
||||
let pareto_vfs_in_lut = SymbolicVariance::reduce_to_pareto_front(in_luts_variance);
|
||||
VariancesAndBound {
|
||||
precision: target_precision,
|
||||
@@ -426,6 +424,8 @@ fn constraint_for_one_precision(
|
||||
nb_luts,
|
||||
pareto_output: pareto_vfs_final,
|
||||
pareto_in_lut: pareto_vfs_in_lut,
|
||||
all_output,
|
||||
all_in_lut,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -434,10 +434,10 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
|
||||
let out_shapes = out_shapes(dag);
|
||||
let out_precisions = out_precisions(dag);
|
||||
let out_variances = out_variances(dag, &out_shapes);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_precisions, &out_variances);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
|
||||
let coeffs = in_luts_variance
|
||||
.iter()
|
||||
.map(|(_precision, symbolic_variance)| {
|
||||
.map(|(_precision, _shape, symbolic_variance)| {
|
||||
symbolic_variance.lut_coeff + symbolic_variance.input_coeff
|
||||
})
|
||||
.filter(|v| *v >= 1.0);
|
||||
@@ -445,6 +445,10 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
|
||||
worst.log2()
|
||||
}
|
||||
|
||||
pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 {
|
||||
lut_count(dag, &out_shapes(dag))
|
||||
}
|
||||
|
||||
pub fn analyze(
|
||||
dag: &unparametrized::OperationDag,
|
||||
noise_config: &NoiseBoundConfig,
|
||||
@@ -453,9 +457,10 @@ pub fn analyze(
|
||||
let out_shapes = out_shapes(dag);
|
||||
let out_precisions = out_precisions(dag);
|
||||
let out_variances = out_variances(dag, &out_shapes);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_precisions, &out_variances);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
|
||||
let nb_luts = lut_count(dag, &out_shapes);
|
||||
let extra_final_variances = extra_final_variances(dag, &out_precisions, &out_variances);
|
||||
let extra_final_variances =
|
||||
extra_final_variances(dag, &out_shapes, &out_precisions, &out_variances);
|
||||
let levelled_complexity = levelled_complexity(dag, &out_shapes);
|
||||
let constraints_by_precisions = constraints_by_precisions(
|
||||
&out_precisions,
|
||||
@@ -544,48 +549,111 @@ fn peak_relative_variance(
|
||||
(max_relative_var, safe_noise)
|
||||
}
|
||||
|
||||
fn peak_p_error(
|
||||
dag: &OperationDag,
|
||||
fn p_success_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 {
|
||||
let sigma_scale = kappa / relative_variance.sqrt();
|
||||
error::success_probability_of_sigma_scale(sigma_scale)
|
||||
}
|
||||
|
||||
fn p_success_per_constraint(
|
||||
constraint: &VariancesAndBound,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
kappa: f64,
|
||||
) -> (f64, f64) {
|
||||
let (relative_var, variance_bound) = peak_relative_variance(
|
||||
dag,
|
||||
input_noise_out,
|
||||
blind_rotate_noise_out,
|
||||
noise_keyswitch,
|
||||
noise_modulus_switching,
|
||||
);
|
||||
let sigma_scale = kappa / relative_var.sqrt();
|
||||
(
|
||||
error::error_probability_of_sigma_scale(sigma_scale),
|
||||
relative_var * variance_bound,
|
||||
)
|
||||
) -> f64 {
|
||||
// Note: no log probability to keep accuracy near 0, 0 is a fine answer when p_success is very small.
|
||||
let mut p_success = 1.0;
|
||||
for &(count, vf) in &constraint.all_output {
|
||||
assert!(0 < count);
|
||||
let variance = vf.eval(input_noise_out, blind_rotate_noise_out);
|
||||
let relative_variance = variance / constraint.safe_variance_bound;
|
||||
let vf_p_success = p_success_from_relative_variance(relative_variance, kappa);
|
||||
p_success *= vf_p_success.powi(count as i32);
|
||||
}
|
||||
// the maximal variance encountered during a lut computation
|
||||
for &(count, vf) in &constraint.all_in_lut {
|
||||
assert!(0 < count);
|
||||
let variance = vf.eval(input_noise_out, blind_rotate_noise_out);
|
||||
let relative_variance =
|
||||
(variance + noise_keyswitch + noise_modulus_switching) / constraint.safe_variance_bound;
|
||||
let vf_p_success = p_success_from_relative_variance(relative_variance, kappa);
|
||||
p_success *= vf_p_success.powi(count as i32);
|
||||
}
|
||||
p_success
|
||||
}
|
||||
|
||||
fn feasible(
|
||||
dag: &OperationDag,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
) -> bool {
|
||||
for ns in &dag.constraints_by_precisions {
|
||||
if peak_variance_per_constraint(
|
||||
ns,
|
||||
impl OperationDag {
|
||||
pub fn peek_p_error(
|
||||
&self,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
kappa: f64,
|
||||
) -> (f64, f64) {
|
||||
let (relative_var, variance_bound) = peak_relative_variance(
|
||||
self,
|
||||
input_noise_out,
|
||||
blind_rotate_noise_out,
|
||||
noise_keyswitch,
|
||||
noise_modulus_switching,
|
||||
) > ns.safe_variance_bound
|
||||
{
|
||||
return false;
|
||||
}
|
||||
);
|
||||
(
|
||||
1.0 - p_success_from_relative_variance(relative_var, kappa),
|
||||
relative_var * variance_bound,
|
||||
)
|
||||
}
|
||||
pub fn global_p_error(
|
||||
&self,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
kappa: f64,
|
||||
) -> f64 {
|
||||
let mut p_success = 1.0;
|
||||
for ns in &self.constraints_by_precisions {
|
||||
p_success *= p_success_per_constraint(
|
||||
ns,
|
||||
input_noise_out,
|
||||
blind_rotate_noise_out,
|
||||
noise_keyswitch,
|
||||
noise_modulus_switching,
|
||||
kappa,
|
||||
);
|
||||
}
|
||||
assert!(0.0 <= p_success && p_success <= 1.0);
|
||||
1.0 - p_success
|
||||
}
|
||||
|
||||
pub fn feasible(
|
||||
&self,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
) -> bool {
|
||||
for ns in &self.constraints_by_precisions {
|
||||
if peak_variance_per_constraint(
|
||||
ns,
|
||||
input_noise_out,
|
||||
blind_rotate_noise_out,
|
||||
noise_keyswitch,
|
||||
noise_modulus_switching,
|
||||
) > ns.safe_variance_bound
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
|
||||
let luts_cost = one_lut_cost * (self.nb_luts as f64);
|
||||
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
|
||||
luts_cost + levelled_cost
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -79,6 +79,8 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
let mut best_br_noise = f64::INFINITY;
|
||||
let mut best_ks_noise = f64::INFINITY;
|
||||
let mut best_br_i = 0;
|
||||
let mut best_ks_i = 0;
|
||||
let mut update_best_solution = false;
|
||||
@@ -154,6 +156,8 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
|
||||
best_complexity = complexity;
|
||||
best_p_error = peek_p_error;
|
||||
best_variance = variance;
|
||||
best_br_noise = br_quantity.noise;
|
||||
best_ks_noise = ks_quantity.noise;
|
||||
best_br_i = br_quantity.index;
|
||||
best_ks_i = ks_quantity.index;
|
||||
}
|
||||
@@ -180,6 +184,13 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
|
||||
br_decomposition_base_log: br_b,
|
||||
complexity: best_complexity,
|
||||
p_error: best_p_error,
|
||||
global_p_error: dag.global_p_error(
|
||||
input_noise_out,
|
||||
best_br_noise,
|
||||
best_ks_noise,
|
||||
noise_modulus_switching,
|
||||
consts.kappa,
|
||||
),
|
||||
noise_max: best_variance,
|
||||
});
|
||||
}
|
||||
@@ -277,7 +288,9 @@ pub fn optimize<W: UnsignedInteger>(
|
||||
|
||||
if let Some(sol) = state.best_solution {
|
||||
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
|
||||
assert!(0.0 <= sol.global_p_error && sol.global_p_error <= 1.0);
|
||||
assert!(sol.p_error <= maximum_acceptable_error_probability * REL_EPSILON_PROBA);
|
||||
assert!(sol.p_error <= sol.global_p_error * REL_EPSILON_PROBA);
|
||||
}
|
||||
|
||||
state
|
||||
@@ -320,6 +333,7 @@ pub fn optimize_v0<W: UnsignedInteger>(
|
||||
state
|
||||
}
|
||||
|
||||
#[allow(clippy::unnecessary_cast)] // unecessary warning on 'as Precision'
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
@@ -336,8 +350,9 @@ mod tests {
|
||||
}
|
||||
|
||||
impl Solution {
|
||||
fn assert_same(&self, other: Self) -> bool {
|
||||
fn assert_same_pbs_solution(&self, other: Self) -> bool {
|
||||
let mut other = other;
|
||||
other.global_p_error = self.global_p_error;
|
||||
if small_relative_diff(self.noise_max, other.noise_max)
|
||||
&& small_relative_diff(self.p_error, other.p_error)
|
||||
{
|
||||
@@ -444,7 +459,10 @@ mod tests {
|
||||
}
|
||||
let sol = state.best_solution.unwrap();
|
||||
let sol_ref = state_ref.best_solution.unwrap();
|
||||
assert!(sol.assert_same(sol_ref));
|
||||
assert!(sol.assert_same_pbs_solution(sol_ref));
|
||||
assert!(!sol.global_p_error.is_nan());
|
||||
assert!(sol.p_error <= sol.global_p_error);
|
||||
assert!(sol.global_p_error <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -508,7 +526,10 @@ mod tests {
|
||||
let sol = state.best_solution.unwrap();
|
||||
let mut sol_ref = state_ref.best_solution.unwrap();
|
||||
sol_ref.complexity *= 2.0 /* number of luts */;
|
||||
assert!(sol.assert_same(sol_ref));
|
||||
assert!(sol.assert_same_pbs_solution(sol_ref));
|
||||
assert!(!sol.global_p_error.is_nan());
|
||||
assert!(sol.p_error <= sol.global_p_error);
|
||||
assert!(sol.global_p_error <= 1.0);
|
||||
}
|
||||
|
||||
fn no_lut_vs_lut(precision: Precision) {
|
||||
@@ -619,10 +640,13 @@ mod tests {
|
||||
let mut sol_multi = state_multi.best_solution.unwrap();
|
||||
sol_multi.complexity /= 2.0;
|
||||
if sol_low.complexity < sol_high.complexity {
|
||||
assert!(sol_high.assert_same(sol_multi));
|
||||
assert!(sol_high.assert_same_pbs_solution(sol_multi));
|
||||
Some(true)
|
||||
} else {
|
||||
assert!(sol_low.complexity < sol_multi.complexity || sol_low.assert_same(sol_multi));
|
||||
assert!(
|
||||
sol_low.complexity < sol_multi.complexity
|
||||
|| sol_low.assert_same_pbs_solution(sol_multi)
|
||||
);
|
||||
Some(false)
|
||||
}
|
||||
}
|
||||
@@ -643,4 +667,151 @@ mod tests {
|
||||
prev = current;
|
||||
}
|
||||
}
|
||||
|
||||
fn local_to_approx_global_p_error(local_p_error: f64, nb_pbs: u64) -> f64 {
|
||||
#[allow(clippy::float_cmp)]
|
||||
if local_p_error == 1f64 {
|
||||
return 1.0;
|
||||
}
|
||||
#[allow(clippy::float_cmp)]
|
||||
if local_p_error == 0f64 {
|
||||
return 0.0;
|
||||
}
|
||||
let local_p_success = 1.0 - local_p_error;
|
||||
assert!(local_p_success < 1.0);
|
||||
let p_success = local_p_success.powi(nb_pbs as i32);
|
||||
assert!(p_success < 1.0);
|
||||
assert!(0.0 < p_success);
|
||||
1.0 - p_success
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_input() {
|
||||
for precision in [4_u8, 8] {
|
||||
for weight in [1, 3, 27, 243, 729] {
|
||||
for dim in [1, 2, 16, 32] {
|
||||
let _ = check_global_p_error_input(dim, weight, precision);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_global_p_error_input(dim: u64, weight: u64, precision: u8) -> f64 {
|
||||
let shape = Shape::vector(dim);
|
||||
let weights = Weights::number(weight);
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision as u8, shape);
|
||||
let _dot1 = dag.add_dot([input1], weights); // this is just several multiply
|
||||
let state = optimize(&dag);
|
||||
let sol = state.best_solution.unwrap();
|
||||
let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim);
|
||||
approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim);
|
||||
sol.global_p_error
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_lut() {
|
||||
for precision in [4_u8, 8] {
|
||||
for weight in [1, 3, 27, 243, 729] {
|
||||
for depth in [2, 16, 32] {
|
||||
check_global_p_error_lut(depth, weight, precision);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_global_p_error_lut(depth: u64, weight: u64, precision: u8) {
|
||||
let shape = Shape::number();
|
||||
let weights = Weights::number(weight);
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let mut last_val = dag.add_input(precision as u8, shape);
|
||||
for _i in 0..depth {
|
||||
let dot = dag.add_dot([last_val], &weights);
|
||||
last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
let state = optimize(&dag);
|
||||
let sol = state.best_solution.unwrap();
|
||||
// the first lut on input has reduced impact on error probability
|
||||
let lower_nb_dominating_lut = depth - 1;
|
||||
let lower_global_p_error =
|
||||
local_to_approx_global_p_error(sol.p_error, lower_nb_dominating_lut);
|
||||
let higher_global_p_error =
|
||||
local_to_approx_global_p_error(sol.p_error, lower_nb_dominating_lut + 1);
|
||||
assert!(lower_global_p_error <= sol.global_p_error);
|
||||
assert!(sol.global_p_error <= higher_global_p_error);
|
||||
}
|
||||
|
||||
fn dag_2_precisions_lut_chain(
|
||||
depth: u64,
|
||||
precision_low: Precision,
|
||||
precision_high: Precision,
|
||||
weight_low: u64,
|
||||
weight_high: u64,
|
||||
) -> unparametrized::OperationDag {
|
||||
let shape = Shape::number();
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let weights_low = Weights::number(weight_low);
|
||||
let weights_high = Weights::number(weight_high);
|
||||
let mut last_val_low = dag.add_input(precision_low as u8, &shape);
|
||||
let mut last_val_high = dag.add_input(precision_high as u8, &shape);
|
||||
for _i in 0..depth {
|
||||
let dot_low = dag.add_dot([last_val_low], &weights_low);
|
||||
last_val_low = dag.add_lut(dot_low, FunctionTable::UNKWOWN, precision_low);
|
||||
let dot_high = dag.add_dot([last_val_high], &weights_high);
|
||||
last_val_high = dag.add_lut(dot_high, FunctionTable::UNKWOWN, precision_high);
|
||||
}
|
||||
dag
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_dominating_lut() {
|
||||
let depth = 128;
|
||||
let weights_low = 1;
|
||||
let weights_high = 1;
|
||||
let precision_low = 6 as Precision;
|
||||
let precision_high = 8 as Precision;
|
||||
let dag = dag_2_precisions_lut_chain(
|
||||
depth,
|
||||
precision_low,
|
||||
precision_high,
|
||||
weights_low,
|
||||
weights_high,
|
||||
);
|
||||
let sol = optimize(&dag).best_solution.unwrap();
|
||||
// the 2 first luts and low precision/weight luts have little impact on error probability
|
||||
let nb_dominating_lut = depth - 1;
|
||||
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
|
||||
// errors rate is approximated accurately
|
||||
approx::assert_relative_eq!(
|
||||
sol.global_p_error,
|
||||
approx_global_p_error,
|
||||
max_relative = 1e-01
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_p_error_non_dominating_lut() {
|
||||
let depth = 128;
|
||||
let weights_low = 1024 * 1024 * 3;
|
||||
let weights_high = 1;
|
||||
let precision_low = 6 as Precision;
|
||||
let precision_high = 8 as Precision;
|
||||
let dag = dag_2_precisions_lut_chain(
|
||||
depth,
|
||||
precision_low,
|
||||
precision_high,
|
||||
weights_low,
|
||||
weights_high,
|
||||
);
|
||||
let sol = optimize(&dag).best_solution.unwrap();
|
||||
// all intern luts have an impact on error probability almost equaly
|
||||
let nb_dominating_lut = (2 * depth) - 1;
|
||||
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
|
||||
// errors rate is approximated accurately
|
||||
approx::assert_relative_eq!(
|
||||
sol.global_p_error,
|
||||
approx_global_p_error,
|
||||
max_relative = 0.05
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +34,14 @@ fn max_precision(dag: &OperationDag) -> Precision {
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {
|
||||
let global_p_error = 1.0 - (1.0 - sol.p_error).powi(nb_luts as i32);
|
||||
WopSolution {
|
||||
global_p_error,
|
||||
..sol
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize<W: UnsignedInteger>(
|
||||
dag: &OperationDag,
|
||||
security_level: u64,
|
||||
@@ -61,6 +69,7 @@ pub fn optimize<W: UnsignedInteger>(
|
||||
let fallback_16b_precision = 16;
|
||||
let default_log_norm = default_log_norm2_woppbs;
|
||||
let worst_log_norm = analyze::worst_log_norm(dag);
|
||||
let nb_luts = analyze::lut_count_from_dag(dag);
|
||||
let log_norm = default_log_norm.min(worst_log_norm);
|
||||
let opt_sol = wop_optimize::<W>(
|
||||
fallback_16b_precision,
|
||||
@@ -72,6 +81,6 @@ pub fn optimize<W: UnsignedInteger>(
|
||||
internal_lwe_dimensions,
|
||||
)
|
||||
.best_solution;
|
||||
opt_sol.map(Solution::WopSolution)
|
||||
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -65,6 +65,7 @@ pub struct Solution {
|
||||
pub complexity: f64,
|
||||
pub noise_max: f64,
|
||||
pub p_error: f64,
|
||||
pub global_p_error: f64,
|
||||
// error probability
|
||||
pub cb_decomposition_level_count: u64,
|
||||
pub cb_decomposition_base_log: u64,
|
||||
@@ -84,6 +85,7 @@ impl Solution {
|
||||
complexity: 0.,
|
||||
noise_max: 0.0,
|
||||
p_error: 0.0,
|
||||
global_p_error: 0.0,
|
||||
cb_decomposition_level_count: 0,
|
||||
cb_decomposition_base_log: 0,
|
||||
}
|
||||
@@ -104,6 +106,7 @@ impl From<Solution> for atomic_pattern::Solution {
|
||||
complexity: sol.complexity,
|
||||
noise_max: sol.noise_max,
|
||||
p_error: sol.p_error,
|
||||
global_p_error: sol.global_p_error,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -431,6 +434,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
noise_max: variance_max,
|
||||
complexity,
|
||||
p_error,
|
||||
global_p_error: f64::NAN,
|
||||
cb_decomposition_level_count: circuit_pbs_decomposition_parameter.level,
|
||||
cb_decomposition_base_log: circuit_pbs_decomposition_parameter.log2_base,
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user