feat(optimizer): multiparameters optimization

This commit is contained in:
rudy
2023-03-23 11:55:18 +01:00
committed by Quentin Bourgerie
parent 361244abd0
commit 3e05aa47a4
15 changed files with 2042 additions and 38 deletions

View File

@@ -4,6 +4,7 @@
#![allow(clippy::cast_lossless)]
#![allow(clippy::cast_precision_loss)] // u64 to f64
#![allow(clippy::cast_possible_truncation)] // u64 to usize
#![allow(clippy::question_mark)]
#![allow(clippy::match_wildcard_for_single_variants)]
#![allow(clippy::manual_range_contains)]
#![allow(clippy::missing_panics_doc)]

View File

@@ -1,3 +1,5 @@
use std::collections::HashSet;
use crate::dag::operator::{
dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape,
};
@@ -14,6 +16,8 @@ use crate::optimization::dag::solo_key::analyze::{
extra_final_values_to_check, first, safe_noise_bound,
};
use super::complexity::OperationsCount;
use super::operations_value::OperationsValue;
use super::variance_constraint::VarianceConstraint;
use crate::utils::square;
@@ -35,31 +39,37 @@ pub struct AnalyzedDag {
pub variance_constraints: Vec<VarianceConstraint>,
// Undominated variance constraints
pub undominated_variance_constraints: Vec<VarianceConstraint>,
pub operations_count_per_instrs: Vec<OperationsCount>,
pub operations_count: OperationsCount,
pub p_cut: PrecisionCut,
}
pub fn analyze(
dag: &unparametrized::OperationDag,
noise_config: &NoiseBoundConfig,
p_cut: &PrecisionCut,
p_cut: &Option<PrecisionCut>,
default_partition: PartitionIndex,
) -> AnalyzedDag {
assert!(
p_cut.p_cut.len() <= 1,
"Multi-parameter can only be used 0 or 1 precision cut"
);
let dag = expand_round(dag);
let levelled_complexity = LevelledComplexity::ZERO;
// The precision cut is chosen to work well with rounded pbs
// Note: this is temporary
let partitions = partitionning_with_preferred(&dag, p_cut, default_partition);
#[allow(clippy::option_if_let_else)]
let p_cut = match p_cut {
Some(p_cut) => p_cut.clone(),
None => maximal_p_cut(&dag),
};
let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition);
let instrs_partition = partitions.instrs_partition;
let nb_partitions = partitions.nb_partitions;
let out_variances = out_variances(&dag, nb_partitions, &instrs_partition);
let variance_constraints =
collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances);
let undominated_variance_constraints =
VarianceConstraint::remove_dominated(&variance_constraints);
let operations_count_per_instrs =
collect_operations_count(&dag, nb_partitions, &instrs_partition);
let operations_count = sum_operations_count(&operations_count_per_instrs);
AnalyzedDag {
operators: dag.operators,
nb_partitions,
@@ -68,6 +78,9 @@ pub fn analyze(
levelled_complexity,
variance_constraints,
undominated_variance_constraints,
operations_count_per_instrs,
operations_count,
p_cut,
}
}
@@ -240,8 +253,65 @@ fn collect_all_variance_constraints(
constraints
}
#[allow(clippy::match_on_vec_items)]
fn operations_counts(
dag: &unparametrized::OperationDag,
op: &unparametrized::UnparameterizedOperator,
nb_partitions: usize,
instr_partition: &InstructionPartition,
) -> OperationsCount {
let mut counts = OperationsValue::zero(nb_partitions);
if let Op::Lut { input, .. } = op {
let partition = instr_partition.instruction_partition;
let nb_lut = dag.out_shapes[input.i].flat_size() as f64;
let src_partition = match instr_partition.inputs_transition[0] {
Some(Transition::Internal { src_partition }) => src_partition,
Some(Transition::Additional { .. }) | None => partition,
};
*counts.ks(src_partition, partition) += nb_lut;
*counts.pbs(partition) += nb_lut;
for &conv_partition in &instr_partition.alternative_output_representation {
*counts.fks(partition, conv_partition) += nb_lut;
}
}
OperationsCount { counts }
}
fn collect_operations_count(
dag: &unparametrized::OperationDag,
nb_partitions: usize,
instrs_partition: &[InstructionPartition],
) -> Vec<OperationsCount> {
dag.operators
.iter()
.enumerate()
.map(|(i, op)| operations_counts(dag, op, nb_partitions, &instrs_partition[i]))
.collect()
}
fn sum_operations_count(all_counts: &[OperationsCount]) -> OperationsCount {
let mut sum_counts = OperationsValue::zero(all_counts[0].counts.nb_partitions());
for OperationsCount { counts } in all_counts {
sum_counts += counts;
}
OperationsCount { counts: sum_counts }
}
fn maximal_p_cut(dag: &unparametrized::OperationDag) -> PrecisionCut {
let mut lut_in_precisions: HashSet<_> = HashSet::default();
for op in &dag.operators {
if let Op::Lut { input, .. } = op {
_ = lut_in_precisions.insert(dag.out_precisions[input.i]);
}
}
let mut p_cut: Vec<_> = lut_in_precisions.iter().copied().collect();
p_cut.sort_unstable();
_ = p_cut.pop();
PrecisionCut { p_cut }
}
#[cfg(test)]
mod tests {
pub mod tests {
use super::*;
use crate::dag::operator::{FunctionTable, Shape};
use crate::dag::unparametrized;
@@ -250,16 +320,16 @@ mod tests {
};
use crate::optimization::dag::solo_key::analyze::tests::CONFIG;
fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag {
pub fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag {
analyze_with_preferred(dag, LOW_PRECISION_PARTITION)
}
fn analyze_with_preferred(
pub fn analyze_with_preferred(
dag: &unparametrized::OperationDag,
default_partition: PartitionIndex,
) -> AnalyzedDag {
let p_cut = PrecisionCut { p_cut: vec![2] };
super::analyze(dag, &CONFIG, &p_cut, default_partition)
super::analyze(dag, &CONFIG, &Some(p_cut), default_partition)
}
#[allow(clippy::float_cmp)]
@@ -658,4 +728,76 @@ mod tests {
);
}
}
#[test]
fn test_rounded_v3_classic_first_layer_second_layer_complexity() {
let acc_precision = 7;
let precision = 4;
let mut dag = unparametrized::OperationDag::new();
let free_input1 = dag.add_input(precision, Shape::number());
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
let rounded1 = dag.add_expanded_round(input1, precision);
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision);
let old_dag = dag;
let dag = analyze(&old_dag);
// Partition 0
let instrs_counts: Vec<_> = dag
.operations_count_per_instrs
.iter()
.map(OperationsCount::to_string)
.collect();
#[rustfmt::skip] // nighlty and stable are inconsitent here
let expected_counts = [
"ZERO x ¢", // free_input1
"1¢K[1] + 1¢Br[1] + 1¢FK[1→0]", // input1
"ZERO x ¢", // shift
"ZERO x ¢", // cast
"1¢K[0] + 1¢Br[0]", // extract (lut)
"ZERO x ¢", // erase (dot)
"ZERO x ¢", // cast
"ZERO x ¢", // shift
"ZERO x ¢", // cast
"1¢K[0] + 1¢Br[0]", // extract (lut)
"ZERO x ¢", // erase (dot)
"ZERO x ¢", // cast
"ZERO x ¢", // shift
"ZERO x ¢", // cast
"1¢K[0] + 1¢Br[0]", // extract (lut)
"ZERO x ¢", // erase (dot)
"ZERO x ¢", // cast
"1¢K[0→1] + 1¢Br[1]", // _lut1
];
for ((c, ec), op) in instrs_counts.iter().zip(expected_counts).zip(dag.operators) {
assert!(
c == ec,
"\nBad count on {op}\nActual: {c}\nTruth : {ec} (expected)\n"
);
}
eprintln!("{}", dag.operations_count);
assert!(
format!("{}", dag.operations_count)
== "3¢K[0] + 1¢K[0→1] + 1¢K[1] + 3¢Br[0] + 2¢Br[1] + 1¢FK[1→0]"
);
}
#[test]
fn test_high_partition_number() {
let mut dag = unparametrized::OperationDag::new();
let max_precision = 10;
let mut lut_input = dag.add_input(max_precision, Shape::number());
let mut p_cut = vec![];
for out_precision in (1..=max_precision).rev() {
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
}
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, 1);
for out_precision in 1..max_precision {
p_cut.push(out_precision);
}
eprintln!("{}", dag.dump());
let p_cut = PrecisionCut { p_cut };
eprintln!("{p_cut}");
let p_cut = Some(p_cut);
let dag = super::analyze(&dag, &CONFIG, &p_cut, LOW_PRECISION_PARTITION);
assert!(dag.nb_partitions == p_cut.unwrap().p_cut.len() + 1);
}
}

