mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
feat(optimizer): multiparameters optimization
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#![allow(clippy::cast_lossless)]
|
||||
#![allow(clippy::cast_precision_loss)] // u64 to f64
|
||||
#![allow(clippy::cast_possible_truncation)] // u64 to usize
|
||||
#![allow(clippy::question_mark)]
|
||||
#![allow(clippy::match_wildcard_for_single_variants)]
|
||||
#![allow(clippy::manual_range_contains)]
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::dag::operator::{
|
||||
dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape,
|
||||
};
|
||||
@@ -14,6 +16,8 @@ use crate::optimization::dag::solo_key::analyze::{
|
||||
extra_final_values_to_check, first, safe_noise_bound,
|
||||
};
|
||||
|
||||
use super::complexity::OperationsCount;
|
||||
use super::operations_value::OperationsValue;
|
||||
use super::variance_constraint::VarianceConstraint;
|
||||
|
||||
use crate::utils::square;
|
||||
@@ -35,31 +39,37 @@ pub struct AnalyzedDag {
|
||||
pub variance_constraints: Vec<VarianceConstraint>,
|
||||
// Undominated variance constraints
|
||||
pub undominated_variance_constraints: Vec<VarianceConstraint>,
|
||||
pub operations_count_per_instrs: Vec<OperationsCount>,
|
||||
pub operations_count: OperationsCount,
|
||||
pub p_cut: PrecisionCut,
|
||||
}
|
||||
|
||||
pub fn analyze(
|
||||
dag: &unparametrized::OperationDag,
|
||||
noise_config: &NoiseBoundConfig,
|
||||
p_cut: &PrecisionCut,
|
||||
p_cut: &Option<PrecisionCut>,
|
||||
default_partition: PartitionIndex,
|
||||
) -> AnalyzedDag {
|
||||
assert!(
|
||||
p_cut.p_cut.len() <= 1,
|
||||
"Multi-parameter can only be used 0 or 1 precision cut"
|
||||
);
|
||||
let dag = expand_round(dag);
|
||||
let levelled_complexity = LevelledComplexity::ZERO;
|
||||
// The precision cut is chosen to work well with rounded pbs
|
||||
// Note: this is temporary
|
||||
let partitions = partitionning_with_preferred(&dag, p_cut, default_partition);
|
||||
#[allow(clippy::option_if_let_else)]
|
||||
let p_cut = match p_cut {
|
||||
Some(p_cut) => p_cut.clone(),
|
||||
None => maximal_p_cut(&dag),
|
||||
};
|
||||
let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition);
|
||||
let instrs_partition = partitions.instrs_partition;
|
||||
let nb_partitions = partitions.nb_partitions;
|
||||
let out_variances = out_variances(&dag, nb_partitions, &instrs_partition);
|
||||
|
||||
let variance_constraints =
|
||||
collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances);
|
||||
let undominated_variance_constraints =
|
||||
VarianceConstraint::remove_dominated(&variance_constraints);
|
||||
let operations_count_per_instrs =
|
||||
collect_operations_count(&dag, nb_partitions, &instrs_partition);
|
||||
let operations_count = sum_operations_count(&operations_count_per_instrs);
|
||||
AnalyzedDag {
|
||||
operators: dag.operators,
|
||||
nb_partitions,
|
||||
@@ -68,6 +78,9 @@ pub fn analyze(
|
||||
levelled_complexity,
|
||||
variance_constraints,
|
||||
undominated_variance_constraints,
|
||||
operations_count_per_instrs,
|
||||
operations_count,
|
||||
p_cut,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,8 +253,65 @@ fn collect_all_variance_constraints(
|
||||
constraints
|
||||
}
|
||||
|
||||
#[allow(clippy::match_on_vec_items)]
|
||||
fn operations_counts(
|
||||
dag: &unparametrized::OperationDag,
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
nb_partitions: usize,
|
||||
instr_partition: &InstructionPartition,
|
||||
) -> OperationsCount {
|
||||
let mut counts = OperationsValue::zero(nb_partitions);
|
||||
if let Op::Lut { input, .. } = op {
|
||||
let partition = instr_partition.instruction_partition;
|
||||
let nb_lut = dag.out_shapes[input.i].flat_size() as f64;
|
||||
let src_partition = match instr_partition.inputs_transition[0] {
|
||||
Some(Transition::Internal { src_partition }) => src_partition,
|
||||
Some(Transition::Additional { .. }) | None => partition,
|
||||
};
|
||||
*counts.ks(src_partition, partition) += nb_lut;
|
||||
*counts.pbs(partition) += nb_lut;
|
||||
for &conv_partition in &instr_partition.alternative_output_representation {
|
||||
*counts.fks(partition, conv_partition) += nb_lut;
|
||||
}
|
||||
}
|
||||
OperationsCount { counts }
|
||||
}
|
||||
|
||||
fn collect_operations_count(
|
||||
dag: &unparametrized::OperationDag,
|
||||
nb_partitions: usize,
|
||||
instrs_partition: &[InstructionPartition],
|
||||
) -> Vec<OperationsCount> {
|
||||
dag.operators
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, op)| operations_counts(dag, op, nb_partitions, &instrs_partition[i]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn sum_operations_count(all_counts: &[OperationsCount]) -> OperationsCount {
|
||||
let mut sum_counts = OperationsValue::zero(all_counts[0].counts.nb_partitions());
|
||||
for OperationsCount { counts } in all_counts {
|
||||
sum_counts += counts;
|
||||
}
|
||||
OperationsCount { counts: sum_counts }
|
||||
}
|
||||
|
||||
fn maximal_p_cut(dag: &unparametrized::OperationDag) -> PrecisionCut {
|
||||
let mut lut_in_precisions: HashSet<_> = HashSet::default();
|
||||
for op in &dag.operators {
|
||||
if let Op::Lut { input, .. } = op {
|
||||
_ = lut_in_precisions.insert(dag.out_precisions[input.i]);
|
||||
}
|
||||
}
|
||||
let mut p_cut: Vec<_> = lut_in_precisions.iter().copied().collect();
|
||||
p_cut.sort_unstable();
|
||||
_ = p_cut.pop();
|
||||
PrecisionCut { p_cut }
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
use crate::dag::operator::{FunctionTable, Shape};
|
||||
use crate::dag::unparametrized;
|
||||
@@ -250,16 +320,16 @@ mod tests {
|
||||
};
|
||||
use crate::optimization::dag::solo_key::analyze::tests::CONFIG;
|
||||
|
||||
fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag {
|
||||
pub fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag {
|
||||
analyze_with_preferred(dag, LOW_PRECISION_PARTITION)
|
||||
}
|
||||
|
||||
fn analyze_with_preferred(
|
||||
pub fn analyze_with_preferred(
|
||||
dag: &unparametrized::OperationDag,
|
||||
default_partition: PartitionIndex,
|
||||
) -> AnalyzedDag {
|
||||
let p_cut = PrecisionCut { p_cut: vec![2] };
|
||||
super::analyze(dag, &CONFIG, &p_cut, default_partition)
|
||||
super::analyze(dag, &CONFIG, &Some(p_cut), default_partition)
|
||||
}
|
||||
|
||||
#[allow(clippy::float_cmp)]
|
||||
@@ -658,4 +728,76 @@ mod tests {
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rounded_v3_classic_first_layer_second_layer_complexity() {
|
||||
let acc_precision = 7;
|
||||
let precision = 4;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let free_input1 = dag.add_input(precision, Shape::number());
|
||||
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
|
||||
let rounded1 = dag.add_expanded_round(input1, precision);
|
||||
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision);
|
||||
let old_dag = dag;
|
||||
let dag = analyze(&old_dag);
|
||||
// Partition 0
|
||||
let instrs_counts: Vec<_> = dag
|
||||
.operations_count_per_instrs
|
||||
.iter()
|
||||
.map(OperationsCount::to_string)
|
||||
.collect();
|
||||
#[rustfmt::skip] // nighlty and stable are inconsitent here
|
||||
let expected_counts = [
|
||||
"ZERO x ¢", // free_input1
|
||||
"1¢K[1] + 1¢Br[1] + 1¢FK[1→0]", // input1
|
||||
"ZERO x ¢", // shift
|
||||
"ZERO x ¢", // cast
|
||||
"1¢K[0] + 1¢Br[0]", // extract (lut)
|
||||
"ZERO x ¢", // erase (dot)
|
||||
"ZERO x ¢", // cast
|
||||
"ZERO x ¢", // shift
|
||||
"ZERO x ¢", // cast
|
||||
"1¢K[0] + 1¢Br[0]", // extract (lut)
|
||||
"ZERO x ¢", // erase (dot)
|
||||
"ZERO x ¢", // cast
|
||||
"ZERO x ¢", // shift
|
||||
"ZERO x ¢", // cast
|
||||
"1¢K[0] + 1¢Br[0]", // extract (lut)
|
||||
"ZERO x ¢", // erase (dot)
|
||||
"ZERO x ¢", // cast
|
||||
"1¢K[0→1] + 1¢Br[1]", // _lut1
|
||||
];
|
||||
for ((c, ec), op) in instrs_counts.iter().zip(expected_counts).zip(dag.operators) {
|
||||
assert!(
|
||||
c == ec,
|
||||
"\nBad count on {op}\nActual: {c}\nTruth : {ec} (expected)\n"
|
||||
);
|
||||
}
|
||||
eprintln!("{}", dag.operations_count);
|
||||
assert!(
|
||||
format!("{}", dag.operations_count)
|
||||
== "3¢K[0] + 1¢K[0→1] + 1¢K[1] + 3¢Br[0] + 2¢Br[1] + 1¢FK[1→0]"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_partition_number() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let max_precision = 10;
|
||||
let mut lut_input = dag.add_input(max_precision, Shape::number());
|
||||
let mut p_cut = vec![];
|
||||
for out_precision in (1..=max_precision).rev() {
|
||||
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
|
||||
}
|
||||
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, 1);
|
||||
for out_precision in 1..max_precision {
|
||||
p_cut.push(out_precision);
|
||||
}
|
||||
eprintln!("{}", dag.dump());
|
||||
let p_cut = PrecisionCut { p_cut };
|
||||
eprintln!("{p_cut}");
|
||||
let p_cut = Some(p_cut);
|
||||
let dag = super::analyze(&dag, &CONFIG, &p_cut, LOW_PRECISION_PARTITION);
|
||||
assert!(dag.nb_partitions == p_cut.unwrap().p_cut.len() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
use std::fmt;
|
||||
|
||||
use crate::utils::f64::f64_dot;
|
||||
|
||||
use super::operations_value::OperationsValue;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OperationsCount {
|
||||
pub counts: OperationsValue,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OperationsCost {
|
||||
pub costs: OperationsValue,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Complexity {
|
||||
pub counts: OperationsValue,
|
||||
}
|
||||
|
||||
impl fmt::Display for OperationsCount {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let mut add_plus = "";
|
||||
let counts = &self.counts;
|
||||
let nb_partitions = counts.nb_partitions();
|
||||
let index = counts.index;
|
||||
for src_partition in 0..nb_partitions {
|
||||
for dst_partition in 0..nb_partitions {
|
||||
let coeff = counts.values[index.keyswitch_to_small(src_partition, dst_partition)];
|
||||
if coeff != 0.0 {
|
||||
if src_partition == dst_partition {
|
||||
write!(f, "{add_plus}{coeff}¢K[{src_partition}]")?;
|
||||
} else {
|
||||
write!(f, "{add_plus}{coeff}¢K[{src_partition}→{dst_partition}]")?;
|
||||
}
|
||||
add_plus = " + ";
|
||||
}
|
||||
}
|
||||
}
|
||||
for src_partition in 0..nb_partitions {
|
||||
assert!(counts.values[index.input(src_partition)] == 0.0);
|
||||
let coeff = counts.values[index.pbs(src_partition)];
|
||||
if coeff != 0.0 {
|
||||
write!(f, "{add_plus}{coeff}¢Br[{src_partition}]")?;
|
||||
add_plus = " + ";
|
||||
}
|
||||
for dst_partition in 0..nb_partitions {
|
||||
let coeff = counts.values[index.keyswitch_to_big(src_partition, dst_partition)];
|
||||
if coeff != 0.0 {
|
||||
write!(f, "{add_plus}{coeff}¢FK[{src_partition}→{dst_partition}]")?;
|
||||
add_plus = " + ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for partition in 0..nb_partitions {
|
||||
assert!(counts.values[index.modulus_switching(partition)] == 0.0);
|
||||
}
|
||||
if add_plus.is_empty() {
|
||||
write!(f, "ZERO x ¢")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Complexity {
|
||||
pub fn of(counts: &OperationsCount) -> Self {
|
||||
Self {
|
||||
counts: counts.counts.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn complexity(&self, costs: &OperationsValue) -> f64 {
|
||||
f64_dot(&self.counts, costs)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
// TODO: move to cache with pareto check
|
||||
|
||||
use concrete_cpu_noise_model::gaussian_noise::conversion::modular_variance_to_variance;
|
||||
|
||||
// TODO: move to concrete-cpu
|
||||
use crate::optimization::decomposition::keyswitch::KsComplexityNoise;
|
||||
use crate::parameters::{GlweParameters, KsDecompositionParameters};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// output glwe is incorporated in KsComplexityNoise
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
pub struct FksComplexityNoise {
|
||||
// 1 -> 0
|
||||
// k1⋅N1>k0⋅N0, k0⋅N0≥k1⋅N1
|
||||
pub decomp: KsDecompositionParameters,
|
||||
pub noise: f64,
|
||||
pub complexity: f64,
|
||||
}
|
||||
|
||||
// Copy & paste from concrete-cpu
|
||||
const FFT_SCALING_WEIGHT: f64 = -2.577_224_94;
|
||||
fn fft_noise_variance_external_product_glwe(
|
||||
glwe_dimension: u64,
|
||||
polynomial_size: u64,
|
||||
log2_base: u64,
|
||||
level: u64,
|
||||
ciphertext_modulus_log: u32,
|
||||
) -> f64 {
|
||||
// https://github.com/zama-ai/concrete-optimizer/blob/prototype/python/optimizer/noise_formulas/bootstrap.py#L25
|
||||
let b = 2_f64.powi(log2_base as i32);
|
||||
let l = level as f64;
|
||||
let big_n = polynomial_size as f64;
|
||||
let k = glwe_dimension;
|
||||
assert!(k > 0, "k = {k}");
|
||||
|
||||
// 22 = 2 x 11, 11 = 64 -53
|
||||
let scale_margin = (1_u64 << 22) as f64;
|
||||
let res =
|
||||
f64::exp2(FFT_SCALING_WEIGHT) * scale_margin * l * b * b * big_n.powi(2) * (k as f64 + 1.);
|
||||
modular_variance_to_variance(res, ciphertext_modulus_log)
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
fn upper_k0(input_glwe: &GlweParameters, output_glwe: &GlweParameters) -> u64 {
|
||||
let k1 = input_glwe.glwe_dimension;
|
||||
let N1 = input_glwe.polynomial_size();
|
||||
let k0 = output_glwe.glwe_dimension;
|
||||
let N0 = output_glwe.polynomial_size();
|
||||
assert!(k1 * N1 >= k0 * N0);
|
||||
// candidate * N0 >= k1 * N1
|
||||
let f_upper_k0 = (k1 * N1) as f64 / N0 as f64;
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
let upper_k0 = f_upper_k0.ceil() as u64;
|
||||
upper_k0
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub fn complexity(input_glwe: &GlweParameters, output_glwe: &GlweParameters, level: u64) -> f64 {
|
||||
let k0 = output_glwe.glwe_dimension;
|
||||
let N0 = output_glwe.polynomial_size();
|
||||
let upper_k0 = upper_k0(input_glwe, output_glwe);
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
let log2_N0 = (N0 as f64).log2().ceil() as u64;
|
||||
let size0 = (k0 + 1) * N0 * log2_N0;
|
||||
let mul_count = size0 * upper_k0 * level;
|
||||
let add_count = size0 * (upper_k0 * level - 1);
|
||||
(add_count + mul_count) as f64
|
||||
}
|
||||
|
||||
#[allow(non_snake_case)]
|
||||
pub fn noise(
|
||||
ks: &KsComplexityNoise,
|
||||
input_glwe: &GlweParameters,
|
||||
output_glwe: &GlweParameters,
|
||||
) -> f64 {
|
||||
let N0 = output_glwe.polynomial_size();
|
||||
let upper_k0 = upper_k0(input_glwe, output_glwe);
|
||||
ks.noise(input_glwe.sample_extract_lwe_dimension())
|
||||
+ fft_noise_variance_external_product_glwe(
|
||||
upper_k0,
|
||||
N0,
|
||||
ks.decomp.log2_base,
|
||||
ks.decomp.level,
|
||||
64,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,122 @@
|
||||
use crate::noise_estimator::p_error::{combine_errors, repeat_p_error};
|
||||
use crate::optimization::dag::multi_parameters::variance_constraint::VarianceConstraint;
|
||||
use crate::optimization::dag::solo_key::analyze::p_error_from_relative_variance;
|
||||
use crate::utils::f64::f64_dot;
|
||||
|
||||
use super::operations_value::OperationsValue;
|
||||
use super::partitions::PartitionIndex;
|
||||
|
||||
pub struct Feasible {
|
||||
// TODO: move kappa here
|
||||
pub constraints: Vec<VarianceConstraint>,
|
||||
pub undominated_constraints: Vec<VarianceConstraint>,
|
||||
pub kappa: f64, // to convert variance to local probabilities
|
||||
pub global_p_error: Option<f64>,
|
||||
}
|
||||
|
||||
impl Feasible {
|
||||
pub fn of(constraints: &[VarianceConstraint], kappa: f64, global_p_error: Option<f64>) -> Self {
|
||||
let undominated_constraints = VarianceConstraint::remove_dominated(constraints);
|
||||
Self {
|
||||
kappa,
|
||||
constraints: constraints.into(),
|
||||
undominated_constraints,
|
||||
global_p_error,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn feasible(&self, operations_variance: &OperationsValue) -> bool {
|
||||
if self.global_p_error.is_none() {
|
||||
self.local_feasible(operations_variance)
|
||||
} else {
|
||||
self.global_feasible(operations_variance)
|
||||
}
|
||||
}
|
||||
|
||||
fn local_feasible(&self, operations_variance: &OperationsValue) -> bool {
|
||||
for constraint in &self.undominated_constraints {
|
||||
if f64_dot(operations_variance, &constraint.variance.coeffs)
|
||||
> constraint.safe_variance_bound
|
||||
{
|
||||
return false;
|
||||
};
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn global_feasible(&self, operations_variance: &OperationsValue) -> bool {
|
||||
self.global_p_error_with_cut(operations_variance, self.global_p_error.unwrap_or(1.0))
|
||||
.is_some()
|
||||
}
|
||||
|
||||
pub fn worst_constraint(
|
||||
&self,
|
||||
operations_variance: &OperationsValue,
|
||||
) -> (f64, f64, &VarianceConstraint) {
|
||||
let mut worst_constraint = &self.undominated_constraints[0];
|
||||
let mut worst_relative_variance = 0.0;
|
||||
let mut worst_variance = 0.0;
|
||||
for constraint in &self.undominated_constraints {
|
||||
let variance = f64_dot(operations_variance, &constraint.variance.coeffs);
|
||||
let relative_variance = variance / constraint.safe_variance_bound;
|
||||
if relative_variance > worst_relative_variance {
|
||||
worst_relative_variance = relative_variance;
|
||||
worst_variance = variance;
|
||||
worst_constraint = constraint;
|
||||
}
|
||||
}
|
||||
(worst_variance, worst_relative_variance, worst_constraint)
|
||||
}
|
||||
|
||||
pub fn p_error(&self, operations_variance: &OperationsValue) -> f64 {
|
||||
let (_, relative_variance, _) = self.worst_constraint(operations_variance);
|
||||
p_error_from_relative_variance(relative_variance, self.kappa)
|
||||
}
|
||||
|
||||
fn global_p_error_with_cut(
|
||||
&self,
|
||||
operations_variance: &OperationsValue,
|
||||
cut: f64,
|
||||
) -> Option<f64> {
|
||||
let mut global_p_error = 0.0;
|
||||
for constraint in &self.constraints {
|
||||
let variance = f64_dot(operations_variance, &constraint.variance.coeffs);
|
||||
let relative_variance = variance / constraint.safe_variance_bound;
|
||||
let p_error = p_error_from_relative_variance(relative_variance, self.kappa);
|
||||
global_p_error = combine_errors(
|
||||
global_p_error,
|
||||
repeat_p_error(p_error, constraint.nb_constraints),
|
||||
);
|
||||
if global_p_error > cut {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(global_p_error)
|
||||
}
|
||||
|
||||
pub fn global_p_error(&self, operations_variance: &OperationsValue) -> f64 {
|
||||
self.global_p_error_with_cut(operations_variance, 1.0)
|
||||
.unwrap_or(1.0)
|
||||
}
|
||||
|
||||
pub fn filter_constraints(&self, partition: PartitionIndex) -> Self {
|
||||
let nb_partitions = self.constraints[0].variance.nb_partitions();
|
||||
let touch_any_ks = |constraint: &VarianceConstraint, i| {
|
||||
let variance = &constraint.variance;
|
||||
variance.coeff_keyswitch_to_small(partition, i) > 0.0
|
||||
|| variance.coeff_keyswitch_to_small(i, partition) > 0.0
|
||||
|| variance.coeff_partition_keyswitch_to_big(partition, i) > 0.0
|
||||
|| variance.coeff_partition_keyswitch_to_big(i, partition) > 0.0
|
||||
};
|
||||
let partition_constraints: Vec<_> = self
|
||||
.constraints
|
||||
.iter()
|
||||
.filter(|constraint| {
|
||||
constraint.partition == partition
|
||||
|| (0..nb_partitions).any(|i| touch_any_ks(constraint, i))
|
||||
})
|
||||
.map(VarianceConstraint::clone)
|
||||
.collect();
|
||||
Self::of(&partition_constraints, self.kappa, self.global_p_error)
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,13 @@
|
||||
pub mod analyze;
|
||||
mod analyze;
|
||||
mod complexity;
|
||||
mod fast_keyswitch;
|
||||
mod feasible;
|
||||
pub mod keys_spec;
|
||||
pub(crate) mod operations_value;
|
||||
pub(crate) mod partitionning;
|
||||
pub(crate) mod partitions;
|
||||
pub(crate) mod precision_cut;
|
||||
pub(crate) mod symbolic_variance;
|
||||
pub(crate) mod union_find;
|
||||
pub(crate) mod variance_constraint;
|
||||
mod operations_value;
|
||||
pub mod optimize;
|
||||
mod partitionning;
|
||||
mod partitions;
|
||||
mod precision_cut;
|
||||
mod symbolic_variance;
|
||||
mod union_find;
|
||||
mod variance_constraint;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -2,6 +2,7 @@ use crate::dag::operator::{Operator, Precision};
|
||||
use crate::dag::unparametrized;
|
||||
use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PrecisionCut {
|
||||
// partition0 precision <= p_cut[0] < partition 1 precision <= p_cut[1] ...
|
||||
// precision are in the sens of Lut input precision and are sorted
|
||||
|
||||
@@ -0,0 +1,538 @@
|
||||
// part of optimize.rs
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#![allow(clippy::float_cmp)]
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use super::*;
|
||||
use crate::computing_cost::cpu::CpuComplexity;
|
||||
use crate::config;
|
||||
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape};
|
||||
use crate::dag::unparametrized;
|
||||
use crate::optimization::dag::solo_key;
|
||||
use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag};
|
||||
use crate::optimization::decomposition;
|
||||
|
||||
static SHARED_CACHES: Lazy<PersistDecompCaches> = Lazy::new(|| {
|
||||
let processing_unit = config::ProcessingUnit::Cpu;
|
||||
decomposition::cache(128, processing_unit, None, true)
|
||||
});
|
||||
|
||||
const _4_SIGMA: f64 = 0.000_063_342_483_999_973;
|
||||
|
||||
const LOW_PARTITION: PartitionIndex = 0;
|
||||
|
||||
fn optimize(
|
||||
dag: &unparametrized::OperationDag,
|
||||
p_cut: &Option<PrecisionCut>,
|
||||
default_partition: usize,
|
||||
) -> Option<Parameters> {
|
||||
let config = Config {
|
||||
security_level: 128,
|
||||
maximum_acceptable_error_probability: _4_SIGMA,
|
||||
ciphertext_modulus_log: 64,
|
||||
complexity_model: &CpuComplexity::default(),
|
||||
};
|
||||
|
||||
let search_space = SearchSpace::default_cpu();
|
||||
super::optimize(
|
||||
dag,
|
||||
config,
|
||||
&search_space,
|
||||
&SHARED_CACHES,
|
||||
p_cut,
|
||||
default_partition,
|
||||
)
|
||||
}
|
||||
|
||||
fn optimize_single(dag: &unparametrized::OperationDag) -> Option<Parameters> {
|
||||
optimize(dag, &Some(PrecisionCut { p_cut: vec![] }), LOW_PARTITION)
|
||||
}
|
||||
|
||||
fn equiv_single(dag: &unparametrized::OperationDag) -> Option<bool> {
|
||||
let sol_mono = solo_key::optimize::tests::optimize(dag);
|
||||
let sol_multi = optimize_single(dag);
|
||||
if sol_mono.best_solution.is_none() != sol_multi.is_none() {
|
||||
eprintln!("Not same feasibility");
|
||||
return Some(false);
|
||||
};
|
||||
if sol_multi.is_none() {
|
||||
return None;
|
||||
}
|
||||
let equiv = sol_mono.best_solution.unwrap().complexity
|
||||
== sol_multi.as_ref().unwrap().complexity;
|
||||
if !equiv {
|
||||
eprintln!("Not same complexity");
|
||||
eprintln!("Single: {:?}", sol_mono.best_solution.unwrap());
|
||||
eprintln!(
|
||||
"Multi: {:?}",
|
||||
sol_multi.clone().unwrap().complexity
|
||||
);
|
||||
eprintln!("Multi: {:?}", sol_multi.unwrap());
|
||||
}
|
||||
Some(equiv)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optimize_simple_parameter_v0_dag() {
|
||||
for precision in 1..11 {
|
||||
for manp in 1..25 {
|
||||
eprintln!("P M {precision} {manp}");
|
||||
let dag = v0_dag(0, precision, manp as f64);
|
||||
if let Some(equiv) = equiv_single(&dag) {
|
||||
assert!(equiv);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optimize_simple_parameter_rounded_lut_2_layers() {
|
||||
for accumulator_precision in 1..11 {
|
||||
for precision in 1..accumulator_precision {
|
||||
for manp in [1, 8, 16] {
|
||||
eprintln!("CASE {accumulator_precision} {precision} {manp}");
|
||||
let dag = v0_dag(0, precision, manp as f64);
|
||||
if let Some(equiv) = equiv_single(&dag) {
|
||||
assert!(equiv);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn equiv_2_single(
|
||||
dag_multi: &unparametrized::OperationDag,
|
||||
dag_1: &unparametrized::OperationDag,
|
||||
dag_2: &unparametrized::OperationDag,
|
||||
) -> Option<bool> {
|
||||
let precision_max = dag_multi.out_precisions.iter().copied().max().unwrap();
|
||||
let p_cut = Some(PrecisionCut {
|
||||
p_cut: vec![precision_max - 1],
|
||||
});
|
||||
eprintln!("{dag_multi}");
|
||||
let sol_single_1 = solo_key::optimize::tests::optimize(dag_1);
|
||||
let sol_single_2 = solo_key::optimize::tests::optimize(dag_2);
|
||||
let sol_multi = optimize(dag_multi, &p_cut, LOW_PARTITION);
|
||||
let sol_multi_1 = optimize(dag_1, &p_cut, LOW_PARTITION);
|
||||
let sol_multi_2 = optimize(dag_2, &p_cut, LOW_PARTITION);
|
||||
let feasible_1 = sol_single_1.best_solution.is_some();
|
||||
let feasible_2 = sol_single_2.best_solution.is_some();
|
||||
let feasible_multi = sol_multi.is_some();
|
||||
if (feasible_1 && feasible_2) != feasible_multi {
|
||||
eprintln!(
|
||||
"Not same feasibility {feasible_1} {feasible_2} {feasible_multi}"
|
||||
);
|
||||
return Some(false);
|
||||
}
|
||||
if sol_multi.is_none() {
|
||||
return None;
|
||||
}
|
||||
let sol_multi = sol_multi.unwrap();
|
||||
let sol_multi_1 = sol_multi_1.unwrap();
|
||||
let sol_multi_2 = sol_multi_2.unwrap();
|
||||
let cost_1 = sol_single_1.best_solution.unwrap().complexity;
|
||||
let cost_2 = sol_single_2.best_solution.unwrap().complexity;
|
||||
let cost_multi = sol_multi.complexity;
|
||||
let equiv =
|
||||
cost_1 + cost_2 == cost_multi
|
||||
&& cost_1 == sol_multi_1.complexity
|
||||
&& cost_2 == sol_multi_2.complexity
|
||||
&& sol_multi.micro_params.ks[0][0].unwrap().decomp ==
|
||||
sol_multi_1.micro_params.ks[0][0].unwrap().decomp
|
||||
&& sol_multi.micro_params.ks[1][1].unwrap().decomp ==
|
||||
sol_multi_2.micro_params.ks[0][0].unwrap().decomp
|
||||
;
|
||||
if !equiv {
|
||||
eprintln!("Not same complexity");
|
||||
eprintln!("Multi: {cost_multi:?}");
|
||||
eprintln!("Added Single: {:?}", cost_1 + cost_2);
|
||||
eprintln!("Single1: {:?}", sol_single_1.best_solution.unwrap());
|
||||
eprintln!("Single2: {:?}", sol_single_2.best_solution.unwrap());
|
||||
eprintln!("Multi: {sol_multi:?}");
|
||||
eprintln!("Multi1: {sol_multi_1:?}");
|
||||
eprintln!("Multi2: {sol_multi_2:?}");
|
||||
}
|
||||
Some(equiv)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optimize_multi_independant_2_precisions() {
|
||||
let sum_size = 0;
|
||||
for precision1 in 1..11 {
|
||||
for precision2 in (precision1 + 1)..11 {
|
||||
for manp in [1, 8, 16] {
|
||||
eprintln!("CASE {precision1} {precision2} {manp}");
|
||||
let noise_factor = manp as f64;
|
||||
let mut dag_multi = v0_dag(sum_size, precision1, noise_factor);
|
||||
add_v0_dag(&mut dag_multi, sum_size, precision2, noise_factor);
|
||||
let dag_1 = v0_dag(sum_size, precision1, noise_factor);
|
||||
let dag_2 = v0_dag(sum_size, precision2, noise_factor);
|
||||
if let Some(equiv) = equiv_2_single(&dag_multi, &dag_1, &dag_2) {
|
||||
assert!(equiv, "FAILED ON {precision1} {precision2} {manp}");
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dag_lut_sum_of_2_partitions_2_layer(precision1: u8, precision2: u8, final_lut: bool) -> unparametrized::OperationDag {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision1, Shape::number());
|
||||
let input2 = dag.add_input(precision2, Shape::number());
|
||||
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision1);
|
||||
let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, precision2);
|
||||
let lut1 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision2);
|
||||
let lut2 = dag.add_lut(lut2, FunctionTable::UNKWOWN, precision2);
|
||||
let dot = dag.add_dot([lut1, lut2], [1, 1]);
|
||||
if final_lut {
|
||||
_ = dag.add_lut(dot, FunctionTable::UNKWOWN, precision1);
|
||||
}
|
||||
dag
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optimize_multi_independant_2_partitions_finally_added() {
|
||||
let default_partition = 0;
|
||||
let single_precision_sol : Vec<_> = (0..11).map(
|
||||
|precision| {
|
||||
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false);
|
||||
optimize_single(&dag)
|
||||
}
|
||||
).collect();
|
||||
|
||||
for precision1 in 1..11 {
|
||||
for precision2 in (precision1 + 1)..11 {
|
||||
let p_cut = Some(PrecisionCut {
|
||||
p_cut: vec![precision1],
|
||||
});
|
||||
let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, false);
|
||||
let sol_1 = single_precision_sol[precision1 as usize].clone();
|
||||
let sol_2 = single_precision_sol[precision2 as usize].clone();
|
||||
let sol_multi = optimize(&dag_multi, &p_cut, LOW_PARTITION);
|
||||
let feasible_multi = sol_multi.is_some();
|
||||
let feasible_2 = sol_2.is_some();
|
||||
assert!(feasible_multi);
|
||||
assert!(feasible_2);
|
||||
let sol_multi = sol_multi.unwrap();
|
||||
let sol_1 = sol_1.unwrap();
|
||||
let sol_2 = sol_2.unwrap();
|
||||
assert!(sol_1.complexity < sol_multi.complexity);
|
||||
assert!(sol_multi.complexity < sol_2.complexity);
|
||||
eprintln!("{:?}", sol_multi.micro_params.fks);
|
||||
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
|
||||
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
|
||||
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
|
||||
assert!(sol_multi.macro_params[1] == sol_2.macro_params[0]);
|
||||
// The smallest the precision the more fks noise break partition independence
|
||||
if precision1 < 4 {
|
||||
assert!(
|
||||
sol_multi_without_fks / perfect_complexity < 1.1,
|
||||
"{precision1} {precision2}"
|
||||
);
|
||||
} else if precision1 <= 7 {
|
||||
assert!(
|
||||
sol_multi_without_fks / perfect_complexity < 1.03,
|
||||
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
|
||||
);
|
||||
} else {
|
||||
assert!(
|
||||
sol_multi_without_fks / perfect_complexity < 1.001,
|
||||
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optimize_multi_independant_2_partitions_finally_added_and_luted() {
|
||||
let default_partition = 0;
|
||||
let single_precision_sol : Vec<_> = (0..11).map(
|
||||
|precision| {
|
||||
let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true);
|
||||
optimize_single(&dag)
|
||||
}
|
||||
).collect();
|
||||
for precision1 in 1..11 {
|
||||
for precision2 in (precision1 + 1)..11 {
|
||||
let p_cut = Some(PrecisionCut {
|
||||
p_cut: vec![precision1],
|
||||
});
|
||||
let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, true);
|
||||
let sol_1 = single_precision_sol[precision1 as usize].clone();
|
||||
let sol_2 = single_precision_sol[precision2 as usize].clone();
|
||||
let sol_multi = optimize(&dag_multi, &p_cut, 0);
|
||||
let feasible_multi = sol_multi.is_some();
|
||||
let feasible_2 = sol_2.is_some();
|
||||
assert!(feasible_multi);
|
||||
assert!(feasible_2);
|
||||
let sol_multi = sol_multi.unwrap();
|
||||
let sol_1 = sol_1.unwrap();
|
||||
let sol_2 = sol_2.unwrap();
|
||||
// The smallest the precision the more fks noise dominate
|
||||
assert!(sol_1.complexity < sol_multi.complexity);
|
||||
assert!(sol_multi.complexity < sol_2.complexity);
|
||||
let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity;
|
||||
let sol_multi_without_fks = sol_multi.complexity - fks_complexity;
|
||||
let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0;
|
||||
let relative_degradation = sol_multi_without_fks / perfect_complexity;
|
||||
if precision1 < 4 {
|
||||
assert!(
|
||||
relative_degradation < 1.2,
|
||||
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
|
||||
);
|
||||
} else if precision1 <= 7 {
|
||||
assert!(
|
||||
relative_degradation < 1.19,
|
||||
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
|
||||
);
|
||||
} else {
|
||||
assert!(
|
||||
relative_degradation < 1.15,
|
||||
"{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn optimize_rounded(dag: &unparametrized::OperationDag) -> Option<Parameters> {
|
||||
let p_cut = Some(PrecisionCut { p_cut: vec![1] });
|
||||
let default_partition = 0;
|
||||
optimize(dag, &p_cut, default_partition)
|
||||
}
|
||||
|
||||
fn dag_rounded_lut_2_layers(
|
||||
accumulator_precision: usize,
|
||||
precision: usize,
|
||||
) -> unparametrized::OperationDag {
|
||||
let out_precision = accumulator_precision as u8;
|
||||
let rounded_precision = precision as u8;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision as u8, Shape::number());
|
||||
let rounded1 = dag.add_expanded_rounded_lut(
|
||||
input1,
|
||||
FunctionTable::UNKWOWN,
|
||||
rounded_precision,
|
||||
out_precision,
|
||||
);
|
||||
let rounded2 = dag.add_expanded_rounded_lut(
|
||||
rounded1,
|
||||
FunctionTable::UNKWOWN,
|
||||
rounded_precision,
|
||||
out_precision,
|
||||
);
|
||||
let _rounded3 =
|
||||
dag.add_expanded_rounded_lut(rounded2, FunctionTable::UNKWOWN, rounded_precision, 1);
|
||||
dag
|
||||
}
|
||||
|
||||
fn test_optimize_v3_expanded_round(precision_acc: usize, precision_tlu: usize, minimal_speedup: f64) {
|
||||
let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu);
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
|
||||
let sol = optimize_rounded(&dag).unwrap();
|
||||
let speedup = sol_mono.complexity / sol.complexity;
|
||||
assert!(speedup >= minimal_speedup,
|
||||
"Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}"
|
||||
);
|
||||
let expected_ks = [
|
||||
[true, true], // KS[0], KS[0->1]
|
||||
[false, true],// KS[1]
|
||||
];
|
||||
let expected_fks = [
|
||||
[false, false],
|
||||
[true, false], // FKS[1->0]
|
||||
];
|
||||
for (src, dst) in cross_partition(2) {
|
||||
assert!(sol.micro_params.ks[src][dst].is_some() == expected_ks[src][dst]);
|
||||
assert!(sol.micro_params.fks[src][dst].is_some() == expected_fks[src][dst]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimize_v3_expanded_round_16_8() {
|
||||
test_optimize_v3_expanded_round(16, 8, 5.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimize_v3_expanded_round_16_6() {
|
||||
test_optimize_v3_expanded_round(16, 6, 3.3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optimize_v3_direct_round() {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(16, Shape::number());
|
||||
_ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16);
|
||||
let sol = optimize_rounded(&dag).unwrap();
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
|
||||
let minimal_speedup = 8.6;
|
||||
let speedup = sol_mono.complexity / sol.complexity;
|
||||
assert!(speedup >= minimal_speedup,
|
||||
"Speedup {speedup} smaller than {minimal_speedup}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn optimize_sign_extract() {
|
||||
let precision = 8;
|
||||
let high_precision = 16;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let complexity = LevelledComplexity::ZERO;
|
||||
let free_small_input1 = dag.add_input(precision, Shape::number());
|
||||
let small_input1 = dag.add_lut(free_small_input1, FunctionTable::UNKWOWN, precision);
|
||||
let small_input1 = dag.add_lut(small_input1, FunctionTable::UNKWOWN, high_precision);
|
||||
let input1 = dag.add_levelled_op(
|
||||
[small_input1],
|
||||
complexity,
|
||||
1.0,
|
||||
Shape::vector(1_000_000),
|
||||
"comment",
|
||||
);
|
||||
let rounded1 = dag.add_expanded_round(input1, 1);
|
||||
let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1);
|
||||
let sol = optimize_rounded(&dag).unwrap();
|
||||
let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap();
|
||||
let speedup = sol_mono.complexity / sol.complexity;
|
||||
let minimal_speedup = 80.0;
|
||||
assert!(speedup >= minimal_speedup,
|
||||
"Speedup {speedup} smaller than {minimal_speedup}"
|
||||
);
|
||||
}
|
||||
|
||||
fn test_partition_chain(decreasing: bool) {
|
||||
// tlu chain with decreasing precision (decreasing partition index)
|
||||
// check that increasing partitionning gaves faster solutions
|
||||
// check solution has the right structure
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let min_precision = 6;
|
||||
let max_precision = 8;
|
||||
let mut input_precisions : Vec<_> = (min_precision..=max_precision).collect();
|
||||
if decreasing {
|
||||
input_precisions.reverse();
|
||||
}
|
||||
let mut lut_input = dag.add_input(input_precisions[0], Shape::number());
|
||||
for &out_precision in &input_precisions {
|
||||
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
|
||||
}
|
||||
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *input_precisions.last().unwrap());
|
||||
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision);
|
||||
let mut p_cut = PrecisionCut { p_cut:vec![] };
|
||||
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
|
||||
assert!(sol.macro_params.len() == 1);
|
||||
let mut complexity = sol.complexity;
|
||||
for &out_precision in &input_precisions {
|
||||
if out_precision == max_precision {
|
||||
// There is nothing to cut above max_precision
|
||||
continue;
|
||||
}
|
||||
p_cut.p_cut.push(out_precision);
|
||||
p_cut.p_cut.sort_unstable();
|
||||
eprintln!("PCUT {p_cut}");
|
||||
let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap();
|
||||
let nb_partitions = sol.macro_params.len();
|
||||
assert!(nb_partitions == (p_cut.p_cut.len() + 1),
|
||||
"bad nb partitions {} {p_cut}", sol.macro_params.len());
|
||||
assert!(sol.complexity < complexity,
|
||||
"{} < {complexity} {out_precision} / {max_precision}", sol.complexity);
|
||||
for (src, dst) in cross_partition(nb_partitions) {
|
||||
let ks = sol.micro_params.ks[src][dst];
|
||||
eprintln!("{} {src} {dst}", ks.is_some());
|
||||
let expected_ks =
|
||||
(!decreasing || src == dst + 1)
|
||||
&& (decreasing || src + 1 == dst)
|
||||
|| (src == dst && (src == 0 || src == nb_partitions - 1))
|
||||
;
|
||||
assert!(ks.is_some() == expected_ks, "{:?} {:?}", ks.is_some(), expected_ks);
|
||||
let fks = sol.micro_params.fks[src][dst];
|
||||
assert!(fks.is_none());
|
||||
}
|
||||
complexity = sol.complexity;
|
||||
}
|
||||
let sol = optimize(&dag, &None, 0);
|
||||
assert!(sol.unwrap().complexity == complexity);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partition_decreasing_chain() {
|
||||
test_partition_chain(true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partition_increasing_chain() {
|
||||
test_partition_chain(true);
|
||||
}
|
||||
|
||||
const MAX_WEIGHT: &[u64] = &[
|
||||
// max v0 weight for each precision
|
||||
1_073_741_824,
|
||||
1_073_741_824, // 2**30, 1b
|
||||
536_870_912, // 2**29, 2b
|
||||
268_435_456, // 2**28, 3b
|
||||
67_108_864, // 2**26, 4b
|
||||
16_777_216, // 2**24, 5b
|
||||
4_194_304, // 2**22, 6b
|
||||
1_048_576, // 2**20, 7b
|
||||
262_144, // 2**18, 8b
|
||||
65_536, // 2**16, 9b
|
||||
16384, // 2**14, 10b
|
||||
2048, // 2**11, 11b
|
||||
];
|
||||
|
||||
#[test]
|
||||
fn test_independant_partitions_non_feasible_single_params() {
|
||||
// generate hard circuit, non feasible with single parameters
|
||||
// composed of independant partitions so we know the optimal result
|
||||
let precisions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
|
||||
let sum_size = 0;
|
||||
let noise_factor = MAX_WEIGHT[precisions[0]] as f64;
|
||||
let mut dag = v0_dag(sum_size, precisions[0] as u64, noise_factor);
|
||||
let sol_single = optimize_single(&dag);
|
||||
let mut optimal_complexity = sol_single.as_ref().unwrap().complexity;
|
||||
let mut optimal_p_error = sol_single.unwrap().p_error;
|
||||
for &out_precision in &precisions[1..] {
|
||||
let noise_factor = MAX_WEIGHT[out_precision] as f64;
|
||||
add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor);
|
||||
let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor));
|
||||
optimal_complexity += sol_single.as_ref().unwrap().complexity;
|
||||
optimal_p_error += sol_single.as_ref().unwrap().p_error;
|
||||
}
|
||||
// check non feasible in single
|
||||
let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution;
|
||||
assert!(sol_single.is_none());
|
||||
// solves in multi
|
||||
let sol = optimize(&dag, &None, 0);
|
||||
assert!(sol.is_some());
|
||||
let sol = sol.unwrap();
|
||||
// check optimality
|
||||
assert!(sol.complexity / optimal_complexity < 1.0 + f64::EPSILON);
|
||||
assert!(sol.p_error / optimal_p_error < 1.0 + f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chained_partitions_non_feasible_single_params() {
|
||||
// generate hard circuit, non feasible with single parameters
|
||||
let precisions = [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11];
|
||||
// Note: reversing chain have issues for connecting lower bits to 7 bits, there may be no feasible solution
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let mut lut_input = dag.add_input(precisions[0], Shape::number());
|
||||
for out_precision in precisions {
|
||||
let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64;
|
||||
lut_input = dag.add_levelled_op([lut_input], LevelledComplexity::ZERO, noise_factor, Shape::number(), "");
|
||||
lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision);
|
||||
}
|
||||
_ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *precisions.last().unwrap());
|
||||
let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution;
|
||||
assert!(sol_single.is_none());
|
||||
let sol = optimize(&dag, &None, 0);
|
||||
assert!(sol.is_some());
|
||||
}
|
||||
}
|
||||
@@ -504,7 +504,7 @@ fn peak_relative_variance(
|
||||
(max_relative_var, safe_noise)
|
||||
}
|
||||
|
||||
fn p_error_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 {
|
||||
pub fn p_error_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 {
|
||||
let sigma_scale = kappa / relative_variance.sqrt();
|
||||
error::error_probability_of_sigma_scale(sigma_scale)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use concrete_cpu_noise_model::gaussian_noise::noise::modulus_switching::estimate
|
||||
use super::analyze;
|
||||
use crate::dag::operator::{LevelledComplexity, Precision};
|
||||
use crate::dag::unparametrized;
|
||||
use crate::dag::unparametrized::OperationDag;
|
||||
use crate::noise_estimator::error;
|
||||
use crate::optimization::atomic_pattern::{
|
||||
OptimizationDecompositionsConsts, OptimizationState, Solution,
|
||||
@@ -386,6 +387,27 @@ pub fn optimize(
|
||||
state
|
||||
}
|
||||
|
||||
pub fn add_v0_dag(dag: &mut OperationDag, sum_size: u64, precision: u64, noise_factor: f64) {
|
||||
use crate::dag::operator::{FunctionTable, Shape};
|
||||
let same_scale_manp = 1.0;
|
||||
let manp = noise_factor;
|
||||
let out_shape = &Shape::number();
|
||||
let complexity = LevelledComplexity::ADDITION * sum_size;
|
||||
let comment = "dot";
|
||||
let precision = precision as Precision;
|
||||
let input1 = dag.add_input(precision, out_shape);
|
||||
let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment);
|
||||
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
|
||||
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
|
||||
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
|
||||
}
|
||||
|
||||
pub fn v0_dag(sum_size: u64, precision: u64, noise_factor: f64) -> OperationDag {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
add_v0_dag(&mut dag, sum_size, precision, noise_factor);
|
||||
dag
|
||||
}
|
||||
|
||||
pub fn optimize_v0(
|
||||
sum_size: u64,
|
||||
precision: u64,
|
||||
@@ -394,19 +416,7 @@ pub fn optimize_v0(
|
||||
search_space: &SearchSpace,
|
||||
cache: &PersistDecompCaches,
|
||||
) -> OptimizationState {
|
||||
use crate::dag::operator::{FunctionTable, Shape};
|
||||
let same_scale_manp = 0.0;
|
||||
let manp = noise_factor;
|
||||
let out_shape = &Shape::number();
|
||||
let complexity = LevelledComplexity::ADDITION * sum_size;
|
||||
let comment = "dot";
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let precision = precision as Precision;
|
||||
let input1 = dag.add_input(precision, out_shape);
|
||||
let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment);
|
||||
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision);
|
||||
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
|
||||
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision);
|
||||
let dag = v0_dag(sum_size, precision, noise_factor);
|
||||
let mut state = optimize(&dag, config, search_space, cache);
|
||||
if let Some(sol) = &mut state.best_solution {
|
||||
sol.complexity /= 2.0;
|
||||
@@ -415,7 +425,7 @@ pub fn optimize_v0(
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
pub(crate) mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use once_cell::sync::Lazy;
|
||||
@@ -456,7 +466,7 @@ mod tests {
|
||||
decomposition::cache(128, processing_unit, None, true)
|
||||
});
|
||||
|
||||
fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
|
||||
pub fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState {
|
||||
let config = Config {
|
||||
security_level: 128,
|
||||
maximum_acceptable_error_probability: _4_SIGMA,
|
||||
|
||||
@@ -102,8 +102,12 @@ pub fn pareto_quantities(
|
||||
quantities
|
||||
}
|
||||
|
||||
pub fn lowest_noise(quantities: &[CmuxComplexityNoise]) -> CmuxComplexityNoise {
|
||||
quantities[quantities.len() - 1]
|
||||
}
|
||||
|
||||
pub fn lowest_noise_br(quantities: &[CmuxComplexityNoise], in_lwe_dim: u64) -> f64 {
|
||||
quantities[quantities.len() - 1].noise_br(in_lwe_dim)
|
||||
lowest_noise(quantities).noise_br(in_lwe_dim)
|
||||
}
|
||||
|
||||
pub fn lowest_complexity_br(quantities: &[CmuxComplexityNoise], in_lwe_dim: u64) -> f64 {
|
||||
|
||||
@@ -17,7 +17,7 @@ use crate::optimization::decomposition::keyswitch::KsComplexityNoise;
|
||||
use crate::optimization::decomposition::pp_switch::PpSwitchComplexityNoise;
|
||||
use crate::optimization::decomposition::PersistDecompCaches;
|
||||
use crate::parameters::{BrDecompositionParameters, GlweParameters};
|
||||
use crate::utils::max::f64_max;
|
||||
use crate::utils::f64::f64_max;
|
||||
use crate::utils::square;
|
||||
|
||||
pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 {
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
pub fn f64_max(values: &[f64], default: f64) -> f64 {
|
||||
values.iter().copied().reduce(f64::max).unwrap_or(default)
|
||||
}
|
||||
|
||||
pub fn f64_dot(a: &[f64], b: &[f64]) -> f64 {
|
||||
let mut sum = 0.0;
|
||||
for i in 0..a.len() {
|
||||
sum += a[i] * b[i];
|
||||
}
|
||||
sum
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
pub mod cache;
|
||||
pub mod f64;
|
||||
pub mod hasher_builder;
|
||||
pub mod max;
|
||||
|
||||
pub fn square<V>(v: V) -> V
|
||||
where
|
||||
|
||||
Reference in New Issue
Block a user