From 3e05aa47a4921d0cc6569f732e6a04afb7418275 Mon Sep 17 00:00:00 2001 From: rudy Date: Thu, 23 Mar 2023 11:55:18 +0100 Subject: [PATCH] feat(optimizer): multiparameters optimization --- .../concrete-optimizer/src/lib.rs | 1 + .../dag/multi_parameters/analyze.rs | 164 ++- .../dag/multi_parameters/complexity.rs | 77 ++ .../dag/multi_parameters/fast_keyswitch.rs | 87 ++ .../dag/multi_parameters/feasible.rs | 122 ++ .../optimization/dag/multi_parameters/mod.rs | 20 +- .../dag/multi_parameters/optimize.rs | 1007 +++++++++++++++++ .../dag/multi_parameters/precision_cut.rs | 1 + .../multi_parameters/tests/test_optimize.rs | 538 +++++++++ .../src/optimization/dag/solo_key/analyze.rs | 2 +- .../src/optimization/dag/solo_key/optimize.rs | 40 +- .../src/optimization/decomposition/cmux.rs | 6 +- .../wop_atomic_pattern/optimize.rs | 2 +- .../concrete-optimizer/src/utils/f64.rs | 11 + .../concrete-optimizer/src/utils/mod.rs | 2 +- 15 files changed, 2042 insertions(+), 38 deletions(-) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/fast_keyswitch.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/utils/f64.rs diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/lib.rs b/compilers/concrete-optimizer/concrete-optimizer/src/lib.rs index f02073ffa..a86e3c0e8 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/lib.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/lib.rs @@ -4,6 +4,7 @@ #![allow(clippy::cast_lossless)] #![allow(clippy::cast_precision_loss)] // u64 to f64 #![allow(clippy::cast_possible_truncation)] // u64 to usize +#![allow(clippy::question_mark)] #![allow(clippy::match_wildcard_for_single_variants)] #![allow(clippy::manual_range_contains)] #![allow(clippy::missing_panics_doc)] diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index 7f71b6b41..f45933064 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use crate::dag::operator::{ dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, }; @@ -14,6 +16,8 @@ use crate::optimization::dag::solo_key::analyze::{ extra_final_values_to_check, first, safe_noise_bound, }; +use super::complexity::OperationsCount; +use super::operations_value::OperationsValue; use super::variance_constraint::VarianceConstraint; use crate::utils::square; @@ -35,31 +39,37 @@ pub struct AnalyzedDag { pub variance_constraints: Vec, // Undominated variance constraints pub undominated_variance_constraints: Vec, + pub operations_count_per_instrs: Vec, + pub operations_count: OperationsCount, + pub p_cut: PrecisionCut, } pub fn analyze( dag: &unparametrized::OperationDag, noise_config: &NoiseBoundConfig, - p_cut: &PrecisionCut, + p_cut: &Option, default_partition: PartitionIndex, ) -> AnalyzedDag { - assert!( - p_cut.p_cut.len() <= 1, - "Multi-parameter can only be used 0 or 1 precision cut" - ); let dag = expand_round(dag); let levelled_complexity = LevelledComplexity::ZERO; // The precision cut is chosen to work well with rounded pbs // Note: this is temporary - let partitions = partitionning_with_preferred(&dag, p_cut, default_partition); + #[allow(clippy::option_if_let_else)] + let p_cut = match p_cut { + Some(p_cut) => p_cut.clone(), + None => maximal_p_cut(&dag), + }; + let partitions = partitionning_with_preferred(&dag, &p_cut, default_partition); let instrs_partition = partitions.instrs_partition; let nb_partitions = partitions.nb_partitions; let out_variances = out_variances(&dag, nb_partitions, &instrs_partition); - let variance_constraints = collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances); let undominated_variance_constraints = VarianceConstraint::remove_dominated(&variance_constraints); + let operations_count_per_instrs = + collect_operations_count(&dag, nb_partitions, &instrs_partition); + let operations_count = sum_operations_count(&operations_count_per_instrs); AnalyzedDag { operators: dag.operators, nb_partitions, @@ -68,6 +78,9 @@ pub fn analyze( levelled_complexity, variance_constraints, undominated_variance_constraints, + operations_count_per_instrs, + operations_count, + p_cut, } } @@ -240,8 +253,65 @@ fn collect_all_variance_constraints( constraints } +#[allow(clippy::match_on_vec_items)] +fn operations_counts( + dag: &unparametrized::OperationDag, + op: &unparametrized::UnparameterizedOperator, + nb_partitions: usize, + instr_partition: &InstructionPartition, +) -> OperationsCount { + let mut counts = OperationsValue::zero(nb_partitions); + if let Op::Lut { input, .. } = op { + let partition = instr_partition.instruction_partition; + let nb_lut = dag.out_shapes[input.i].flat_size() as f64; + let src_partition = match instr_partition.inputs_transition[0] { + Some(Transition::Internal { src_partition }) => src_partition, + Some(Transition::Additional { .. }) | None => partition, + }; + *counts.ks(src_partition, partition) += nb_lut; + *counts.pbs(partition) += nb_lut; + for &conv_partition in &instr_partition.alternative_output_representation { + *counts.fks(partition, conv_partition) += nb_lut; + } + } + OperationsCount { counts } +} + +fn collect_operations_count( + dag: &unparametrized::OperationDag, + nb_partitions: usize, + instrs_partition: &[InstructionPartition], +) -> Vec { + dag.operators + .iter() + .enumerate() + .map(|(i, op)| operations_counts(dag, op, nb_partitions, &instrs_partition[i])) + .collect() +} + +fn sum_operations_count(all_counts: &[OperationsCount]) -> OperationsCount { + let mut sum_counts = OperationsValue::zero(all_counts[0].counts.nb_partitions()); + for OperationsCount { counts } in all_counts { + sum_counts += counts; + } + OperationsCount { counts: sum_counts } +} + +fn maximal_p_cut(dag: &unparametrized::OperationDag) -> PrecisionCut { + let mut lut_in_precisions: HashSet<_> = HashSet::default(); + for op in &dag.operators { + if let Op::Lut { input, .. } = op { + _ = lut_in_precisions.insert(dag.out_precisions[input.i]); + } + } + let mut p_cut: Vec<_> = lut_in_precisions.iter().copied().collect(); + p_cut.sort_unstable(); + _ = p_cut.pop(); + PrecisionCut { p_cut } +} + #[cfg(test)] -mod tests { +pub mod tests { use super::*; use crate::dag::operator::{FunctionTable, Shape}; use crate::dag::unparametrized; @@ -250,16 +320,16 @@ mod tests { }; use crate::optimization::dag::solo_key::analyze::tests::CONFIG; - fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag { + pub fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag { analyze_with_preferred(dag, LOW_PRECISION_PARTITION) } - fn analyze_with_preferred( + pub fn analyze_with_preferred( dag: &unparametrized::OperationDag, default_partition: PartitionIndex, ) -> AnalyzedDag { let p_cut = PrecisionCut { p_cut: vec![2] }; - super::analyze(dag, &CONFIG, &p_cut, default_partition) + super::analyze(dag, &CONFIG, &Some(p_cut), default_partition) } #[allow(clippy::float_cmp)] @@ -658,4 +728,76 @@ mod tests { ); } } + + #[test] + fn test_rounded_v3_classic_first_layer_second_layer_complexity() { + let acc_precision = 7; + let precision = 4; + let mut dag = unparametrized::OperationDag::new(); + let free_input1 = dag.add_input(precision, Shape::number()); + let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); + let rounded1 = dag.add_expanded_round(input1, precision); + let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); + let old_dag = dag; + let dag = analyze(&old_dag); + // Partition 0 + let instrs_counts: Vec<_> = dag + .operations_count_per_instrs + .iter() + .map(OperationsCount::to_string) + .collect(); + #[rustfmt::skip] // nighlty and stable are inconsitent here + let expected_counts = [ + "ZERO x ¢", // free_input1 + "1¢K[1] + 1¢Br[1] + 1¢FK[1→0]", // input1 + "ZERO x ¢", // shift + "ZERO x ¢", // cast + "1¢K[0] + 1¢Br[0]", // extract (lut) + "ZERO x ¢", // erase (dot) + "ZERO x ¢", // cast + "ZERO x ¢", // shift + "ZERO x ¢", // cast + "1¢K[0] + 1¢Br[0]", // extract (lut) + "ZERO x ¢", // erase (dot) + "ZERO x ¢", // cast + "ZERO x ¢", // shift + "ZERO x ¢", // cast + "1¢K[0] + 1¢Br[0]", // extract (lut) + "ZERO x ¢", // erase (dot) + "ZERO x ¢", // cast + "1¢K[0→1] + 1¢Br[1]", // _lut1 + ]; + for ((c, ec), op) in instrs_counts.iter().zip(expected_counts).zip(dag.operators) { + assert!( + c == ec, + "\nBad count on {op}\nActual: {c}\nTruth : {ec} (expected)\n" + ); + } + eprintln!("{}", dag.operations_count); + assert!( + format!("{}", dag.operations_count) + == "3¢K[0] + 1¢K[0→1] + 1¢K[1] + 3¢Br[0] + 2¢Br[1] + 1¢FK[1→0]" + ); + } + + #[test] + fn test_high_partition_number() { + let mut dag = unparametrized::OperationDag::new(); + let max_precision = 10; + let mut lut_input = dag.add_input(max_precision, Shape::number()); + let mut p_cut = vec![]; + for out_precision in (1..=max_precision).rev() { + lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); + } + _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, 1); + for out_precision in 1..max_precision { + p_cut.push(out_precision); + } + eprintln!("{}", dag.dump()); + let p_cut = PrecisionCut { p_cut }; + eprintln!("{p_cut}"); + let p_cut = Some(p_cut); + let dag = super::analyze(&dag, &CONFIG, &p_cut, LOW_PRECISION_PARTITION); + assert!(dag.nb_partitions == p_cut.unwrap().p_cut.len() + 1); + } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs new file mode 100644 index 000000000..68490701c --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/complexity.rs @@ -0,0 +1,77 @@ +use std::fmt; + +use crate::utils::f64::f64_dot; + +use super::operations_value::OperationsValue; + +#[derive(Clone, Debug)] +pub struct OperationsCount { + pub counts: OperationsValue, +} + +#[derive(Clone, Debug)] +pub struct OperationsCost { + pub costs: OperationsValue, +} + +#[derive(Clone, Debug)] +pub struct Complexity { + pub counts: OperationsValue, +} + +impl fmt::Display for OperationsCount { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut add_plus = ""; + let counts = &self.counts; + let nb_partitions = counts.nb_partitions(); + let index = counts.index; + for src_partition in 0..nb_partitions { + for dst_partition in 0..nb_partitions { + let coeff = counts.values[index.keyswitch_to_small(src_partition, dst_partition)]; + if coeff != 0.0 { + if src_partition == dst_partition { + write!(f, "{add_plus}{coeff}¢K[{src_partition}]")?; + } else { + write!(f, "{add_plus}{coeff}¢K[{src_partition}→{dst_partition}]")?; + } + add_plus = " + "; + } + } + } + for src_partition in 0..nb_partitions { + assert!(counts.values[index.input(src_partition)] == 0.0); + let coeff = counts.values[index.pbs(src_partition)]; + if coeff != 0.0 { + write!(f, "{add_plus}{coeff}¢Br[{src_partition}]")?; + add_plus = " + "; + } + for dst_partition in 0..nb_partitions { + let coeff = counts.values[index.keyswitch_to_big(src_partition, dst_partition)]; + if coeff != 0.0 { + write!(f, "{add_plus}{coeff}¢FK[{src_partition}→{dst_partition}]")?; + add_plus = " + "; + } + } + } + + for partition in 0..nb_partitions { + assert!(counts.values[index.modulus_switching(partition)] == 0.0); + } + if add_plus.is_empty() { + write!(f, "ZERO x ¢")?; + } + Ok(()) + } +} + +impl Complexity { + pub fn of(counts: &OperationsCount) -> Self { + Self { + counts: counts.counts.clone(), + } + } + + pub fn complexity(&self, costs: &OperationsValue) -> f64 { + f64_dot(&self.counts, costs) + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/fast_keyswitch.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/fast_keyswitch.rs new file mode 100644 index 000000000..6cc923643 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/fast_keyswitch.rs @@ -0,0 +1,87 @@ +// TODO: move to cache with pareto check + +use concrete_cpu_noise_model::gaussian_noise::conversion::modular_variance_to_variance; + +// TODO: move to concrete-cpu +use crate::optimization::decomposition::keyswitch::KsComplexityNoise; +use crate::parameters::{GlweParameters, KsDecompositionParameters}; + +use serde::{Deserialize, Serialize}; + +// output glwe is incorporated in KsComplexityNoise +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct FksComplexityNoise { + // 1 -> 0 + // k1⋅N1>k0⋅N0, k0⋅N0≥k1⋅N1 + pub decomp: KsDecompositionParameters, + pub noise: f64, + pub complexity: f64, +} + +// Copy & paste from concrete-cpu +const FFT_SCALING_WEIGHT: f64 = -2.577_224_94; +fn fft_noise_variance_external_product_glwe( + glwe_dimension: u64, + polynomial_size: u64, + log2_base: u64, + level: u64, + ciphertext_modulus_log: u32, +) -> f64 { + // https://github.com/zama-ai/concrete-optimizer/blob/prototype/python/optimizer/noise_formulas/bootstrap.py#L25 + let b = 2_f64.powi(log2_base as i32); + let l = level as f64; + let big_n = polynomial_size as f64; + let k = glwe_dimension; + assert!(k > 0, "k = {k}"); + + // 22 = 2 x 11, 11 = 64 -53 + let scale_margin = (1_u64 << 22) as f64; + let res = + f64::exp2(FFT_SCALING_WEIGHT) * scale_margin * l * b * b * big_n.powi(2) * (k as f64 + 1.); + modular_variance_to_variance(res, ciphertext_modulus_log) +} + +#[allow(non_snake_case)] +fn upper_k0(input_glwe: &GlweParameters, output_glwe: &GlweParameters) -> u64 { + let k1 = input_glwe.glwe_dimension; + let N1 = input_glwe.polynomial_size(); + let k0 = output_glwe.glwe_dimension; + let N0 = output_glwe.polynomial_size(); + assert!(k1 * N1 >= k0 * N0); + // candidate * N0 >= k1 * N1 + let f_upper_k0 = (k1 * N1) as f64 / N0 as f64; + #[allow(clippy::cast_sign_loss)] + let upper_k0 = f_upper_k0.ceil() as u64; + upper_k0 +} + +#[allow(non_snake_case)] +pub fn complexity(input_glwe: &GlweParameters, output_glwe: &GlweParameters, level: u64) -> f64 { + let k0 = output_glwe.glwe_dimension; + let N0 = output_glwe.polynomial_size(); + let upper_k0 = upper_k0(input_glwe, output_glwe); + #[allow(clippy::cast_sign_loss)] + let log2_N0 = (N0 as f64).log2().ceil() as u64; + let size0 = (k0 + 1) * N0 * log2_N0; + let mul_count = size0 * upper_k0 * level; + let add_count = size0 * (upper_k0 * level - 1); + (add_count + mul_count) as f64 +} + +#[allow(non_snake_case)] +pub fn noise( + ks: &KsComplexityNoise, + input_glwe: &GlweParameters, + output_glwe: &GlweParameters, +) -> f64 { + let N0 = output_glwe.polynomial_size(); + let upper_k0 = upper_k0(input_glwe, output_glwe); + ks.noise(input_glwe.sample_extract_lwe_dimension()) + + fft_noise_variance_external_product_glwe( + upper_k0, + N0, + ks.decomp.log2_base, + ks.decomp.level, + 64, + ) +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs new file mode 100644 index 000000000..fe67a1426 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/feasible.rs @@ -0,0 +1,122 @@ +use crate::noise_estimator::p_error::{combine_errors, repeat_p_error}; +use crate::optimization::dag::multi_parameters::variance_constraint::VarianceConstraint; +use crate::optimization::dag::solo_key::analyze::p_error_from_relative_variance; +use crate::utils::f64::f64_dot; + +use super::operations_value::OperationsValue; +use super::partitions::PartitionIndex; + +pub struct Feasible { + // TODO: move kappa here + pub constraints: Vec, + pub undominated_constraints: Vec, + pub kappa: f64, // to convert variance to local probabilities + pub global_p_error: Option, +} + +impl Feasible { + pub fn of(constraints: &[VarianceConstraint], kappa: f64, global_p_error: Option) -> Self { + let undominated_constraints = VarianceConstraint::remove_dominated(constraints); + Self { + kappa, + constraints: constraints.into(), + undominated_constraints, + global_p_error, + } + } + + pub fn feasible(&self, operations_variance: &OperationsValue) -> bool { + if self.global_p_error.is_none() { + self.local_feasible(operations_variance) + } else { + self.global_feasible(operations_variance) + } + } + + fn local_feasible(&self, operations_variance: &OperationsValue) -> bool { + for constraint in &self.undominated_constraints { + if f64_dot(operations_variance, &constraint.variance.coeffs) + > constraint.safe_variance_bound + { + return false; + }; + } + true + } + + fn global_feasible(&self, operations_variance: &OperationsValue) -> bool { + self.global_p_error_with_cut(operations_variance, self.global_p_error.unwrap_or(1.0)) + .is_some() + } + + pub fn worst_constraint( + &self, + operations_variance: &OperationsValue, + ) -> (f64, f64, &VarianceConstraint) { + let mut worst_constraint = &self.undominated_constraints[0]; + let mut worst_relative_variance = 0.0; + let mut worst_variance = 0.0; + for constraint in &self.undominated_constraints { + let variance = f64_dot(operations_variance, &constraint.variance.coeffs); + let relative_variance = variance / constraint.safe_variance_bound; + if relative_variance > worst_relative_variance { + worst_relative_variance = relative_variance; + worst_variance = variance; + worst_constraint = constraint; + } + } + (worst_variance, worst_relative_variance, worst_constraint) + } + + pub fn p_error(&self, operations_variance: &OperationsValue) -> f64 { + let (_, relative_variance, _) = self.worst_constraint(operations_variance); + p_error_from_relative_variance(relative_variance, self.kappa) + } + + fn global_p_error_with_cut( + &self, + operations_variance: &OperationsValue, + cut: f64, + ) -> Option { + let mut global_p_error = 0.0; + for constraint in &self.constraints { + let variance = f64_dot(operations_variance, &constraint.variance.coeffs); + let relative_variance = variance / constraint.safe_variance_bound; + let p_error = p_error_from_relative_variance(relative_variance, self.kappa); + global_p_error = combine_errors( + global_p_error, + repeat_p_error(p_error, constraint.nb_constraints), + ); + if global_p_error > cut { + return None; + } + } + Some(global_p_error) + } + + pub fn global_p_error(&self, operations_variance: &OperationsValue) -> f64 { + self.global_p_error_with_cut(operations_variance, 1.0) + .unwrap_or(1.0) + } + + pub fn filter_constraints(&self, partition: PartitionIndex) -> Self { + let nb_partitions = self.constraints[0].variance.nb_partitions(); + let touch_any_ks = |constraint: &VarianceConstraint, i| { + let variance = &constraint.variance; + variance.coeff_keyswitch_to_small(partition, i) > 0.0 + || variance.coeff_keyswitch_to_small(i, partition) > 0.0 + || variance.coeff_partition_keyswitch_to_big(partition, i) > 0.0 + || variance.coeff_partition_keyswitch_to_big(i, partition) > 0.0 + }; + let partition_constraints: Vec<_> = self + .constraints + .iter() + .filter(|constraint| { + constraint.partition == partition + || (0..nb_partitions).any(|i| touch_any_ks(constraint, i)) + }) + .map(VarianceConstraint::clone) + .collect(); + Self::of(&partition_constraints, self.kappa, self.global_p_error) + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 35f23044c..3d3ca3723 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -1,9 +1,13 @@ -pub mod analyze; +mod analyze; +mod complexity; +mod fast_keyswitch; +mod feasible; pub mod keys_spec; -pub(crate) mod operations_value; -pub(crate) mod partitionning; -pub(crate) mod partitions; -pub(crate) mod precision_cut; -pub(crate) mod symbolic_variance; -pub(crate) mod union_find; -pub(crate) mod variance_constraint; +mod operations_value; +pub mod optimize; +mod partitionning; +mod partitions; +mod precision_cut; +mod symbolic_variance; +mod union_find; +mod variance_constraint; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs new file mode 100644 index 000000000..f7f14dbe6 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/optimize.rs @@ -0,0 +1,1007 @@ +// OPT: cache for fks and verified pareto +use concrete_cpu_noise_model::gaussian_noise::noise::modulus_switching::estimate_modulus_switching_noise_with_binary_key; + +use crate::dag::unparametrized; +use crate::noise_estimator::error; +use crate::optimization::config::{Config, NoiseBoundConfig, SearchSpace}; +use crate::optimization::dag::multi_parameters::analyze::{analyze, AnalyzedDag}; +use crate::optimization::dag::multi_parameters::fast_keyswitch; +use crate::optimization::dag::multi_parameters::fast_keyswitch::FksComplexityNoise; +use crate::optimization::dag::multi_parameters::operations_value::OperationsValue; +use crate::optimization::decomposition::cmux::CmuxComplexityNoise; +use crate::optimization::decomposition::keyswitch::KsComplexityNoise; +use crate::optimization::decomposition::{cmux, keyswitch, DecompCaches, PersistDecompCaches}; +use crate::parameters::GlweParameters; + +use crate::optimization::dag::multi_parameters::complexity::Complexity; +use crate::optimization::dag::multi_parameters::feasible::Feasible; +use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; +use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut; + +const DEBUG: bool = false; + +#[derive(Debug, Clone)] +pub struct MicroParameters { + pbs: Vec>, + ks: Vec>>, + fks: Vec>>, +} + +// Parameters optimized for 1 partition: +// the partition pbs, all used ks for all partitions, a much fks as partition +pub struct PartialMicroParameters { + pbs: CmuxComplexityNoise, + ks: Vec>>, + fks: Vec>>, + p_error: f64, + global_p_error: f64, + complexity: f64, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct MacroParameters { + glwe_params: GlweParameters, + internal_dim: u64, +} + +#[derive(Debug, Clone)] +pub struct Parameters { + micro_params: MicroParameters, + macro_params: Vec>, + is_lower_bound: bool, + is_feasible: bool, + p_error: f64, + global_p_error: f64, + complexity: f64, +} + +#[derive(Debug, Clone)] +struct OperationsCV { + variance: OperationsValue, + cost: OperationsValue, +} + +type KsSrc = usize; +type KsDst = usize; +type FksSrc = usize; + +#[inline(never)] +fn optimize_1_ks( + ks_src: KsSrc, + ks_dst: KsDst, + ks_input_lwe_dim: u64, + ks_pareto: &[KsComplexityNoise], + operations: &mut OperationsCV, + feasible: &Feasible, + complexity: &Complexity, + cut_complexity: f64, +) -> Option { + // find the first feasible (and less complex) + for &ks_quantity in ks_pareto { + // variance is decreasing, complexity is increasing + *operations.variance.ks(ks_src, ks_dst) = ks_quantity.noise(ks_input_lwe_dim); + *operations.cost.ks(ks_src, ks_dst) = ks_quantity.complexity(ks_input_lwe_dim); + if complexity.complexity(&operations.cost) > cut_complexity { + return None; + } + if feasible.feasible(&operations.variance) { + return Some(ks_quantity); + } + } + None +} + +fn optimize_many_independant_ks( + macro_parameters: &[MacroParameters], + ks_src: KsSrc, + ks_input_lwe_dim: u64, + ks_used: &[Vec], + operations: &OperationsCV, + feasible: &Feasible, + complexity: &Complexity, + caches: &mut keyswitch::Cache, + cut_complexity: f64, +) -> Option<(Vec<(KsDst, KsComplexityNoise)>, OperationsCV)> { + // all ks are independant since they appears in mutually exclusive variance constraints + // only one ks can appear in a variance constraint, + // we can obtain the best feasible by optimizing them separately since everything else is already chosen + // at this point feasability and minimal complexity has already been checked on lower bound + // we know there a feasible solution and a better complexity solution + // we just need to check if both properties at the same time occur + debug_assert!(feasible.feasible(&operations.variance)); + debug_assert!(complexity.complexity(&operations.cost) <= cut_complexity); + let mut operations = operations.clone(); + let mut ks_bests = vec![]; + ks_bests.reserve(macro_parameters.len()); + for (ks_dst, macro_dst) in macro_parameters.iter().enumerate() { + if !ks_used[ks_src][ks_dst] { + continue; + } + let output_dim = macro_dst.internal_dim; + let ks_pareto = caches.pareto_quantities(output_dim); + let ks_best = optimize_1_ks( + ks_src, + ks_dst, + ks_input_lwe_dim, + ks_pareto, + &mut operations, + feasible, + complexity, + cut_complexity, + )?; // abort if feasible but not with the right complexity + ks_bests.push((ks_dst, ks_best)); + } + Some((ks_bests, operations)) +} + +struct Best1FksAndManyKs { + fks: Option<(FksSrc, FksComplexityNoise)>, + many_ks: Vec<(KsDst, KsComplexityNoise)>, +} + +#[allow(clippy::type_complexity)] +fn optimize_1_fks_and_all_compatible_ks( + macro_parameters: &[MacroParameters], + ks_used: &[Vec], + fks_src: usize, + fks_dst: usize, + operations: &OperationsCV, + feasible: &Feasible, + complexity: &Complexity, + caches: &mut keyswitch::Cache, + cut_complexity: f64, +) -> Option<(Best1FksAndManyKs, OperationsCV)> { + // At this point every thing else is known apart fks and ks + let input_glwe = macro_parameters[fks_src].glwe_params; + let output_glwe = macro_parameters[fks_dst].glwe_params; + let output_lwe_dim = output_glwe.sample_extract_lwe_dimension(); + // OPT: have a separate cache for fks + let ks_pareto = caches.pareto_quantities(output_lwe_dim).to_owned(); + // TODO: fast ks in the other direction as well + let use_fast_ks = REAL_FAST_KS && input_glwe.sample_extract_lwe_dimension() >= output_lwe_dim; + let ks_src = fks_dst; + let ks_input_dim = macro_parameters[fks_dst] + .glwe_params + .sample_extract_lwe_dimension(); + let mut operations = operations.clone(); + let mut best_sol = None; + let same_dim = input_glwe == output_glwe; + + for &ks_quantity in &ks_pareto { + // OPT: add a pareto cache for fks + let fks_quantity = if same_dim { + FksComplexityNoise { + decomp: ks_quantity.decomp, + noise: 0.0, + complexity: 0.0, + } + } else if use_fast_ks { + let noise = fast_keyswitch::noise(&ks_quantity, &input_glwe, &output_glwe); + let complexity = + fast_keyswitch::complexity(&input_glwe, &output_glwe, ks_quantity.decomp.level); + FksComplexityNoise { + decomp: ks_quantity.decomp, + noise, + complexity, + } + } else { + let noise = ks_quantity.noise(input_glwe.sample_extract_lwe_dimension()); + let complexity = ks_quantity.complexity(input_glwe.sample_extract_lwe_dimension()); + FksComplexityNoise { + decomp: ks_quantity.decomp, + noise, + complexity, + } + }; + *operations.cost.fks(fks_src, fks_dst) = fks_quantity.complexity; + *operations.variance.fks(fks_src, fks_dst) = fks_quantity.noise; + + if complexity.complexity(&operations.cost) > cut_complexity { + // complexity is strictly increasing by level + // next complexity will be worse + return None; + } + if !feasible.feasible(&operations.variance) { + continue; + } + let sol = optimize_many_independant_ks( + macro_parameters, + ks_src, + ks_input_dim, + ks_used, + &operations, + feasible, + complexity, + caches, + cut_complexity, + ); + if sol.is_none() { + continue; + } + let (best_many_ks, operations) = sol.unwrap(); + let cost = complexity.complexity(&operations.cost); + if cost > cut_complexity { + continue; + } + // COULD: handle complexity tie + let bests = Best1FksAndManyKs { + fks: Some((fks_src, fks_quantity)), + many_ks: best_many_ks, + }; + best_sol = Some((bests, operations)); + if same_dim { + break; + } + } + best_sol +} + +fn optimize_dst_exclusive_fks_subset_and_all_ks( + macro_parameters: &[MacroParameters], + fks_paretos: &[Option], + ks_used: &[Vec], + operations: &OperationsCV, + feasible: &Feasible, + complexity: &Complexity, + caches: &mut keyswitch::Cache, + cut_complexity: f64, +) -> Option<(Vec, OperationsCV)> { + // All fks subgroup can be optimized independently + let mut acc_operations = operations.clone(); + let mut result = vec![]; + result.reserve_exact(fks_paretos.len()); + for (fks_dst, maybe_fks_pareto) in fks_paretos.iter().enumerate() { + let ks_src = fks_dst; + let ks_input_lwe_dim = macro_parameters[fks_dst] + .glwe_params + .sample_extract_lwe_dimension(); + if let Some(fks_src) = maybe_fks_pareto { + let (bests, operations) = optimize_1_fks_and_all_compatible_ks( + macro_parameters, + ks_used, + *fks_src, + fks_dst, + &acc_operations, + feasible, + complexity, + caches, + cut_complexity, + )?; + result.push(bests); + let _ = std::mem::replace(&mut acc_operations, operations); + } else { + // There is no fks to optimize + let (many_ks, operations) = optimize_many_independant_ks( + macro_parameters, + ks_src, + ks_input_lwe_dim, + ks_used, + &acc_operations, + feasible, + complexity, + caches, + cut_complexity, + )?; + result.push(Best1FksAndManyKs { fks: None, many_ks }); + let _ = std::mem::replace(&mut acc_operations, operations); + } + } + Some((result, acc_operations)) +} + +fn optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( + partition: PartitionIndex, + macro_parameters: &[MacroParameters], + internal_dim: u64, + cmux_pareto: &[CmuxComplexityNoise], + fks_paretos: &[Option], + ks_used: &[Vec], + operations: &OperationsCV, + feasible: &Feasible, + complexity: &Complexity, + caches: &mut keyswitch::Cache, + cut_complexity: f64, + best_p_error: f64, +) -> Option { + let mut operations = operations.clone(); + let mut best_sol = None; + let mut best_sol_complexity = cut_complexity; + let mut best_sol_p_error = best_p_error; + let mut best_sol_global_p_error = 1.0; + for &cmux_quantity in cmux_pareto { + // increasing complexity, decreasing variance + let pbs_cost = cmux_quantity.complexity_br(internal_dim); + *operations.cost.pbs(partition) = pbs_cost; + // Lower bounds cuts + let lower_cost = complexity.complexity(&operations.cost); + if lower_cost > best_sol_complexity { + continue; + } + let pbs_variance = cmux_quantity.noise_br(internal_dim); + *operations.variance.pbs(partition) = pbs_variance; + if !feasible.feasible(&operations.variance) { + continue; + } + let sol = optimize_dst_exclusive_fks_subset_and_all_ks( + macro_parameters, + fks_paretos, + ks_used, + &operations, + feasible, + complexity, + caches, + best_sol_complexity, + ); + if sol.is_none() { + continue; + } + + let (best_fks_ks, operations) = sol.unwrap(); + let cost = complexity.complexity(&operations.cost); + if cost > best_sol_complexity { + continue; + }; + let p_error = feasible.p_error(&operations.variance); + #[allow(clippy::float_cmp)] + if cost == best_sol_complexity && p_error >= best_sol_p_error { + continue; + } + best_sol_complexity = cost; + best_sol_p_error = p_error; + best_sol_global_p_error = feasible.global_p_error(&operations.variance); + best_sol = Some((cmux_quantity, best_fks_ks)); + } + if best_sol.is_none() { + return None; + } + let nb_partitions = macro_parameters.len(); + let (cmux_quantity, best_fks_ks) = best_sol.unwrap(); + let mut fks = vec![vec![None; nb_partitions]; nb_partitions]; + let mut ks = vec![vec![None; nb_partitions]; nb_partitions]; + for (fks_dst, one_best_fks_ks) in best_fks_ks.iter().enumerate() { + if let Some((fks_src, sol_fks)) = one_best_fks_ks.fks { + fks[fks_src][fks_dst] = Some(sol_fks); + } + for (ks_dst, sol_ks) in &one_best_fks_ks.many_ks { + ks[fks_dst][*ks_dst] = Some(*sol_ks); + } + } + Some(PartialMicroParameters { + pbs: cmux_quantity, + fks, + ks, + p_error: best_sol_p_error, + global_p_error: best_sol_global_p_error, + complexity: best_sol_complexity, + }) +} + +fn apply_all_ks_lower_bound( + caches: &mut keyswitch::Cache, + nb_partitions: usize, + macro_parameters: &[MacroParameters], + used_tlu_keyswitch: &[Vec], + operations: &mut OperationsCV, +) { + for (src, dst) in cross_partition(nb_partitions) { + if !used_tlu_keyswitch[src][dst] { + continue; + } + let in_glwe_params = macro_parameters[src].glwe_params; + let out_internal_dim = macro_parameters[dst].internal_dim; + let ks_pareto = caches.pareto_quantities(out_internal_dim); + let in_lwe_dim = in_glwe_params.sample_extract_lwe_dimension(); + *operations.variance.ks(src, dst) = keyswitch::lowest_noise_ks(ks_pareto, in_lwe_dim); + *operations.cost.ks(src, dst) = keyswitch::lowest_complexity_ks(ks_pareto, in_lwe_dim); + } +} + +fn apply_all_fks_lower_bound( + caches: &mut keyswitch::Cache, + nb_partitions: usize, + macro_parameters: &[MacroParameters], + used_conversion_keyswitch: &[Vec], + operations: &mut OperationsCV, +) { + for (src, dst) in cross_partition(nb_partitions) { + if !used_conversion_keyswitch[src][dst] { + continue; + } + let input_glwe = ¯o_parameters[src].glwe_params; + let output_glwe = ¯o_parameters[dst].glwe_params; + if input_glwe == output_glwe { + *operations.variance.fks(src, dst) = 0.0; + *operations.cost.fks(src, dst) = 0.0; + continue; + } + let ks_pareto = caches.pareto_quantities(output_glwe.sample_extract_lwe_dimension()); + let use_fast_ks = REAL_FAST_KS + && input_glwe.sample_extract_lwe_dimension() + >= output_glwe.sample_extract_lwe_dimension(); + let cost = if use_fast_ks { + fast_keyswitch::complexity(input_glwe, output_glwe, ks_pareto[0].decomp.level) + } else { + keyswitch::lowest_complexity_ks(ks_pareto, input_glwe.sample_extract_lwe_dimension()) + }; + *operations.cost.fks(src, dst) = cost; + let mut variance_min = f64::INFINITY; + // TODO: use a pareto front to avoid that loop + if use_fast_ks { + for ks_q in ks_pareto { + let variance = fast_keyswitch::noise(ks_q, input_glwe, output_glwe); + variance_min = variance_min.min(variance); + } + } else { + variance_min = + keyswitch::lowest_noise_ks(ks_pareto, input_glwe.sample_extract_lwe_dimension()); + } + *operations.variance.fks(src, dst) = variance_min; + } +} + +fn apply_partitions_input_and_modulus_variance_and_cost( + ciphertext_modulus_log: u32, + security_level: u64, + nb_partitions: usize, + macro_parameters: &[MacroParameters], + partition: PartitionIndex, + input_variance: f64, + variance_modulus_switching: f64, + operations: &mut OperationsCV, +) { + for i in 0..nb_partitions { + let (input_variance, variance_modulus_switching) = + if macro_parameters[i] == macro_parameters[partition] { + (input_variance, variance_modulus_switching) + } else { + let input_variance = macro_parameters[i] + .glwe_params + .minimal_variance(ciphertext_modulus_log, security_level); + let variance_modulus_switching = estimate_modulus_switching_noise_with_binary_key( + macro_parameters[i].internal_dim, + macro_parameters[i].glwe_params.log2_polynomial_size, + ciphertext_modulus_log, + ); + (input_variance, variance_modulus_switching) + }; + *operations.variance.input(i) = input_variance; + *operations.variance.modulus_switching(i) = variance_modulus_switching; + } +} + +fn apply_pbs_variance_and_cost_or_lower_bounds( + caches: &mut cmux::Cache, + macro_parameters: &[MacroParameters], + initial_pbs: &[Option], + partition: PartitionIndex, + operations: &mut OperationsCV, +) { + // setting already chosen pbs and lower bounds + for (i, pbs) in initial_pbs.iter().enumerate() { + let pbs = if i == partition { &None } else { pbs }; + if let Some(pbs) = pbs { + let internal_dim = macro_parameters[i].internal_dim; + *operations.variance.pbs(i) = pbs.noise_br(internal_dim); + *operations.cost.pbs(i) = pbs.complexity_br(internal_dim); + } else { + // OPT: Most values could be shared on first optimize_macro + let in_internal_dim = macro_parameters[i].internal_dim; + let out_glwe_params = macro_parameters[i].glwe_params; + let variance_min = + cmux::lowest_noise_br(caches.pareto_quantities(out_glwe_params), in_internal_dim); + *operations.variance.pbs(i) = variance_min; + *operations.cost.pbs(i) = 0.0; + } + } +} + +fn fks_to_optimize( + nb_partitions: usize, + used_conversion_keyswitch: &[Vec], + optimized_partition: PartitionIndex, +) -> Vec> { + // Prepare a subset fks pareto to optimize: real, lower, bound or unused (fake) + // We only take 1 fks pareto fks[_->dst] with different dst partition for each dst, since they can be optimized independently. + // I.e. They appears only in constraints with ks[fks_dst->_]. + // When fks is unused a None is used to keep the same loop structure. + let mut fks_paretos: Vec> = vec![]; + fks_paretos.reserve_exact(nb_partitions); + for fks_dst in 0..nb_partitions { + // find the i-th valid fks_src + let fks_src = if used_conversion_keyswitch[optimized_partition][fks_dst] { + Some(optimized_partition) + } else { + let mut count_used: usize = 0; + let mut fks_src = None; + #[allow(clippy::needless_range_loop)] + for src in 0..nb_partitions { + let used = used_conversion_keyswitch[src][fks_dst]; + if used && count_used == optimized_partition { + fks_src = Some(src); + break; + } + count_used += used as usize; + } + if fks_src.is_none() && count_used > 0 { + let n_th = optimized_partition % count_used; + count_used = 0; + #[allow(clippy::needless_range_loop)] + for src in 0..nb_partitions { + let used = used_conversion_keyswitch[src][fks_dst]; + if used && count_used == n_th { + fks_src = Some(src); + break; + } + } + } + fks_src + }; + fks_paretos.push(fks_src); + } + fks_paretos +} + +// In case fast ks are not used +const REAL_FAST_KS: bool = true; + +#[allow(clippy::too_many_lines)] +fn optimize_macro( + security_level: u64, + ciphertext_modulus_log: u32, + search_space: &SearchSpace, + partition: PartitionIndex, + used_tlu_keyswitch: &[Vec], + used_conversion_keyswitch: &[Vec], + feasible: &Feasible, + complexity: &Complexity, + caches: &mut DecompCaches, + init_parameters: &Parameters, + best_complexity: f64, + best_p_error: f64, +) -> Parameters { + let nb_partitions = init_parameters.macro_params.len(); + assert!(partition < nb_partitions); + + let variance_modulus_switching_of = |glwe_log2_poly_size, internal_lwe_dimensions| { + estimate_modulus_switching_noise_with_binary_key( + internal_lwe_dimensions, + glwe_log2_poly_size, + ciphertext_modulus_log, + ) + }; + + let mut best_parameters = init_parameters.clone(); + let mut best_complexity = best_complexity; + let mut best_p_error = best_p_error; + let mut best_partition_p_error = f64::INFINITY; + + let fks_to_optimize = fks_to_optimize(nb_partitions, used_conversion_keyswitch, partition); + let operations = OperationsCV { + variance: OperationsValue::zero(nb_partitions), + cost: OperationsValue::zero(nb_partitions), + }; + let partition_feasible = feasible.filter_constraints(partition); + + let glwe_params_domain = search_space.glwe_dimensions.iter().flat_map(|a| { + search_space + .glwe_log_polynomial_sizes + .iter() + .map(|b| (*a, *b)) + }); + for (glwe_dimension, log2_polynomial_size) in glwe_params_domain { + let glwe_params = GlweParameters { + log2_polynomial_size, + glwe_dimension, + }; + + let input_variance = glwe_params.minimal_variance(ciphertext_modulus_log, security_level); + if glwe_dimension == 1 && log2_polynomial_size == 8 { + // this is insecure and so minimal variance will be above 1 + assert!(input_variance > 1.0); + continue; + } + + for &internal_dim in &search_space.internal_lwe_dimensions { + let mut operations = operations.clone(); + // OPT: fast linear noise_modulus_switching + let variance_modulus_switching = + variance_modulus_switching_of(log2_polynomial_size, internal_dim); + + let macro_param_partition = MacroParameters { + glwe_params, + internal_dim, + }; + + // Heuristic to fill missing macro parameters + let macros: Vec<_> = (0..nb_partitions) + .map(|i| { + if i == partition { + macro_param_partition + } else { + init_parameters.macro_params[i].unwrap_or(macro_param_partition) + } + }) + .collect(); + + // OPT: could be done once and than partially updated + apply_partitions_input_and_modulus_variance_and_cost( + ciphertext_modulus_log, + security_level, + nb_partitions, + ¯os, + partition, + input_variance, + variance_modulus_switching, + &mut operations, + ); + + if best_parameters.is_feasible && !feasible.feasible(&operations.variance) { + // noise_modulus_switching is increasing with internal_dim so we can cut + // but as long as nothing feasible as been found we don't break to improve feasibility + break; + } + + if complexity.complexity(&operations.cost) > best_complexity { + continue; + } + + // setting already chosen pbs and lower bounds + // OPT: could be done once and than partially updated + apply_pbs_variance_and_cost_or_lower_bounds( + &mut caches.cmux, + ¯os, + &init_parameters.micro_params.pbs, + partition, + &mut operations, + ); + + // OPT: could be done once and than partially updated + apply_all_ks_lower_bound( + &mut caches.keyswitch, + nb_partitions, + ¯os, + used_tlu_keyswitch, + &mut operations, + ); + // OPT: could be done once and than partially updated + apply_all_fks_lower_bound( + &mut caches.keyswitch, + nb_partitions, + ¯os, + used_conversion_keyswitch, + &mut operations, + ); + + let non_feasible = !feasible.feasible(&operations.variance); + if best_parameters.is_feasible && non_feasible { + continue; + } + + if complexity.complexity(&operations.cost) > best_complexity { + continue; + } + + let cmux_pareto = caches.cmux.pareto_quantities(glwe_params); + + if non_feasible { + // here we optimize for feasibility only + // if nothing is feasible, it will give improves feasability for later iterations + let mut macro_params = init_parameters.macro_params.clone(); + macro_params[partition] = Some(MacroParameters { + glwe_params, + internal_dim, + }); + // optimize the feasibility only, takes all lower bounds on variance + // this selects both macro parameters and pbs (lowest variance) for this partition + let complexity = f64::INFINITY; + let cmux_params = cmux::lowest_noise(cmux_pareto); + let partition_p_error = partition_feasible.p_error(&operations.variance); + if partition_p_error >= best_partition_p_error { + continue; + } + best_partition_p_error = partition_p_error; + let p_error = feasible.p_error(&operations.variance); + let global_p_error = feasible.global_p_error(&operations.variance); + let mut pbs = init_parameters.micro_params.pbs.clone(); + pbs[partition] = Some(cmux_params); + let micro_params = MicroParameters { + pbs, + ks: vec![vec![None; nb_partitions]; nb_partitions], + fks: vec![vec![None; nb_partitions]; nb_partitions], + }; + best_parameters = Parameters { + p_error, + global_p_error, + complexity, + micro_params, + macro_params, + is_lower_bound: true, + is_feasible: false, + }; + continue; + } + + if complexity.complexity(&operations.cost) > best_complexity { + continue; + } + + let micro_opt = optimize_1_cmux_and_dst_exclusive_fks_subset_and_all_ks( + partition, + ¯os, + internal_dim, + cmux_pareto, + &fks_to_optimize, + used_tlu_keyswitch, + &operations, + feasible, + complexity, + &mut caches.keyswitch, + best_complexity, + best_p_error, + ); + if let Some(some_micro_params) = micro_opt { + // erase macros and all fks and ks that can't be real + // set global is_lower_bound here, if any parameter is missing this is lower bound + // optimize_micro has already checked for best-ness + let mut macro_params = init_parameters.macro_params.clone(); + macro_params[partition] = Some(macro_param_partition); + let is_lower_bound = macro_params.iter().any(Option::is_none); + // copy back pbs from other partition + let mut all_pbs = init_parameters.micro_params.pbs.clone(); + all_pbs[partition] = Some(some_micro_params.pbs); + let micro_params = MicroParameters { + pbs: all_pbs, + ks: some_micro_params.ks, + fks: some_micro_params.fks, + }; + // for (i, pbs) in init_parameters.micro_params.pbs.iter().enumerate() { + // if i != partition { + // micro_params.pbs[i] = *pbs; + // } + // } + best_complexity = some_micro_params.complexity; + best_p_error = some_micro_params.p_error; + best_parameters = Parameters { + p_error: best_p_error, + global_p_error: some_micro_params.global_p_error, + complexity: best_complexity, + micro_params, + macro_params, + is_lower_bound, + is_feasible: true, + }; + } else { + // the macro parameters are feasible + // but the complexity is not good enough due to previous feasible solution + assert!(best_parameters.is_feasible); + } + } + } + best_parameters +} + +fn cross_partition(nb_partitions: usize) -> impl Iterator { + (0..nb_partitions).flat_map(move |a: usize| (0..nb_partitions).map(move |b: usize| (a, b))) +} + +#[allow(clippy::too_many_lines)] +pub fn optimize( + dag: &unparametrized::OperationDag, + config: Config, + search_space: &SearchSpace, + persistent_caches: &PersistDecompCaches, + p_cut: &Option, + default_partition: PartitionIndex, +) -> Option { + let ciphertext_modulus_log = config.ciphertext_modulus_log; + let security_level = config.security_level; + let noise_config = NoiseBoundConfig { + security_level, + maximum_acceptable_error_probability: config.maximum_acceptable_error_probability, + ciphertext_modulus_log, + }; + + let dag = analyze(dag, &noise_config, p_cut, default_partition); + + let kappa = + error::sigma_scale_of_error_probability(config.maximum_acceptable_error_probability); + + let mut caches = persistent_caches.caches(); + + let feasible = Feasible::of(&dag.variance_constraints, kappa, None); + let complexity = Complexity::of(&dag.operations_count); + let used_tlu_keyswitch = used_tlu_keyswitch(&dag); + let used_conversion_keyswitch = used_conversion_keyswitch(&dag); + + let nb_partitions = dag.nb_partitions; + let init_parameters = Parameters { + is_lower_bound: false, + is_feasible: false, + macro_params: vec![None; nb_partitions], + micro_params: MicroParameters { + pbs: vec![None; nb_partitions], + ks: vec![vec![None; nb_partitions]; nb_partitions], + fks: vec![vec![None; nb_partitions]; nb_partitions], + }, + p_error: 1.0, + global_p_error: 1.0, + complexity: f64::INFINITY, + }; + + let mut params = init_parameters; + let mut best_complexity = f64::INFINITY; + let mut best_p_error = f64::INFINITY; + + let mut fix_point = params.clone(); + for iter in 0..=10 { + for partition in 0..nb_partitions { + let new_params = optimize_macro( + security_level, + ciphertext_modulus_log, + search_space, + partition, + &used_tlu_keyswitch, + &used_conversion_keyswitch, + &feasible, + &complexity, + &mut caches, + ¶ms, + best_complexity, + best_p_error, + ); + assert!( + new_params.is_feasible || !params.is_feasible, + "Cannot degrade feasibility" + ); + params = new_params; + if !params.is_feasible { + if nb_partitions == 1 { + return None; + } + if DEBUG { + eprintln!( + "Intermediate non feasible solution {iter} : {partition} : {}", + params.p_error + ); + } + continue; + } + if DEBUG { + eprintln!( + "Feasible solution {iter} : {partition} : {} {} {}", + params.p_error, params.complexity, params.is_lower_bound + ); + } + if !params.is_lower_bound { + best_complexity = params.complexity; + best_p_error = params.p_error; + } + } + if nb_partitions == 1 { + break; + } + // OPT: could be detected sooner + #[allow(clippy::float_cmp)] + if fix_point.complexity == params.complexity + && fix_point.p_error == params.p_error + && fix_point.macro_params == params.macro_params + { + if DEBUG { + eprintln!("Fix point reached at {iter}"); + } + if !params.is_feasible { + eprintln!("{:?}", params.macro_params); + return None; + } + break; + } + fix_point = params.clone(); + } + sanity_check( + ¶ms, + ciphertext_modulus_log, + security_level, + &feasible, + &complexity, + ); + Some(params) +} + +fn used_tlu_keyswitch(dag: &AnalyzedDag) -> Vec> { + let mut result = vec![vec![false; dag.nb_partitions]; dag.nb_partitions]; + for (src_partition, dst_partition) in cross_partition(dag.nb_partitions) { + for constraint in &dag.variance_constraints { + if constraint + .variance + .coeff_keyswitch_to_small(src_partition, dst_partition) + != 0.0 + { + result[src_partition][dst_partition] = true; + break; + } + } + } + result +} + +fn used_conversion_keyswitch(dag: &AnalyzedDag) -> Vec> { + let mut result = vec![vec![false; dag.nb_partitions]; dag.nb_partitions]; + for (src_partition, dst_partition) in cross_partition(dag.nb_partitions) { + for constraint in &dag.variance_constraints { + if constraint + .variance + .coeff_partition_keyswitch_to_big(src_partition, dst_partition) + != 0.0 + { + result[src_partition][dst_partition] = true; + break; + } + } + } + result +} + +#[allow(clippy::float_cmp)] +fn sanity_check( + params: &Parameters, + ciphertext_modulus_log: u32, + security_level: u64, + feasible: &Feasible, + complexity: &Complexity, +) { + let nb_partitions = params.macro_params.len(); + let mut operations = OperationsCV { + variance: OperationsValue::zero(nb_partitions), + cost: OperationsValue::zero(nb_partitions), + }; + let micro_params = ¶ms.micro_params; + for partition in 0..nb_partitions { + let partition_macro = params.macro_params[partition].unwrap(); + let glwe_param = partition_macro.glwe_params; + let internal_dim = partition_macro.internal_dim; + let input_variance = glwe_param.minimal_variance(ciphertext_modulus_log, security_level); + let variance_modulus_switching = estimate_modulus_switching_noise_with_binary_key( + internal_dim, + glwe_param.log2_polynomial_size, + ciphertext_modulus_log, + ); + *operations.variance.input(partition) = input_variance; + *operations.variance.modulus_switching(partition) = variance_modulus_switching; + if let Some(pbs) = micro_params.pbs[partition] { + *operations.variance.pbs(partition) = pbs.noise_br(internal_dim); + *operations.cost.pbs(partition) = pbs.complexity_br(internal_dim); + } else { + *operations.variance.pbs(partition) = f64::MAX; + *operations.cost.pbs(partition) = f64::MAX; + } + for src_partition in 0..nb_partitions { + let src_partition_macro = params.macro_params[src_partition].unwrap(); + let src_glwe_param = src_partition_macro.glwe_params; + let src_lwe_dim = src_glwe_param.sample_extract_lwe_dimension(); + if let Some(ks) = micro_params.ks[src_partition][partition] { + *operations.variance.ks(src_partition, partition) = ks.noise(src_lwe_dim); + *operations.cost.ks(src_partition, partition) = ks.complexity(src_lwe_dim); + } else { + *operations.variance.ks(src_partition, partition) = f64::MAX; + *operations.cost.ks(src_partition, partition) = f64::MAX; + } + if let Some(fks) = micro_params.fks[src_partition][partition] { + *operations.variance.fks(src_partition, partition) = fks.noise; + *operations.cost.fks(src_partition, partition) = fks.complexity; + } else { + *operations.variance.fks(src_partition, partition) = f64::MAX; + *operations.cost.fks(src_partition, partition) = f64::MAX; + } + } + } + #[allow(clippy::float_cmp)] + { + assert!(feasible.feasible(&operations.variance)); + assert!(params.p_error == feasible.p_error(&operations.variance)); + assert!(params.complexity == complexity.complexity(&operations.cost)); + assert!(params.global_p_error == feasible.global_p_error(&operations.variance)); + } +} + +#[cfg(test)] +include!("tests/test_optimize.rs"); diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs index 75aebaeb1..a856bc36d 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/precision_cut.rs @@ -2,6 +2,7 @@ use crate::dag::operator::{Operator, Precision}; use crate::dag::unparametrized; use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; +#[derive(Clone, Debug)] pub struct PrecisionCut { // partition0 precision <= p_cut[0] < partition 1 precision <= p_cut[1] ... // precision are in the sens of Lut input precision and are sorted diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs new file mode 100644 index 000000000..31d63aae7 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/tests/test_optimize.rs @@ -0,0 +1,538 @@ +// part of optimize.rs +#[cfg(test)] +mod tests { + #![allow(clippy::float_cmp)] + + use once_cell::sync::Lazy; + + use super::*; + use crate::computing_cost::cpu::CpuComplexity; + use crate::config; + use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; + use crate::dag::unparametrized; + use crate::optimization::dag::solo_key; + use crate::optimization::dag::solo_key::optimize::{add_v0_dag, v0_dag}; + use crate::optimization::decomposition; + + static SHARED_CACHES: Lazy = Lazy::new(|| { + let processing_unit = config::ProcessingUnit::Cpu; + decomposition::cache(128, processing_unit, None, true) + }); + + const _4_SIGMA: f64 = 0.000_063_342_483_999_973; + + const LOW_PARTITION: PartitionIndex = 0; + + fn optimize( + dag: &unparametrized::OperationDag, + p_cut: &Option, + default_partition: usize, + ) -> Option { + let config = Config { + security_level: 128, + maximum_acceptable_error_probability: _4_SIGMA, + ciphertext_modulus_log: 64, + complexity_model: &CpuComplexity::default(), + }; + + let search_space = SearchSpace::default_cpu(); + super::optimize( + dag, + config, + &search_space, + &SHARED_CACHES, + p_cut, + default_partition, + ) + } + + fn optimize_single(dag: &unparametrized::OperationDag) -> Option { + optimize(dag, &Some(PrecisionCut { p_cut: vec![] }), LOW_PARTITION) + } + + fn equiv_single(dag: &unparametrized::OperationDag) -> Option { + let sol_mono = solo_key::optimize::tests::optimize(dag); + let sol_multi = optimize_single(dag); + if sol_mono.best_solution.is_none() != sol_multi.is_none() { + eprintln!("Not same feasibility"); + return Some(false); + }; + if sol_multi.is_none() { + return None; + } + let equiv = sol_mono.best_solution.unwrap().complexity + == sol_multi.as_ref().unwrap().complexity; + if !equiv { + eprintln!("Not same complexity"); + eprintln!("Single: {:?}", sol_mono.best_solution.unwrap()); + eprintln!( + "Multi: {:?}", + sol_multi.clone().unwrap().complexity + ); + eprintln!("Multi: {:?}", sol_multi.unwrap()); + } + Some(equiv) + } + + #[test] + fn optimize_simple_parameter_v0_dag() { + for precision in 1..11 { + for manp in 1..25 { + eprintln!("P M {precision} {manp}"); + let dag = v0_dag(0, precision, manp as f64); + if let Some(equiv) = equiv_single(&dag) { + assert!(equiv); + } else { + break; + } + } + } + } + + #[test] + fn optimize_simple_parameter_rounded_lut_2_layers() { + for accumulator_precision in 1..11 { + for precision in 1..accumulator_precision { + for manp in [1, 8, 16] { + eprintln!("CASE {accumulator_precision} {precision} {manp}"); + let dag = v0_dag(0, precision, manp as f64); + if let Some(equiv) = equiv_single(&dag) { + assert!(equiv); + } else { + break; + } + } + } + } + } + + fn equiv_2_single( + dag_multi: &unparametrized::OperationDag, + dag_1: &unparametrized::OperationDag, + dag_2: &unparametrized::OperationDag, + ) -> Option { + let precision_max = dag_multi.out_precisions.iter().copied().max().unwrap(); + let p_cut = Some(PrecisionCut { + p_cut: vec![precision_max - 1], + }); + eprintln!("{dag_multi}"); + let sol_single_1 = solo_key::optimize::tests::optimize(dag_1); + let sol_single_2 = solo_key::optimize::tests::optimize(dag_2); + let sol_multi = optimize(dag_multi, &p_cut, LOW_PARTITION); + let sol_multi_1 = optimize(dag_1, &p_cut, LOW_PARTITION); + let sol_multi_2 = optimize(dag_2, &p_cut, LOW_PARTITION); + let feasible_1 = sol_single_1.best_solution.is_some(); + let feasible_2 = sol_single_2.best_solution.is_some(); + let feasible_multi = sol_multi.is_some(); + if (feasible_1 && feasible_2) != feasible_multi { + eprintln!( + "Not same feasibility {feasible_1} {feasible_2} {feasible_multi}" + ); + return Some(false); + } + if sol_multi.is_none() { + return None; + } + let sol_multi = sol_multi.unwrap(); + let sol_multi_1 = sol_multi_1.unwrap(); + let sol_multi_2 = sol_multi_2.unwrap(); + let cost_1 = sol_single_1.best_solution.unwrap().complexity; + let cost_2 = sol_single_2.best_solution.unwrap().complexity; + let cost_multi = sol_multi.complexity; + let equiv = + cost_1 + cost_2 == cost_multi + && cost_1 == sol_multi_1.complexity + && cost_2 == sol_multi_2.complexity + && sol_multi.micro_params.ks[0][0].unwrap().decomp == + sol_multi_1.micro_params.ks[0][0].unwrap().decomp + && sol_multi.micro_params.ks[1][1].unwrap().decomp == + sol_multi_2.micro_params.ks[0][0].unwrap().decomp + ; + if !equiv { + eprintln!("Not same complexity"); + eprintln!("Multi: {cost_multi:?}"); + eprintln!("Added Single: {:?}", cost_1 + cost_2); + eprintln!("Single1: {:?}", sol_single_1.best_solution.unwrap()); + eprintln!("Single2: {:?}", sol_single_2.best_solution.unwrap()); + eprintln!("Multi: {sol_multi:?}"); + eprintln!("Multi1: {sol_multi_1:?}"); + eprintln!("Multi2: {sol_multi_2:?}"); + } + Some(equiv) + } + + #[test] + fn optimize_multi_independant_2_precisions() { + let sum_size = 0; + for precision1 in 1..11 { + for precision2 in (precision1 + 1)..11 { + for manp in [1, 8, 16] { + eprintln!("CASE {precision1} {precision2} {manp}"); + let noise_factor = manp as f64; + let mut dag_multi = v0_dag(sum_size, precision1, noise_factor); + add_v0_dag(&mut dag_multi, sum_size, precision2, noise_factor); + let dag_1 = v0_dag(sum_size, precision1, noise_factor); + let dag_2 = v0_dag(sum_size, precision2, noise_factor); + if let Some(equiv) = equiv_2_single(&dag_multi, &dag_1, &dag_2) { + assert!(equiv, "FAILED ON {precision1} {precision2} {manp}"); + } else { + break; + } + } + } + } + } + + fn dag_lut_sum_of_2_partitions_2_layer(precision1: u8, precision2: u8, final_lut: bool) -> unparametrized::OperationDag { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision1, Shape::number()); + let input2 = dag.add_input(precision2, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, precision1); + let lut2 = dag.add_lut(input2, FunctionTable::UNKWOWN, precision2); + let lut1 = dag.add_lut(lut1, FunctionTable::UNKWOWN, precision2); + let lut2 = dag.add_lut(lut2, FunctionTable::UNKWOWN, precision2); + let dot = dag.add_dot([lut1, lut2], [1, 1]); + if final_lut { + _ = dag.add_lut(dot, FunctionTable::UNKWOWN, precision1); + } + dag + } + + #[test] + fn optimize_multi_independant_2_partitions_finally_added() { + let default_partition = 0; + let single_precision_sol : Vec<_> = (0..11).map( + |precision| { + let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, false); + optimize_single(&dag) + } + ).collect(); + + for precision1 in 1..11 { + for precision2 in (precision1 + 1)..11 { + let p_cut = Some(PrecisionCut { + p_cut: vec![precision1], + }); + let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, false); + let sol_1 = single_precision_sol[precision1 as usize].clone(); + let sol_2 = single_precision_sol[precision2 as usize].clone(); + let sol_multi = optimize(&dag_multi, &p_cut, LOW_PARTITION); + let feasible_multi = sol_multi.is_some(); + let feasible_2 = sol_2.is_some(); + assert!(feasible_multi); + assert!(feasible_2); + let sol_multi = sol_multi.unwrap(); + let sol_1 = sol_1.unwrap(); + let sol_2 = sol_2.unwrap(); + assert!(sol_1.complexity < sol_multi.complexity); + assert!(sol_multi.complexity < sol_2.complexity); + eprintln!("{:?}", sol_multi.micro_params.fks); + let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity; + let sol_multi_without_fks = sol_multi.complexity - fks_complexity; + let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; + assert!(sol_multi.macro_params[1] == sol_2.macro_params[0]); + // The smallest the precision the more fks noise break partition independence + if precision1 < 4 { + assert!( + sol_multi_without_fks / perfect_complexity < 1.1, + "{precision1} {precision2}" + ); + } else if precision1 <= 7 { + assert!( + sol_multi_without_fks / perfect_complexity < 1.03, + "{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity + ); + } else { + assert!( + sol_multi_without_fks / perfect_complexity < 1.001, + "{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity + ); + } + } + } + } + + #[test] + fn optimize_multi_independant_2_partitions_finally_added_and_luted() { + let default_partition = 0; + let single_precision_sol : Vec<_> = (0..11).map( + |precision| { + let dag = dag_lut_sum_of_2_partitions_2_layer(precision, precision, true); + optimize_single(&dag) + } + ).collect(); + for precision1 in 1..11 { + for precision2 in (precision1 + 1)..11 { + let p_cut = Some(PrecisionCut { + p_cut: vec![precision1], + }); + let dag_multi = dag_lut_sum_of_2_partitions_2_layer(precision1, precision2, true); + let sol_1 = single_precision_sol[precision1 as usize].clone(); + let sol_2 = single_precision_sol[precision2 as usize].clone(); + let sol_multi = optimize(&dag_multi, &p_cut, 0); + let feasible_multi = sol_multi.is_some(); + let feasible_2 = sol_2.is_some(); + assert!(feasible_multi); + assert!(feasible_2); + let sol_multi = sol_multi.unwrap(); + let sol_1 = sol_1.unwrap(); + let sol_2 = sol_2.unwrap(); + // The smallest the precision the more fks noise dominate + assert!(sol_1.complexity < sol_multi.complexity); + assert!(sol_multi.complexity < sol_2.complexity); + let fks_complexity = sol_multi.micro_params.fks[(default_partition + 1) % 2][default_partition].unwrap().complexity; + let sol_multi_without_fks = sol_multi.complexity - fks_complexity; + let perfect_complexity = (sol_1.complexity + sol_2.complexity) / 2.0; + let relative_degradation = sol_multi_without_fks / perfect_complexity; + if precision1 < 4 { + assert!( + relative_degradation < 1.2, + "{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity + ); + } else if precision1 <= 7 { + assert!( + relative_degradation < 1.19, + "{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity + ); + } else { + assert!( + relative_degradation < 1.15, + "{precision1} {precision2} {}", sol_multi_without_fks / perfect_complexity + ); + } + } + } + } + + fn optimize_rounded(dag: &unparametrized::OperationDag) -> Option { + let p_cut = Some(PrecisionCut { p_cut: vec![1] }); + let default_partition = 0; + optimize(dag, &p_cut, default_partition) + } + + fn dag_rounded_lut_2_layers( + accumulator_precision: usize, + precision: usize, + ) -> unparametrized::OperationDag { + let out_precision = accumulator_precision as u8; + let rounded_precision = precision as u8; + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision as u8, Shape::number()); + let rounded1 = dag.add_expanded_rounded_lut( + input1, + FunctionTable::UNKWOWN, + rounded_precision, + out_precision, + ); + let rounded2 = dag.add_expanded_rounded_lut( + rounded1, + FunctionTable::UNKWOWN, + rounded_precision, + out_precision, + ); + let _rounded3 = + dag.add_expanded_rounded_lut(rounded2, FunctionTable::UNKWOWN, rounded_precision, 1); + dag + } + + fn test_optimize_v3_expanded_round(precision_acc: usize, precision_tlu: usize, minimal_speedup: f64) { + let dag = dag_rounded_lut_2_layers(precision_acc, precision_tlu); + let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap(); + let sol = optimize_rounded(&dag).unwrap(); + let speedup = sol_mono.complexity / sol.complexity; + assert!(speedup >= minimal_speedup, + "Speedup {speedup} smaller than {minimal_speedup} for {precision_acc}/{precision_tlu}" + ); + let expected_ks = [ + [true, true], // KS[0], KS[0->1] + [false, true],// KS[1] + ]; + let expected_fks = [ + [false, false], + [true, false], // FKS[1->0] + ]; + for (src, dst) in cross_partition(2) { + assert!(sol.micro_params.ks[src][dst].is_some() == expected_ks[src][dst]); + assert!(sol.micro_params.fks[src][dst].is_some() == expected_fks[src][dst]); + } + } + + #[test] + fn test_optimize_v3_expanded_round_16_8() { + test_optimize_v3_expanded_round(16, 8, 5.5); + } + + #[test] + fn test_optimize_v3_expanded_round_16_6() { + test_optimize_v3_expanded_round(16, 6, 3.3); + } + + #[test] + fn optimize_v3_direct_round() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(16, Shape::number()); + _ = dag.add_expanded_rounded_lut(input1, FunctionTable::UNKWOWN, 8, 16); + let sol = optimize_rounded(&dag).unwrap(); + let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap(); + let minimal_speedup = 8.6; + let speedup = sol_mono.complexity / sol.complexity; + assert!(speedup >= minimal_speedup, + "Speedup {speedup} smaller than {minimal_speedup}" + ); + } + + #[test] + fn optimize_sign_extract() { + let precision = 8; + let high_precision = 16; + let mut dag = unparametrized::OperationDag::new(); + let complexity = LevelledComplexity::ZERO; + let free_small_input1 = dag.add_input(precision, Shape::number()); + let small_input1 = dag.add_lut(free_small_input1, FunctionTable::UNKWOWN, precision); + let small_input1 = dag.add_lut(small_input1, FunctionTable::UNKWOWN, high_precision); + let input1 = dag.add_levelled_op( + [small_input1], + complexity, + 1.0, + Shape::vector(1_000_000), + "comment", + ); + let rounded1 = dag.add_expanded_round(input1, 1); + let _rounded2 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, 1); + let sol = optimize_rounded(&dag).unwrap(); + let sol_mono = solo_key::optimize::tests::optimize(&dag).best_solution.unwrap(); + let speedup = sol_mono.complexity / sol.complexity; + let minimal_speedup = 80.0; + assert!(speedup >= minimal_speedup, + "Speedup {speedup} smaller than {minimal_speedup}" + ); + } + + fn test_partition_chain(decreasing: bool) { + // tlu chain with decreasing precision (decreasing partition index) + // check that increasing partitionning gaves faster solutions + // check solution has the right structure + let mut dag = unparametrized::OperationDag::new(); + let min_precision = 6; + let max_precision = 8; + let mut input_precisions : Vec<_> = (min_precision..=max_precision).collect(); + if decreasing { + input_precisions.reverse(); + } + let mut lut_input = dag.add_input(input_precisions[0], Shape::number()); + for &out_precision in &input_precisions { + lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); + } + lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *input_precisions.last().unwrap()); + _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, min_precision); + let mut p_cut = PrecisionCut { p_cut:vec![] }; + let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); + assert!(sol.macro_params.len() == 1); + let mut complexity = sol.complexity; + for &out_precision in &input_precisions { + if out_precision == max_precision { + // There is nothing to cut above max_precision + continue; + } + p_cut.p_cut.push(out_precision); + p_cut.p_cut.sort_unstable(); + eprintln!("PCUT {p_cut}"); + let sol = optimize(&dag, &Some(p_cut.clone()), 0).unwrap(); + let nb_partitions = sol.macro_params.len(); + assert!(nb_partitions == (p_cut.p_cut.len() + 1), + "bad nb partitions {} {p_cut}", sol.macro_params.len()); + assert!(sol.complexity < complexity, + "{} < {complexity} {out_precision} / {max_precision}", sol.complexity); + for (src, dst) in cross_partition(nb_partitions) { + let ks = sol.micro_params.ks[src][dst]; + eprintln!("{} {src} {dst}", ks.is_some()); + let expected_ks = + (!decreasing || src == dst + 1) + && (decreasing || src + 1 == dst) + || (src == dst && (src == 0 || src == nb_partitions - 1)) + ; + assert!(ks.is_some() == expected_ks, "{:?} {:?}", ks.is_some(), expected_ks); + let fks = sol.micro_params.fks[src][dst]; + assert!(fks.is_none()); + } + complexity = sol.complexity; + } + let sol = optimize(&dag, &None, 0); + assert!(sol.unwrap().complexity == complexity); + } + + #[test] + fn test_partition_decreasing_chain() { + test_partition_chain(true); + } + + #[test] + fn test_partition_increasing_chain() { + test_partition_chain(true); + } + + const MAX_WEIGHT: &[u64] = &[ + // max v0 weight for each precision + 1_073_741_824, + 1_073_741_824, // 2**30, 1b + 536_870_912, // 2**29, 2b + 268_435_456, // 2**28, 3b + 67_108_864, // 2**26, 4b + 16_777_216, // 2**24, 5b + 4_194_304, // 2**22, 6b + 1_048_576, // 2**20, 7b + 262_144, // 2**18, 8b + 65_536, // 2**16, 9b + 16384, // 2**14, 10b + 2048, // 2**11, 11b + ]; + + #[test] + fn test_independant_partitions_non_feasible_single_params() { + // generate hard circuit, non feasible with single parameters + // composed of independant partitions so we know the optimal result + let precisions = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; + let sum_size = 0; + let noise_factor = MAX_WEIGHT[precisions[0]] as f64; + let mut dag = v0_dag(sum_size, precisions[0] as u64, noise_factor); + let sol_single = optimize_single(&dag); + let mut optimal_complexity = sol_single.as_ref().unwrap().complexity; + let mut optimal_p_error = sol_single.unwrap().p_error; + for &out_precision in &precisions[1..] { + let noise_factor = MAX_WEIGHT[out_precision] as f64; + add_v0_dag(&mut dag, sum_size, out_precision as u64, noise_factor); + let sol_single = optimize_single(&v0_dag(sum_size, out_precision as u64, noise_factor)); + optimal_complexity += sol_single.as_ref().unwrap().complexity; + optimal_p_error += sol_single.as_ref().unwrap().p_error; + } + // check non feasible in single + let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; + assert!(sol_single.is_none()); + // solves in multi + let sol = optimize(&dag, &None, 0); + assert!(sol.is_some()); + let sol = sol.unwrap(); + // check optimality + assert!(sol.complexity / optimal_complexity < 1.0 + f64::EPSILON); + assert!(sol.p_error / optimal_p_error < 1.0 + f64::EPSILON); + } + + #[test] + fn test_chained_partitions_non_feasible_single_params() { + // generate hard circuit, non feasible with single parameters + let precisions = [0u8, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]; + // Note: reversing chain have issues for connecting lower bits to 7 bits, there may be no feasible solution + let mut dag = unparametrized::OperationDag::new(); + let mut lut_input = dag.add_input(precisions[0], Shape::number()); + for out_precision in precisions { + let noise_factor = MAX_WEIGHT[*dag.out_precisions.last().unwrap() as usize] as f64; + lut_input = dag.add_levelled_op([lut_input], LevelledComplexity::ZERO, noise_factor, Shape::number(), ""); + lut_input = dag.add_lut(lut_input, FunctionTable::UNKWOWN, out_precision); + } + _ = dag.add_lut(lut_input, FunctionTable::UNKWOWN, *precisions.last().unwrap()); + let sol_single = solo_key::optimize::tests::optimize(&dag).best_solution; + assert!(sol_single.is_none()); + let sol = optimize(&dag, &None, 0); + assert!(sol.is_some()); + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index db4cb3f16..2bd345479 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -504,7 +504,7 @@ fn peak_relative_variance( (max_relative_var, safe_noise) } -fn p_error_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 { +pub fn p_error_from_relative_variance(relative_variance: f64, kappa: f64) -> f64 { let sigma_scale = kappa / relative_variance.sqrt(); error::error_probability_of_sigma_scale(sigma_scale) } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index c27aaf0e7..f2716e43a 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -3,6 +3,7 @@ use concrete_cpu_noise_model::gaussian_noise::noise::modulus_switching::estimate use super::analyze; use crate::dag::operator::{LevelledComplexity, Precision}; use crate::dag::unparametrized; +use crate::dag::unparametrized::OperationDag; use crate::noise_estimator::error; use crate::optimization::atomic_pattern::{ OptimizationDecompositionsConsts, OptimizationState, Solution, @@ -386,6 +387,27 @@ pub fn optimize( state } +pub fn add_v0_dag(dag: &mut OperationDag, sum_size: u64, precision: u64, noise_factor: f64) { + use crate::dag::operator::{FunctionTable, Shape}; + let same_scale_manp = 1.0; + let manp = noise_factor; + let out_shape = &Shape::number(); + let complexity = LevelledComplexity::ADDITION * sum_size; + let comment = "dot"; + let precision = precision as Precision; + let input1 = dag.add_input(precision, out_shape); + let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment); + let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); + let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment); + let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); +} + +pub fn v0_dag(sum_size: u64, precision: u64, noise_factor: f64) -> OperationDag { + let mut dag = unparametrized::OperationDag::new(); + add_v0_dag(&mut dag, sum_size, precision, noise_factor); + dag +} + pub fn optimize_v0( sum_size: u64, precision: u64, @@ -394,19 +416,7 @@ pub fn optimize_v0( search_space: &SearchSpace, cache: &PersistDecompCaches, ) -> OptimizationState { - use crate::dag::operator::{FunctionTable, Shape}; - let same_scale_manp = 0.0; - let manp = noise_factor; - let out_shape = &Shape::number(); - let complexity = LevelledComplexity::ADDITION * sum_size; - let comment = "dot"; - let mut dag = unparametrized::OperationDag::new(); - let precision = precision as Precision; - let input1 = dag.add_input(precision, out_shape); - let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment); - let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); - let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment); - let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); + let dag = v0_dag(sum_size, precision, noise_factor); let mut state = optimize(&dag, config, search_space, cache); if let Some(sol) = &mut state.best_solution { sol.complexity /= 2.0; @@ -415,7 +425,7 @@ pub fn optimize_v0( } #[cfg(test)] -mod tests { +pub(crate) mod tests { use std::time::Instant; use once_cell::sync::Lazy; @@ -456,7 +466,7 @@ mod tests { decomposition::cache(128, processing_unit, None, true) }); - fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState { + pub fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState { let config = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs index e843f2939..f8f526087 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/decomposition/cmux.rs @@ -102,8 +102,12 @@ pub fn pareto_quantities( quantities } +pub fn lowest_noise(quantities: &[CmuxComplexityNoise]) -> CmuxComplexityNoise { + quantities[quantities.len() - 1] +} + pub fn lowest_noise_br(quantities: &[CmuxComplexityNoise], in_lwe_dim: u64) -> f64 { - quantities[quantities.len() - 1].noise_br(in_lwe_dim) + lowest_noise(quantities).noise_br(in_lwe_dim) } pub fn lowest_complexity_br(quantities: &[CmuxComplexityNoise], in_lwe_dim: u64) -> f64 { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs index d8d8c5510..267b4e9e3 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/wop_atomic_pattern/optimize.rs @@ -17,7 +17,7 @@ use crate::optimization::decomposition::keyswitch::KsComplexityNoise; use crate::optimization::decomposition::pp_switch::PpSwitchComplexityNoise; use crate::optimization::decomposition::PersistDecompCaches; use crate::parameters::{BrDecompositionParameters, GlweParameters}; -use crate::utils::max::f64_max; +use crate::utils::f64::f64_max; use crate::utils::square; pub fn find_p_error(kappa: f64, variance_bound: f64, current_maximum_noise: f64) -> f64 { diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/f64.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/f64.rs new file mode 100644 index 000000000..91b016ed9 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/f64.rs @@ -0,0 +1,11 @@ +pub fn f64_max(values: &[f64], default: f64) -> f64 { + values.iter().copied().reduce(f64::max).unwrap_or(default) +} + +pub fn f64_dot(a: &[f64], b: &[f64]) -> f64 { + let mut sum = 0.0; + for i in 0..a.len() { + sum += a[i] * b[i]; + } + sum +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs index 688a637f8..6d8ec019b 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/utils/mod.rs @@ -1,6 +1,6 @@ pub mod cache; +pub mod f64; pub mod hasher_builder; -pub mod max; pub fn square(v: V) -> V where