View File

@@ -0,0 +1,77 @@
use std::fmt;
use crate::utils::f64::f64_dot;
use super::operations_value::OperationsValue;
#[derive(Clone, Debug)]
pub struct OperationsCount {
pub counts: OperationsValue,
}
#[derive(Clone, Debug)]
pub struct OperationsCost {
pub costs: OperationsValue,
}
#[derive(Clone, Debug)]
pub struct Complexity {
pub counts: OperationsValue,
}
impl fmt::Display for OperationsCount {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut add_plus = "";
let counts = &self.counts;
let nb_partitions = counts.nb_partitions();
let index = counts.index;
for src_partition in 0..nb_partitions {
for dst_partition in 0..nb_partitions {
let coeff = counts.values[index.keyswitch_to_small(src_partition, dst_partition)];
if coeff != 0.0 {
if src_partition == dst_partition {
write!(f, "{add_plus}{coeff}¢K[{src_partition}]")?;
} else {
write!(f, "{add_plus}{coeff}¢K[{src_partition}→{dst_partition}]")?;
}
add_plus = " + ";
}
}
}
for src_partition in 0..nb_partitions {
assert!(counts.values[index.input(src_partition)] == 0.0);
let coeff = counts.values[index.pbs(src_partition)];
if coeff != 0.0 {
write!(f, "{add_plus}{coeff}¢Br[{src_partition}]")?;
add_plus = " + ";
}
for dst_partition in 0..nb_partitions {
let coeff = counts.values[index.keyswitch_to_big(src_partition, dst_partition)];
if coeff != 0.0 {
write!(f, "{add_plus}{coeff}¢FK[{src_partition}→{dst_partition}]")?;
add_plus = " + ";
}
}
}
for partition in 0..nb_partitions {
assert!(counts.values[index.modulus_switching(partition)] == 0.0);
}
if add_plus.is_empty() {
write!(f, "ZERO x ¢")?;
}
Ok(())
}
}
impl Complexity {
pub fn of(counts: &OperationsCount) -> Self {
Self {
counts: counts.counts.clone(),
}
}
pub fn complexity(&self, costs: &OperationsValue) -> f64 {
f64_dot(&self.counts, costs)
}
}

