feat(dag): compute global p-error after the local one is optimized

Resolves zama-ai/products#302
This commit is contained in:
rudy
2022-07-18 10:06:42 +02:00
committed by rudy-6-4
parent 517ab218dc
commit b7c148257b
7 changed files with 366 additions and 107 deletions

View File

@@ -6,6 +6,7 @@
#![allow(clippy::cast_possible_truncation)] // u64 to usize
#![allow(clippy::inline_always)] // needed by delegate
#![allow(clippy::match_wildcard_for_single_variants)]
#![allow(clippy::manual_range_contains)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::missing_const_for_fn)]
#![allow(clippy::module_name_repetitions)]

View File

@@ -9,9 +9,13 @@ pub fn sigma_scale_of_error_probability(p_error: f64) -> f64 {
statrs::function::erf::erf_inv(p_in) * 2_f64.sqrt()
}
pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
pub fn success_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
// https://en.wikipedia.org/wiki/Error_function#Applications
1.0 - statrs::function::erf::erf(sigma_scale / 2_f64.sqrt())
statrs::function::erf::erf(sigma_scale / 2_f64.sqrt())
}
pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
1.0 - success_probability_of_sigma_scale(sigma_scale)
}
const LEFT_PADDING_BITS: u64 = 1;

View File

@@ -36,6 +36,7 @@ pub struct Solution {
pub complexity: f64,
pub noise_max: f64,
pub p_error: f64, // error probability
pub global_p_error: f64,
}
// Constants during optimisation of decompositions
@@ -379,6 +380,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
noise_max,
complexity,
p_error,
global_p_error: f64::NAN,
});
}
}

View File