View File

@@ -0,0 +1,87 @@
// TODO: move to cache with pareto check
use concrete_cpu_noise_model::gaussian_noise::conversion::modular_variance_to_variance;
// TODO: move to concrete-cpu
use crate::optimization::decomposition::keyswitch::KsComplexityNoise;
use crate::parameters::{GlweParameters, KsDecompositionParameters};
use serde::{Deserialize, Serialize};
// output glwe is incorporated in KsComplexityNoise
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
pub struct FksComplexityNoise {
// 1 -> 0
// k1⋅N1>k0⋅N0, k0⋅N0≥k1⋅N1
pub decomp: KsDecompositionParameters,
pub noise: f64,
pub complexity: f64,
}
// Copy & paste from concrete-cpu
const FFT_SCALING_WEIGHT: f64 = -2.577_224_94;
fn fft_noise_variance_external_product_glwe(
glwe_dimension: u64,
polynomial_size: u64,
log2_base: u64,
level: u64,
ciphertext_modulus_log: u32,
) -> f64 {
// https://github.com/zama-ai/concrete-optimizer/blob/prototype/python/optimizer/noise_formulas/bootstrap.py#L25
let b = 2_f64.powi(log2_base as i32);
let l = level as f64;
let big_n = polynomial_size as f64;
let k = glwe_dimension;
assert!(k > 0, "k = {k}");
// 22 = 2 x 11, 11 = 64 -53
let scale_margin = (1_u64 << 22) as f64;
let res =
f64::exp2(FFT_SCALING_WEIGHT) * scale_margin * l * b * b * big_n.powi(2) * (k as f64 + 1.);
modular_variance_to_variance(res, ciphertext_modulus_log)
}
#[allow(non_snake_case)]
fn upper_k0(input_glwe: &GlweParameters, output_glwe: &GlweParameters) -> u64 {
let k1 = input_glwe.glwe_dimension;
let N1 = input_glwe.polynomial_size();
let k0 = output_glwe.glwe_dimension;
let N0 = output_glwe.polynomial_size();
assert!(k1 * N1 >= k0 * N0);
// candidate * N0 >= k1 * N1
let f_upper_k0 = (k1 * N1) as f64 / N0 as f64;
#[allow(clippy::cast_sign_loss)]
let upper_k0 = f_upper_k0.ceil() as u64;
upper_k0
}
#[allow(non_snake_case)]
pub fn complexity(input_glwe: &GlweParameters, output_glwe: &GlweParameters, level: u64) -> f64 {
let k0 = output_glwe.glwe_dimension;
let N0 = output_glwe.polynomial_size();
let upper_k0 = upper_k0(input_glwe, output_glwe);
#[allow(clippy::cast_sign_loss)]
let log2_N0 = (N0 as f64).log2().ceil() as u64;
let size0 = (k0 + 1) * N0 * log2_N0;
let mul_count = size0 * upper_k0 * level;
let add_count = size0 * (upper_k0 * level - 1);
(add_count + mul_count) as f64
}
#[allow(non_snake_case)]
pub fn noise(
ks: &KsComplexityNoise,
input_glwe: &GlweParameters,
output_glwe: &GlweParameters,
) -> f64 {
let N0 = output_glwe.polynomial_size();
let upper_k0 = upper_k0(input_glwe, output_glwe);
ks.noise(input_glwe.sample_extract_lwe_dimension())
+ fft_noise_variance_external_product_glwe(
upper_k0,
N0,
ks.decomp.log2_base,
ks.decomp.level,
64,
)
}

View File

@@ -0,0 +1,122 @@
use crate::noise_estimator::p_error::{combine_errors, repeat_p_error};
use crate::optimization::dag::multi_parameters::variance_constraint::VarianceConstraint;
use crate::optimization::dag::solo_key::analyze::p_error_from_relative_variance;
use crate::utils::f64::f64_dot;
use super::operations_value::OperationsValue;
use super::partitions::PartitionIndex;
pub struct Feasible {
// TODO: move kappa here
pub constraints: Vec<VarianceConstraint>,
pub undominated_constraints: Vec<VarianceConstraint>,
pub kappa: f64, // to convert variance to local probabilities
pub global_p_error: Option<f64>,
}
impl Feasible {
pub fn of(constraints: &[VarianceConstraint], kappa: f64, global_p_error: Option<f64>) -> Self {
let undominated_constraints = VarianceConstraint::remove_dominated(constraints);
Self {
kappa,
constraints: constraints.into(),
undominated_constraints,
global_p_error,
}
}
pub fn feasible(&self, operations_variance: &OperationsValue) -> bool {
if self.global_p_error.is_none() {
self.local_feasible(operations_variance)
} else {
self.global_feasible(operations_variance)
}
}
fn local_feasible(&self, operations_variance: &OperationsValue) -> bool {
for constraint in &self.undominated_constraints {
if f64_dot(operations_variance, &constraint.variance.coeffs)
> constraint.safe_variance_bound
{
return false;
};
}
true
}
fn global_feasible(&self, operations_variance: &OperationsValue) -> bool {
self.global_p_error_with_cut(operations_variance, self.global_p_error.unwrap_or(1.0))
.is_some()
}
pub fn worst_constraint(
&self,
operations_variance: &OperationsValue,
) -> (f64, f64, &VarianceConstraint) {
let mut worst_constraint = &self.undominated_constraints[0];
let mut worst_relative_variance = 0.0;
let mut worst_variance = 0.0;
for constraint in &self.undominated_constraints {
let variance = f64_dot(operations_variance, &constraint.variance.coeffs);
let relative_variance = variance / constraint.safe_variance_bound;
if relative_variance > worst_relative_variance {
worst_relative_variance = relative_variance;
worst_variance = variance;
worst_constraint = constraint;
}
}
(worst_variance, worst_relative_variance, worst_constraint)
}
pub fn p_error(&self, operations_variance: &OperationsValue) -> f64 {
let (_, relative_variance, _) = self.worst_constraint(operations_variance);
p_error_from_relative_variance(relative_variance, self.kappa)
}
fn global_p_error_with_cut(
&self,
operations_variance: &OperationsValue,
cut: f64,
) -> Option<f64> {
let mut global_p_error = 0.0;
for constraint in &self.constraints {
let variance = f64_dot(operations_variance, &constraint.variance.coeffs);
let relative_variance = variance / constraint.safe_variance_bound;
let p_error = p_error_from_relative_variance(relative_variance, self.kappa);
global_p_error = combine_errors(
global_p_error,
repeat_p_error(p_error, constraint.nb_constraints),
);
if global_p_error > cut {
return None;
}
}
Some(global_p_error)
}
pub fn global_p_error(&self, operations_variance: &OperationsValue) -> f64 {
self.global_p_error_with_cut(operations_variance, 1.0)
.unwrap_or(1.0)
}
pub fn filter_constraints(&self, partition: PartitionIndex) -> Self {
let nb_partitions = self.constraints[0].variance.nb_partitions();
let touch_any_ks = |constraint: &VarianceConstraint, i| {
let variance = &constraint.variance;
variance.coeff_keyswitch_to_small(partition, i) > 0.0
|| variance.coeff_keyswitch_to_small(i, partition) > 0.0
|| variance.coeff_partition_keyswitch_to_big(partition, i) > 0.0
|| variance.coeff_partition_keyswitch_to_big(i, partition) > 0.0
};
let partition_constraints: Vec<_> = self
.constraints
.iter()
.filter(|constraint| {
constraint.partition == partition
|| (0..nb_partitions).any(|i| touch_any_ks(constraint, i))
})
.map(VarianceConstraint::clone)
.collect();
Self::of(&partition_constraints, self.kappa, self.global_p_error)
}
}

View File

@@ -1,9 +1,13 @@
pub mod analyze;
mod analyze;
mod complexity;
mod fast_keyswitch;
mod feasible;
pub mod keys_spec;
pub(crate) mod operations_value;
pub(crate) mod partitionning;
pub(crate) mod partitions;
pub(crate) mod precision_cut;
pub(crate) mod symbolic_variance;
pub(crate) mod union_find;
pub(crate) mod variance_constraint;
mod operations_value;
pub mod optimize;
mod partitionning;
mod partitions;
mod precision_cut;
mod symbolic_variance;
mod union_find;
mod variance_constraint;

View File