@@ -6,7 +6,7 @@ use crate::dag::unparametrized;
use crate::noise_estimator::error;
use crate::optimization::config::NoiseBoundConfig;
use crate::utils::square;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
// private short convention
use DotKind as DK;
@@ -107,52 +107,13 @@ pub struct VariancesAndBound {
pub precision: Precision,
pub safe_variance_bound: f64,
pub nb_luts: u64,
// All final variance factor not entering a lut (usually final levelledOp)
// All dominating final variance factor not entering a lut (usually final levelledOp)
pub pareto_output: Vec<SymbolicVariance>,
// All variance factor entering a lut
// All dominating variance factor entering a lut
pub pareto_in_lut: Vec<SymbolicVariance>,
}
impl OperationDag {
pub fn peek_p_error(
&self,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
kappa: f64,
) -> (f64, f64) {
peak_p_error(
self,
input_noise_out,
blind_rotate_noise_out,
noise_keyswitch,
noise_modulus_switching,
kappa,
)
}
pub fn feasible(
&self,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
) -> bool {
feasible(
self,
input_noise_out,
blind_rotate_noise_out,
noise_keyswitch,
noise_modulus_switching,
)
}
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
let luts_cost = one_lut_cost * (self.nb_luts as f64);
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
luts_cost + levelled_cost
}
// All counted variances for computing exact full dag error probability
pub all_output: Vec<(u64, SymbolicVariance)>,
pub all_in_lut: Vec<(u64, SymbolicVariance)>,
}
fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape {
@@ -294,15 +255,16 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool>
fn extra_final_variances(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
out_precisions: &[Precision],
out_variances: &[SymbolicVariance],
) -> Vec<(Precision, SymbolicVariance)> {
) -> Vec<(Precision, Shape, SymbolicVariance)> {
extra_final_values_to_check(dag)
.iter()
.enumerate()
.filter_map(|(i, &is_final)| {
if is_final {
Some((out_precisions[i], out_variances[i]))
Some((out_precisions[i], out_shapes[i].clone(), out_variances[i]))
} else {
None
}
@@ -312,17 +274,25 @@ fn extra_final_variances(
fn in_luts_variance(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
out_precisions: &[Precision],
out_variances: &[SymbolicVariance],
) -> Vec<(Precision, SymbolicVariance)> {
let only_luts = |op| {
if let &Op::Lut { input, .. } = op {
Some((out_precisions[input.i], out_variances[input.i]))
} else {
None
}
};
dag.operators.iter().filter_map(only_luts).collect()
) -> Vec<(Precision, Shape, SymbolicVariance)> {
dag.operators
.iter()
.enumerate()
.filter_map(|(i, op)| {
if let &Op::Lut { input, .. } = op {
Some((
out_precisions[input.i],
out_shapes[i].clone(),
out_variances[input.i],
))
} else {
None
}
})
.collect()
}
fn op_levelled_complexity(
@@ -378,8 +348,8 @@ fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f6
fn constraints_by_precisions(
out_precisions: &[Precision],
final_variances: &[(Precision, SymbolicVariance)],
in_luts_variance: &[(Precision, SymbolicVariance)],
final_variances: &[(Precision, Shape, SymbolicVariance)],
in_luts_variance: &[(Precision, Shape, SymbolicVariance)],
noise_config: &NoiseBoundConfig,
) -> Vec<VariancesAndBound> {
let precisions: HashSet<Precision> = out_precisions.iter().copied().collect();
@@ -397,11 +367,14 @@ fn constraints_by_precisions(
precisions.iter().rev().map(to_noise_summary).collect()
}
fn select_precision<T: Copy>(target_precision: Precision, v: &[(Precision, T)]) -> Vec<T> {
fn select_precision<T1: Clone, T2: Copy>(
target_precision: Precision,
v: &[(Precision, T1, T2)],
) -> Vec<(T1, T2)> {
v.iter()
.filter_map(|(p, t)| {
.filter_map(|(p, s, t)| {
if *p == target_precision {
Some(*t)
Some((s.clone(), *t))
} else {
None
}
@@ -409,16 +382,41 @@ fn select_precision<T: Copy>(target_precision: Precision, v: &[(Precision, T)])
.collect()
}
fn counted_symbolic_variance(
symbolic_variances: &[(Shape, SymbolicVariance)],
) -> Vec<(u64, SymbolicVariance)> {
pub fn exact_key(v: &SymbolicVariance) -> (u64, u64) {
(v.lut_coeff.to_bits(), v.input_coeff.to_bits())
}
let mut count: HashMap<(u64, u64), u64> = HashMap::new();
for (s, v) in symbolic_variances {
*count.entry(exact_key(v)).or_insert(0) += s.flat_size();
}
let mut res = Vec::new();
res.reserve_exact(count.len());
for (_s, v) in symbolic_variances {
if let Some(c) = count.remove(&exact_key(v)) {
res.push((c, *v));
}
}
res
}
fn constraint_for_one_precision(
target_precision: Precision,
extra_final_variances: &[(Precision, SymbolicVariance)],
in_luts_variance: &[(Precision, SymbolicVariance)],
extra_final_variances: &[(Precision, Shape, SymbolicVariance)],
in_luts_variance: &[(Precision, Shape, SymbolicVariance)],
safe_noise_bound: f64,
) -> VariancesAndBound {
let extra_final_variances = select_precision(target_precision, extra_final_variances);
let extra_finals_variance = select_precision(target_precision, extra_final_variances);
let in_luts_variance = select_precision(target_precision, in_luts_variance);
let nb_luts = in_luts_variance.len() as u64;
let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(extra_final_variances);
let all_output = counted_symbolic_variance(&extra_finals_variance);
let all_in_lut = counted_symbolic_variance(&in_luts_variance);
let remove_shape = |t: &(Shape, SymbolicVariance)| t.1;
let extra_finals_variance = extra_finals_variance.iter().map(remove_shape).collect();
let in_luts_variance = in_luts_variance.iter().map(remove_shape).collect();
let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(extra_finals_variance);
let pareto_vfs_in_lut = SymbolicVariance::reduce_to_pareto_front(in_luts_variance);
VariancesAndBound {
precision: target_precision,
@@ -426,6 +424,8 @@ fn constraint_for_one_precision(
nb_luts,
pareto_output: pareto_vfs_final,
pareto_in_lut: pareto_vfs_in_lut,
all_output,
all_in_lut,
}
}
@@ -434,10 +434,10 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
let out_shapes = out_shapes(dag);
let out_precisions = out_precisions(dag);
let out_variances = out_variances(dag, &out_shapes);
let in_luts_variance = in_luts_variance(dag, &out_precisions, &out_variances);
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
let coeffs = in_luts_variance
.iter()
.map(|(_precision, symbolic_variance)| {
.map(|(_precision, _shape, symbolic_variance)| {
symbolic_variance.lut_coeff + symbolic_variance.input_coeff
})
.filter(|v| *v >= 1.0);
@@ -445,6 +445,10 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 {
worst.log2()
}
pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 {
lut_count(dag, &out_shapes(dag))
}
pub fn analyze(
dag: &unparametrized::OperationDag,
noise_config: &NoiseBoundConfig,
@@ -453,9 +457,10 @@ pub fn analyze(
let out_shapes = out_shapes(dag);
let out_precisions = out_precisions(dag);
let out_variances = out_variances(dag, &out_shapes);
let in_luts_variance = in_luts_variance(dag, &out_precisions, &out_variances);
let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances);
let nb_luts = lut_count(dag, &out_shapes);
let extra_final_variances = extra_final_variances(dag, &out_precisions, &out_variances);
let extra_final_variances =
extra_final_variances(dag, &out_shapes, &out_precisions, &out_variances);
let levelled_complexity = levelled_complexity(dag, &out_shapes);
let constraints_by_precisions = constraints_by_precisions(
&out_precisions,
@@ -544,48 +549,111 @@ fn peak_relative_variance(
(max_relative_var, safe_noise)
}
fn peak_p_error(
dag: &OperationDag,
fn p_success_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 {
let sigma_scale = kappa / relative_variance.sqrt();
error::success_probability_of_sigma_scale(sigma_scale)
}
fn p_success_per_constraint(
constraint: &VariancesAndBound,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
kappa: f64,
) -> (f64, f64) {
let (relative_var, variance_bound) = peak_relative_variance(
dag,
input_noise_out,
blind_rotate_noise_out,
noise_keyswitch,
noise_modulus_switching,
);
let sigma_scale = kappa / relative_var.sqrt();
(
error::error_probability_of_sigma_scale(sigma_scale),
relative_var * variance_bound,
)
) -> f64 {
// Note: no log probability to keep accuracy near 0, 0 is a fine answer when p_success is very small.
let mut p_success = 1.0;
for &(count, vf) in &constraint.all_output {
assert!(0 < count);
let variance = vf.eval(input_noise_out, blind_rotate_noise_out);
let relative_variance = variance / constraint.safe_variance_bound;
let vf_p_success = p_success_from_relative_variance(relative_variance, kappa);
p_success *= vf_p_success.powi(count as i32);
}
// the maximal variance encountered during a lut computation
for &(count, vf) in &constraint.all_in_lut {
assert!(0 < count);
let variance = vf.eval(input_noise_out, blind_rotate_noise_out);
let relative_variance =
(variance + noise_keyswitch + noise_modulus_switching) / constraint.safe_variance_bound;
let vf_p_success = p_success_from_relative_variance(relative_variance, kappa);
p_success *= vf_p_success.powi(count as i32);
}
p_success
}
fn feasible(
dag: &OperationDag,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
) -> bool {
for ns in &dag.constraints_by_precisions {
if peak_variance_per_constraint(
ns,
impl OperationDag {
pub fn peek_p_error(
&self,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
kappa: f64,
) -> (f64, f64) {
let (relative_var, variance_bound) = peak_relative_variance(
self,
input_noise_out,
blind_rotate_noise_out,
noise_keyswitch,
noise_modulus_switching,
) > ns.safe_variance_bound
{
return false;
}
);
(
1.0 - p_success_from_relative_variance(relative_var, kappa),
relative_var * variance_bound,
)
}
pub fn global_p_error(
&self,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
kappa: f64,
) -> f64 {
let mut p_success = 1.0;
for ns in &self.constraints_by_precisions {
p_success *= p_success_per_constraint(
ns,
input_noise_out,
blind_rotate_noise_out,
noise_keyswitch,
noise_modulus_switching,
kappa,
);
}
assert!(0.0 <= p_success && p_success <= 1.0);
1.0 - p_success
}
pub fn feasible(
&self,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
) -> bool {
for ns in &self.constraints_by_precisions {
if peak_variance_per_constraint(
ns,
input_noise_out,
blind_rotate_noise_out,
noise_keyswitch,
noise_modulus_switching,
) > ns.safe_variance_bound
{
return false;
}
}
true
}
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
let luts_cost = one_lut_cost * (self.nb_luts as f64);
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
luts_cost + levelled_cost
}
true
}
#[cfg(test)]

View File

@@ -79,6 +79,8 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
)
.get_variance();
let mut best_br_noise = f64::INFINITY;
let mut best_ks_noise = f64::INFINITY;
let mut best_br_i = 0;
let mut best_ks_i = 0;
let mut update_best_solution = false;
@@ -154,6 +156,8 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
best_complexity = complexity;
best_p_error = peek_p_error;
best_variance = variance;
best_br_noise = br_quantity.noise;
best_ks_noise = ks_quantity.noise;
best_br_i = br_quantity.index;
best_ks_i = ks_quantity.index;
}
@@ -180,6 +184,13 @@ fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
br_decomposition_base_log: br_b,
complexity: best_complexity,
p_error: best_p_error,
global_p_error: dag.global_p_error(
input_noise_out,
best_br_noise,
best_ks_noise,
noise_modulus_switching,
consts.kappa,
),
noise_max: best_variance,
});
}
@@ -277,7 +288,9 @@ pub fn optimize<W: UnsignedInteger>(
if let Some(sol) = state.best_solution {
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
assert!(0.0 <= sol.global_p_error && sol.global_p_error <= 1.0);
assert!(sol.p_error <= maximum_acceptable_error_probability * REL_EPSILON_PROBA);
assert!(sol.p_error <= sol.global_p_error * REL_EPSILON_PROBA);
}
state
@@ -320,6 +333,7 @@ pub fn optimize_v0<W: UnsignedInteger>(
state
}
#[allow(clippy::unnecessary_cast)] // unecessary warning on 'as Precision'
#[cfg(test)]
mod tests {
use std::time::Instant;
@@ -336,8 +350,9 @@ mod tests {
}
impl Solution {
fn assert_same(&self, other: Self) -> bool {
fn assert_same_pbs_solution(&self, other: Self) -> bool {
let mut other = other;
other.global_p_error = self.global_p_error;
if small_relative_diff(self.noise_max, other.noise_max)
&& small_relative_diff(self.p_error, other.p_error)
{
@@ -444,7 +459,10 @@ mod tests {
}
let sol = state.best_solution.unwrap();
let sol_ref = state_ref.best_solution.unwrap();
assert!(sol.assert_same(sol_ref));
assert!(sol.assert_same_pbs_solution(sol_ref));
assert!(!sol.global_p_error.is_nan());
assert!(sol.p_error <= sol.global_p_error);
assert!(sol.global_p_error <= 1.0);
}
#[test]
@@ -508,7 +526,10 @@ mod tests {
let sol = state.best_solution.unwrap();
let mut sol_ref = state_ref.best_solution.unwrap();
sol_ref.complexity *= 2.0 /* number of luts */;
assert!(sol.assert_same(sol_ref));
assert!(sol.assert_same_pbs_solution(sol_ref));
assert!(!sol.global_p_error.is_nan());
assert!(sol.p_error <= sol.global_p_error);
assert!(sol.global_p_error <= 1.0);
}
fn no_lut_vs_lut(precision: Precision) {
@@ -619,10 +640,13 @@ mod tests {
let mut sol_multi = state_multi.best_solution.unwrap();
sol_multi.complexity /= 2.0;
if sol_low.complexity < sol_high.complexity {
assert!(sol_high.assert_same(sol_multi));
assert!(sol_high.assert_same_pbs_solution(sol_multi));
Some(true)
} else {
assert!(sol_low.complexity < sol_multi.complexity || sol_low.assert_same(sol_multi));
assert!(
sol_low.complexity < sol_multi.complexity
|| sol_low.assert_same_pbs_solution(sol_multi)
);
Some(false)
}
}
@@ -643,4 +667,151 @@ mod tests {
prev = current;
}
}
fn local_to_approx_global_p_error(local_p_error: f64, nb_pbs: u64) -> f64 {
#[allow(clippy::float_cmp)]
if local_p_error == 1f64 {
return 1.0;
}
#[allow(clippy::float_cmp)]
if local_p_error == 0f64 {
return 0.0;
}
let local_p_success = 1.0 - local_p_error;
assert!(local_p_success < 1.0);
let p_success = local_p_success.powi(nb_pbs as i32);
assert!(p_success < 1.0);
assert!(0.0 < p_success);
1.0 - p_success
}
#[test]
fn test_global_p_error_input() {
for precision in [4_u8, 8] {
for weight in [1, 3, 27, 243, 729] {
for dim in [1, 2, 16, 32] {
let _ = check_global_p_error_input(dim, weight, precision);
}
}
}
}
fn check_global_p_error_input(dim: u64, weight: u64, precision: u8) -> f64 {
let shape = Shape::vector(dim);
let weights = Weights::number(weight);
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, shape);
let _dot1 = dag.add_dot([input1], weights); // this is just several multiply
let state = optimize(&dag);
let sol = state.best_solution.unwrap();
let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim);
approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim);
sol.global_p_error
}
#[test]
fn test_global_p_error_lut() {
for precision in [4_u8, 8] {
for weight in [1, 3, 27, 243, 729] {
for depth in [2, 16, 32] {
check_global_p_error_lut(depth, weight, precision);
}
}
}
}
fn check_global_p_error_lut(depth: u64, weight: u64, precision: u8) {
let shape = Shape::number();
let weights = Weights::number(weight);
let mut dag = unparametrized::OperationDag::new();
let mut last_val = dag.add_input(precision as u8, shape);
for _i in 0..depth {
let dot = dag.add_dot([last_val], &weights);
last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision);
}
let state = optimize(&dag);
let sol = state.best_solution.unwrap();
// the first lut on input has reduced impact on error probability
let lower_nb_dominating_lut = depth - 1;
let lower_global_p_error =
local_to_approx_global_p_error(sol.p_error, lower_nb_dominating_lut);
let higher_global_p_error =
local_to_approx_global_p_error(sol.p_error, lower_nb_dominating_lut + 1);
assert!(lower_global_p_error <= sol.global_p_error);
assert!(sol.global_p_error <= higher_global_p_error);
}
fn dag_2_precisions_lut_chain(
depth: u64,
precision_low: Precision,
precision_high: Precision,
weight_low: u64,
weight_high: u64,
) -> unparametrized::OperationDag {
let shape = Shape::number();
let mut dag = unparametrized::OperationDag::new();
let weights_low = Weights::number(weight_low);
let weights_high = Weights::number(weight_high);
let mut last_val_low = dag.add_input(precision_low as u8, &shape);
let mut last_val_high = dag.add_input(precision_high as u8, &shape);
for _i in 0..depth {
let dot_low = dag.add_dot([last_val_low], &weights_low);
last_val_low = dag.add_lut(dot_low, FunctionTable::UNKWOWN, precision_low);
let dot_high = dag.add_dot([last_val_high], &weights_high);
last_val_high = dag.add_lut(dot_high, FunctionTable::UNKWOWN, precision_high);
}
dag
}
#[test]
fn test_global_p_error_dominating_lut() {
let depth = 128;
let weights_low = 1;
let weights_high = 1;
let precision_low = 6 as Precision;
let precision_high = 8 as Precision;
let dag = dag_2_precisions_lut_chain(
depth,
precision_low,
precision_high,
weights_low,
weights_high,
);
let sol = optimize(&dag).best_solution.unwrap();
// the 2 first luts and low precision/weight luts have little impact on error probability
let nb_dominating_lut = depth - 1;
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
// errors rate is approximated accurately
approx::assert_relative_eq!(
sol.global_p_error,
approx_global_p_error,
max_relative = 1e-01
);
}
#[test]
fn test_global_p_error_non_dominating_lut() {
let depth = 128;
let weights_low = 1024 * 1024 * 3;
let weights_high = 1;
let precision_low = 6 as Precision;
let precision_high = 8 as Precision;
let dag = dag_2_precisions_lut_chain(
depth,
precision_low,
precision_high,
weights_low,
weights_high,
);
let sol = optimize(&dag).best_solution.unwrap();
// all intern luts have an impact on error probability almost equaly
let nb_dominating_lut = (2 * depth) - 1;
let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut);
// errors rate is approximated accurately
approx::assert_relative_eq!(
sol.global_p_error,
approx_global_p_error,
max_relative = 0.05
);
}
}

View File

@@ -34,6 +34,14 @@ fn max_precision(dag: &OperationDag) -> Precision {
.unwrap_or(0)
}
fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {
let global_p_error = 1.0 - (1.0 - sol.p_error).powi(nb_luts as i32);
WopSolution {
global_p_error,
..sol
}
}
pub fn optimize<W: UnsignedInteger>(
dag: &OperationDag,
security_level: u64,
@@ -61,6 +69,7 @@ pub fn optimize<W: UnsignedInteger>(
let fallback_16b_precision = 16;
let default_log_norm = default_log_norm2_woppbs;
let worst_log_norm = analyze::worst_log_norm(dag);
let nb_luts = analyze::lut_count_from_dag(dag);
let log_norm = default_log_norm.min(worst_log_norm);
let opt_sol = wop_optimize::<W>(
fallback_16b_precision,
@@ -72,6 +81,6 @@ pub fn optimize<W: UnsignedInteger>(
internal_lwe_dimensions,
)
.best_solution;
opt_sol.map(Solution::WopSolution)
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
}
}

View File

@@ -65,6 +65,7 @@ pub struct Solution {
pub complexity: f64,
pub noise_max: f64,
pub p_error: f64,
pub global_p_error: f64,
// error probability
pub cb_decomposition_level_count: u64,
pub cb_decomposition_base_log: u64,
@@ -84,6 +85,7 @@ impl Solution {
complexity: 0.,
noise_max: 0.0,
p_error: 0.0,
global_p_error: 0.0,
cb_decomposition_level_count: 0,
cb_decomposition_base_log: 0,
}
@@ -104,6 +106,7 @@ impl From<Solution> for atomic_pattern::Solution {
complexity: sol.complexity,
noise_max: sol.noise_max,
p_error: sol.p_error,
global_p_error: sol.global_p_error,
}
}
}
@@ -431,6 +434,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
noise_max: variance_max,
complexity,
p_error,
global_p_error: f64::NAN,
cb_decomposition_level_count: circuit_pbs_decomposition_parameter.level,
cb_decomposition_base_log: circuit_pbs_decomposition_parameter.log2_base,
});