@@ -2,6 +2,7 @@ use crate::dag::operator::{Operator, Precision};
use crate::dag::unparametrized;
use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
#[derive(Clone, Debug)]
pub struct PrecisionCut {
// partition0 precision <= p_cut[0] < partition 1 precision <= p_cut[1] ...
// precision are in the sens of Lut input precision and are sorted

View File

@@ -0,0 +1,538 @@
// part of optimize.rs
#[cfg(test)]
mod tests {
#![allow(clippy::float_cmp)]
use once_cell::sync::Lazy;
use super::*;
use crate::computing_cost::cpu::CpuComplexity;
use crate::config;
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape};
use crate::dag::unparametrized;
use crate::optimization::dag::solo_key;
use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag};
use crate::optimization::decomposition;
static SHARED_CACHES: Lazy<PersistDecompCaches> = Lazy::new(|| {
let processing_unit = config::ProcessingUnit::Cpu;
decomposition::cache(128, processing_unit, None, true)
});
const _4_SIGMA: f64 = 0.000_063_342_483_999_973;
const LOW_PARTITION: PartitionIndex = 0;
fn optimize(
dag: &unparametrized::OperationDag,
p_cut: &Option<PrecisionCut>,
default_partition: usize,
) -> Option<Parameters> {
let config = Config {
security_level: 128,
maximum_acceptable_error_probability: _4_SIGMA,
ciphertext_modulus_log: 64,
complexity_model: &CpuComplexity::default(),
};
let search_space = SearchSpace::default_cpu();
super::optimize(
dag,
config,
&search_space,
&SHARED_CACHES,
p_cut,
default_partition,
)
}
fn optimize_single(dag: &unparametrized::OperationDag) -> Option<Parameters> {
optimize(dag, &Some(PrecisionCut { p_cut: vec![] }), LOW_PARTITION)
}
fn equiv_single(dag: &unparametrized::OperationDag) -> Option<bool> {
let sol_mono = solo_key::optimize::tests::optimize(dag);
let sol_multi = optimize_single(dag);
if sol_mono.best_solution.is_none() != sol_multi.is_none() {
eprintln!("Not same feasibility");
return Some(false);
};
if sol_multi.is_none() {
return None;
}
let equiv = sol_mono.best_solution.unwrap().complexity
== sol_multi.as_ref().unwrap().complexity;
if !equiv {
eprintln!("Not same complexity");
eprintln!("Single: {:?}", sol_mono.best_solution.unwrap());
eprintln!(
"Multi: {:?}",
sol_multi.clone().unwrap().complexity
);
eprintln!("Multi: {:?}", sol_multi.unwrap());
}
Some(equiv)
}
#[test]
fn optimize_simple_parameter_v0_dag() {
for precision in 1..11 {
for manp in 1..25 {
eprintln!("P M {precision} {manp}");
let dag = v0_dag(0, precision, manp as f64);
if let Some(equiv) = equiv_single(&dag) {
assert!(equiv);
} else {
break;
}
}
}
}
#[test]
fn optimize_simple_parameter_rounded_lut_2_layers() {
for accumulator_precision in 1..11 {
for precision in 1..accumulator_precision {
for manp in [1, 8, 16] {
eprintln!("CASE {accumulator_precision} {precision} {manp}");
let dag = v0_dag(0, precision, manp as f64);
if let Some(equiv) = equiv_single(&dag) {
assert!(equiv);
} else {
break;
}
}
}
}
}
fn equiv_2_single(
dag_multi: &unparametrized::OperationDag,
dag_1: &unparametrized::OperationDag,
dag_2: &unparametrized::OperationDag,
) -> Option<bool> {
let precision_max = dag_multi.out_precisions.iter().copied().max().unwrap();
let p_cut = Some(PrecisionCut {
p_cut: vec![precision_max - 1],
});
eprintln!("{dag_multi}");
let sol_single_1 = solo_key::optimize::tests::optimize(dag_1);
let sol_single_2 = solo_key::optimize::tests::optimize(dag_2);
let sol_multi = optimize(dag_multi, &p_cut, LOW_PARTITION);
let sol_multi_1 = optimize(dag_1, &p_cut, LOW_PARTITION);
let sol_multi_2 = optimize(dag_2, &p_cut, LOW_PARTITION);
let feasible_1 = sol_single_1.best_solution.is_some();
let feasible_2 = sol_single_2.best_solution.is_some();
let feasible_multi = sol_multi.is_some();
if (feasible_1 && feasible_2) != feasible_multi {
eprintln!(
"Not same feasibility {feasible_1} {feasible_2} {feasible_multi}"
);
return Some(false);
}
if sol_multi.is_none() {
return None;
}
let sol_multi = sol_multi.unwrap();
let sol_multi_1 = sol_multi_1.unwrap();
let sol_multi_2 = sol_multi_2.unwrap();
let cost_1 = sol_single_1.best_solution.unwrap().complexity;
let cost_2 = sol_single_2.best_solution.unwrap().complexity;
let cost_multi = sol_multi.complexity;
let equiv =
cost_1 + cost_2 == cost_multi
&& cost_1 == sol_multi_1.complexity
&& cost_2 == sol_multi_2.complexity
&& sol_multi.micro_params.ks[0][0].unwrap().decomp ==
sol_multi_1.micro_params.ks[0][0].unwrap().decomp
&& sol_multi.micro_params.ks[1][1].unwrap().decomp ==
sol_multi_2.micro_params.ks[0][0].unwrap().decomp
;
if !equiv {
eprintln!("Not same complexity");
eprintln!("Multi: {cost_multi:?}");
eprintln!("Added Single: {:?}", cost_1 + cost_2);
eprintln!("Single1: {:?}", sol_single_1.best_solution.unwrap());
eprintln!("Single2: {:?}", sol_single_2.best_solution.unwrap());
eprintln!("Multi: {sol_multi:?}");
eprintln!("Multi1: {sol_multi_1:?}");
eprintln!("Multi2: {sol_multi_2:?}");
}
Some(equiv)
}
#[test]
fn optimize_multi_independant_2_precisions() {
let sum_size = 0;
for precision1 in 1..11 {
for precision2 in (precision1 + 1)..11 {
for manp in [1, 8, 16] {
eprintln!("CASE {precision1} {precision2} {manp}");
let noise_factor = manp as f64;
let mut dag_multi = v0_dag(sum_size, precision1, noise_factor);
add_v0_dag(&mut dag_multi, sum_size, precision2, noise_factor);
let dag_1 = v0_dag(sum_size, precision1, noise_factor);
let dag_2 = v0_dag(sum_size, precision2, noise_factor);
if let Some(equiv) = equiv_2_single(&dag_multi, &dag_1, &dag_2) {
assert!(equiv, "FAILED ON {precision1} {precision2} {manp}");
} else {
break;
}
}
}
}
}
fn dag_lut_sum_of_2_partitions_2_layer(precision1: u8, precision2: u8, final_lut: bool) -> unparametrized::OperationDag {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision1, Shape::number());
let input2 = dag.add_input(precision2, Shape::number());
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision1);
let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, precision2);
let lut1 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision2);
let lut2 = dag.add_lut(lut2, FunctionTable::UNKWOWN, precision2);
let dot = dag.add_dot([lut1, lut2], [1, 1]);
if final_lut {
_ = dag.add_lut(dot, FunctionTable::UNKWOWN, precision1);
}
dag
}
#[test]
fn optimize_multi_independant_2_partitions_finally_added() {
let default_partition = 0;
let single_precision_sol : Vec<_> = (0..11).map(
|precision| {
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false);
optimize_single(&dag)
}
).collect();
for precision1 in 1..11 {
for precision2 in (precision1 + 1)..11 {
let p_cut = Some(PrecisionCut {
p_cut: vec![precision1],
});
let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, false);
let sol_1 = single_precision_sol[precision1 as usize].clone();
let sol_2 = single_precision_sol[precision2 as usize].clone();
let sol_multi = optimize(&dag_multi, &p_cut, LOW_PARTITION);
let feasible_multi = sol_multi.is_some();
let feasible_2 = sol_2.is_some();
assert!(feasible_multi);
assert!(feasible_2);
let sol_multi = sol_multi.unwrap();
let sol_1 = sol_1.unwrap();
let sol_2 = sol_2.unwrap();
assert!(sol_1.complexity < sol_multi.complexity);
assert!(sol_multi.complexity < sol_2.complexity);
eprintln!("{:?}", sol_multi.micro_params.fks);
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
assert!(sol_multi.macro_params[1] == sol_2.macro_params[0]);
// The smallest the precision the more fks noise break partition independence
if precision1 < 4 {
assert!(
sol_multi_without_fks / perfect_complexity < 1.1,
"{precision1} {precision2}"
);
} else if precision1 <= 7 {
assert!(
sol_multi_without_fks / perfect_complexity < 1.03,
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
);
} else {
assert!(
sol_multi_without_fks / perfect_complexity < 1.001,
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
);
}
}
}
}
#[test]
fn optimize_multi_independant_2_partitions_finally_added_and_luted() {
let default_partition = 0;
let single_precision_sol : Vec<_> = (0..11).map(
|precision| {
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true);
optimize_single(&dag)
}
).collect();
for precision1 in 1..11 {
for precision2 in (precision1 + 1)..11 {
let p_cut = Some(PrecisionCut {
p_cut: vec![precision1],
});
let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, true);
let sol_1 = single_precision_sol[precision1 as usize].clone();
let sol_2 = single_precision_sol[precision2 as usize].clone();
let sol_multi = optimize(&dag_multi, &p_cut, 0);
let feasible_multi = sol_multi.is_some();
let feasible_2 = sol_2.is_some();
assert!(feasible_multi);
assert!(feasible_2);
let sol_multi = sol_multi.unwrap();
let sol_1 = sol_1.unwrap();
let sol_2 = sol_2.unwrap();
// The smallest the precision the more fks noise dominate
assert!(sol_1.complexity < sol_multi.complexity);
assert!(sol_multi.complexity < sol_2.complexity);
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
let relative_degradation = sol_multi_without_fks / perfect_complexity;
if precision1 < 4 {
assert!(
relative_degradation < 1.2,
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
);
} else if precision1 <= 7 {
assert!(
relative_degradation < 1.19,
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
);
} else {
assert!(
relative_degradation < 1.15,
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
);
}
}
}
}
fn optimize_rounded(dag: &unparametrized::OperationDag) -> Option<Parameters> {
let p_cut = Some(PrecisionCut { p_cut: vec![1] });
let default_partition = 0;
optimize(dag, &p_cut, default_partition)
}
fn dag_rounded_lut_2_layers(
accumulator_precision: usize,
precision: usize,
) -> unparametrized::OperationDag {
let out_precision = accumulator_precision as u8;
let rounded_precision = precision as u8;
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, Shape::number());
let rounded1 = dag.add_expanded_rounded_lut(
input1,
FunctionTable::UNKWOWN,
rounded_precision,
out_precision,
);
let rounded2 = dag.add_expanded_rounded_lut(
rounded1,
FunctionTable::UNKWOWN,
rounded_precision,
out_precision,
);
let _rounded3 =
dag.add_expanded_rounded_lut(rounded2, FunctionTable::UNKWOWN, rounded_precision, 1);
dag
}
fn test_optimize_v3_expanded_round(precision_acc: usize, precision_tlu: usize, minimal_speedup: f64) {
let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu);
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
let sol = optimize_rounded(&dag).unwrap();
let speedup = sol_mono.complexity / sol.complexity;
assert!(speedup >= minimal_speedup,
"Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}"
);
let expected_ks = [
[true, true], // KS[0], KS[0->1]
[false, true],// KS[1]
];
let expected_fks = [
[false, false],
[true, false], // FKS[1->0]
];
for (src, dst) in cross_partition(2) {
assert!(sol.micro_params.ks[src][dst].is_some() == expected_ks[src][dst]);
assert!(sol.micro_params.fks[src][dst].is_some() == expected_fks[src][dst]);
}
}
#[test]
fn test_optimize_v3_expanded_round_16_8() {
test_optimize_v3_expanded_round(16, 8, 5.5);
}
#[test]
fn test_optimize_v3_expanded_round_16_6() {
test_optimize_v3_expanded_round(16, 6, 3.3);
}
#[test]
fn optimize_v3_direct_round() {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(16, Shape::number());
_ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16);
let sol = optimize_rounded(&dag).unwrap();
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
let minimal_speedup = 8.6;
let speedup = sol_mono.complexity / sol.complexity;
assert!(speedup >= minimal_speedup,
"Speedup {speedup} smaller than {minimal_speedup}"
);
}
#[test]
fn optimize_sign_extract() {
let precision = 8;
let high_precision = 16;
let mut dag = unparametrized::OperationDag::new();
let complexity = LevelledComplexity::ZERO;
let free_small_input1 = dag.add_input(precision, Shape::number());
let small_input1 = dag.add_lut(free_small_input1, FunctionTable::UNKWOWN, precision);
let small_input1 = dag.add_lut(small_input1, FunctionTable::UNKWOWN, high_precision);
let input1 = dag.add_levelled_op(
[small_input1],
complexity,
1.0,
Shape::vector(1_000_000),
"comment",
);
let rounded1 = dag.add_expanded_round(input1, 1);
let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1);
let sol = optimize_rounded(&dag).unwrap();
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
let speedup = sol_mono.complexity / sol.complexity;
let minimal_speedup = 80.0;
assert!(speedup >= minimal_speedup,
"Speedup {speedup} smaller than {minimal_speedup}"
);
}
fn test_partition_chain(decreasing: bool) {
// tlu chain with decreasing precision (decreasing partition index)
// check that increasing partitionning gaves faster solutions
// check solution has the right structure
let mut dag = unparametrized::OperationDag::new();
let min_precision = 6;
let max_precision = 8;
let mut input_precisions : Vec<_> = (min_precision..=max_precision).collect();
if decreasing {
input_precisions.reverse();
}
let mut lut_input = dag.add_input(input_precisions[0], Shape::number());
for &out_precision in &input_precisions {
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
}
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *input_precisions.last().unwrap());
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision);
let mut p_cut = PrecisionCut { p_cut:vec![] };
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
assert!(sol.macro_params.len() == 1);
let mut complexity = sol.complexity;
for &out_precision in &input_precisions {
if out_precision == max_precision {
// There is nothing to cut above max_precision
continue;
}
p_cut.p_cut.push(out_precision);
p_cut.p_cut.sort_unstable();
eprintln!("PCUT {p_cut}");
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
let nb_partitions = sol.macro_params.len();
assert!(nb_partitions == (p_cut.p_cut.len() + 1),
"bad nb partitions {} {p_cut}", sol.macro_params.len());
assert!(sol.complexity < complexity,
"{} < {complexity} {out_precision} / {max_precision}", sol.complexity);
for (src, dst) in cross_partition(nb_partitions) {
let ks = sol.micro_params.ks[src][dst];
eprintln!("{} {src} {dst}", ks.is_some());
let expected_ks =
(!decreasing || src == dst + 1)
&& (decreasing || src + 1 == dst)
|| (src == dst && (src == 0 || src == nb_partitions - 1))
;
assert!(ks.is_some() == expected_ks, "{:?} {:?}", ks.is_some(), expected_ks);
let fks = sol.micro_params.fks[src][dst];
assert!(fks.is_none());
}
complexity = sol.complexity;
}
let sol = optimize(&dag, &None, 0);
assert!(sol.unwrap().complexity == complexity);
}
#[test]
fn test_partition_decreasing_chain() {
test_partition_chain(true);
}
#[test]
fn test_partition_increasing_chain() {
test_partition_chain(true);
}
const MAX_WEIGHT: &[u64] = &[
// max v0 weight for each precision
1_073_741_824,
1_073_741_824, // 2**30, 1b
536_870_912, // 2**29, 2b
268_435_456, // 2**28, 3b
67_108_864, // 2**26, 4b
16_777_216, // 2**24, 5b
4_194_304, // 2**22, 6b
1_048_576, // 2**20, 7b
262_144, // 2**18, 8b
65_536, // 2**16, 9b
16384, // 2**14, 10b
2048, // 2**11, 11b
];
#[test]
fn test_independant_partitions_non_feasible_single_params() {
// generate hard circuit, non feasible with single parameters
// composed of independant partitions so we know the optimal result
let precisions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
let sum_size = 0;
let noise_factor = MAX_WEIGHT[precisions[0]] as f64;
let mut dag = v0_dag(sum_size, precisions[0] as u64, noise_factor);
let sol_single = optimize_single(&dag);
let mut optimal_complexity = sol_single.as_ref().unwrap().complexity;
let mut optimal_p_error = sol_single.unwrap().p_error;
for &out_precision in &precisions[1..] {
let noise_factor = MAX_WEIGHT[out_precision] as f64;
add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor);
let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor));
optimal_complexity += sol_single.as_ref().unwrap().complexity;
optimal_p_error += sol_single.as_ref().unwrap().p_error;
}
// check non feasible in single
let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution;
assert!(sol_single.is_none());
// solves in multi
let sol = optimize(&dag, &None, 0);
assert!(sol.is_some());
let sol = sol.unwrap();
// check optimality
assert!(sol.complexity / optimal_complexity < 1.0 + f64::EPSILON);
assert!(sol.p_error / optimal_p_error < 1.0 + f64::EPSILON);
}
#[test]
fn test_chained_partitions_non_feasible_single_params() {
// generate hard circuit, non feasible with single parameters
let precisions = [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
// Note: reversing chain have issues for connecting lower bits to 7 bits, there may be no feasible solution
let mut dag = unparametrized::OperationDag::new();
let mut lut_input = dag.add_input(precisions[0], Shape::number());
for out_precision in precisions {
let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64;
lut_input = dag.add_levelled_op([lut_input], LevelledComplexity::ZERO, noise_factor, Shape::number(), "");
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
}
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *precisions.last().unwrap());
let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution;
assert!(sol_single.is_none());
let sol = optimize(&dag, &None, 0);
assert!(sol.is_some());
}
}

View File

@@ -504,7 +504,7 @@ fn peak_relative_variance(
(max_relative_var, safe_noise)
}
fn p_error_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 {
pub fn p_error_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 {
let sigma_scale = kappa / relative_variance.sqrt();
error::error_probability_of_sigma_scale(sigma_scale)
}

View File

@@ -3,6 +3,7 @@ use concrete_cpu_noise_model::gaussian_noise::noise::modulus_switching::estimate
use super::analyze;
use crate::dag::operator::{LevelledComplexity, Precision};
use crate::dag::unparametrized;
use crate::dag::unparametrized::OperationDag;
use crate::noise_estimator::error;
use crate::optimization::atomic_pattern::{
OptimizationDecompositionsConsts, OptimizationState, Solution,
@@ -386,6 +387,27 @@ pub fn optimize(
state
}
pub fn add_v0_dag(dag: &mut OperationDag, sum_size: u64, precision: u64, noise_factor: f64) {
use crate::dag::operator::{FunctionTable, Shape};
let same_scale_manp = 1.0;
let manp = noise_factor;
let out_shape = &Shape::number();
let complexity = LevelledComplexity::ADDITION * sum_size;
let comment = "dot";
let precision = precision as Precision;
let input1 = dag.add_input(precision, out_shape);
let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
}
pub fn v0_dag(sum_size: u64, precision: u64, noise_factor: f64) -> OperationDag {
let mut dag = unparametrized::OperationDag::new();
add_v0_dag(&mut dag, sum_size, precision, noise_factor);
dag
}
pub fn optimize_v0(
sum_size: u64,
precision: u64,
@@ -394,19 +416,7 @@ pub fn optimize_v0(
search_space: &SearchSpace,
cache: &PersistDecompCaches,
) -> OptimizationState {
use crate::dag::operator::{FunctionTable, Shape};
let same_scale_manp = 0.0;
let manp = noise_factor;
let out_shape = &Shape::number();
let complexity = LevelledComplexity::ADDITION * sum_size;
let comment = "dot";
let mut dag = unparametrized::OperationDag::new();
let precision = precision as Precision;
let input1 = dag.add_input(precision, out_shape);
let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
let dag = v0_dag(sum_size, precision, noise_factor);
let mut state = optimize(&dag, config, search_space, cache);
if let Some(sol) = &mut state.best_solution {
sol.complexity /= 2.0;
@@ -415,7 +425,7 @@ pub fn optimize_v0(
}
#[cfg(test)]
mod tests {
pub(crate) mod tests {
use std::time::Instant;
use once_cell::sync::Lazy;
@@ -456,7 +466,7 @@ mod tests {
decomposition::cache(128, processing_unit, None, true)
});
fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
pub fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
let config = Config {
security_level: 128,
maximum_acceptable_error_probability: _4_SIGMA,

View File

@@ -102,8 +102,12 @@ pub fn pareto_quantities(
quantities
}
pub fn lowest_noise(quantities: &[CmuxComplexityNoise]) -> CmuxComplexityNoise {
quantities[quantities.len() - 1]
}
pub fn lowest_noise_br(quantities: &[CmuxComplexityNoise], in_lwe_dim: u64) -> f64 {
quantities[quantities.len() - 1].noise_br(in_lwe_dim)
lowest_noise(quantities).noise_br(in_lwe_dim)
}
pub fn lowest_complexity_br(quantities: &[CmuxComplexityNoise], in_lwe_dim: u64) -> f64 {

View File

@@ -17,7 +17,7 @@ use crate::optimization::decomposition::keyswitch::KsComplexityNoise;
use crate::optimization::decomposition::pp_switch::PpSwitchComplexityNoise;
use crate::optimization::decomposition::PersistDecompCaches;
use crate::parameters::{BrDecompositionParameters, GlweParameters};
use crate::utils::max::f64_max;
use crate::utils::f64::f64_max;
use crate::utils::square;
pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 {

View File

@@ -0,0 +1,11 @@
pub fn f64_max(values: &[f64], default: f64) -> f64 {
values.iter().copied().reduce(f64::max).unwrap_or(default)
}
pub fn f64_dot(a: &[f64], b: &[f64]) -> f64 {
let mut sum = 0.0;
for i in 0..a.len() {
sum += a[i] * b[i];
}
sum
}

View File

@@ -1,6 +1,6 @@
pub mod cache;
pub mod f64;
pub mod hasher_builder;
pub mod max;
pub fn square<V>(v: V) -> V
where