From 361244abd0023b5b91dbd79f1c8c510d51999fb8 Mon Sep 17 00:00:00 2001 From: rudy Date: Thu, 23 Feb 2023 17:37:42 +0100 Subject: [PATCH] feat(optimizer): symbolic variance constraints for multiparameters --- .../dag/multi_parameters/analyze.rs | 222 +++++++++++++++++- .../optimization/dag/multi_parameters/mod.rs | 1 + .../multi_parameters/variance_constraint.rs | 111 +++++++++ 3 files changed, 330 insertions(+), 4 deletions(-) create mode 100644 compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs index eb8a67e7a..7f71b6b41 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -1,14 +1,21 @@ -use crate::dag::operator::{dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Shape}; +use crate::dag::operator::{ + dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, +}; use crate::dag::rewrite::round::expand_round; use crate::dag::unparametrized; use crate::optimization::config::NoiseBoundConfig; use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred; use crate::optimization::dag::multi_parameters::partitions::{ - InstructionPartition, PartitionIndex, + InstructionPartition, PartitionIndex, Transition, }; use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut; use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance; -use crate::optimization::dag::solo_key::analyze::first; +use crate::optimization::dag::solo_key::analyze::{ + extra_final_values_to_check, first, safe_noise_bound, +}; + +use super::variance_constraint::VarianceConstraint; + use crate::utils::square; // private short convention @@ -24,11 +31,15 @@ pub struct AnalyzedDag { pub out_variances: Vec>, // The full dag levelled complexity pub levelled_complexity: LevelledComplexity, + // All variance constraints including dominated ones + pub variance_constraints: Vec, + // Undominated variance constraints + pub undominated_variance_constraints: Vec, } pub fn analyze( dag: &unparametrized::OperationDag, - _noise_config: &NoiseBoundConfig, + noise_config: &NoiseBoundConfig, p_cut: &PrecisionCut, default_partition: PartitionIndex, ) -> AnalyzedDag { @@ -45,12 +56,18 @@ pub fn analyze( let nb_partitions = partitions.nb_partitions; let out_variances = out_variances(&dag, nb_partitions, &instrs_partition); + let variance_constraints = + collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances); + let undominated_variance_constraints = + VarianceConstraint::remove_dominated(&variance_constraints); AnalyzedDag { operators: dag.operators, nb_partitions, instrs_partition, out_variances, levelled_complexity, + variance_constraints, + undominated_variance_constraints, } } @@ -144,6 +161,85 @@ fn out_variances( out_variances } +fn variance_constraint( + dag: &unparametrized::OperationDag, + noise_config: &NoiseBoundConfig, + partition: PartitionIndex, + op_i: usize, + precision: Precision, + variance: SymbolicVariance, +) -> VarianceConstraint { + let nb_constraints = dag.out_shapes[op_i].flat_size(); + let safe_variance_bound = safe_noise_bound(precision, noise_config); + VarianceConstraint { + precision, + partition, + nb_constraints, + safe_variance_bound, + variance, + } +} + +#[allow(clippy::float_cmp)] +#[allow(clippy::match_on_vec_items)] +fn collect_all_variance_constraints( + dag: &unparametrized::OperationDag, + noise_config: &NoiseBoundConfig, + instrs_partition: &[InstructionPartition], + out_variances: &[Vec], +) -> Vec { + let decryption_points = extra_final_values_to_check(dag); + let mut constraints = vec![]; + for (op_i, op) in dag.operators.iter().enumerate() { + let partition = instrs_partition[op_i].instruction_partition; + if let Op::Lut { input, .. } = op { + let precision = dag.out_precisions[input.i]; + let dst_partition = partition; + let src_partition = match instrs_partition[op_i].inputs_transition[0] { + None => dst_partition, + Some(Transition::Internal { src_partition }) => { + assert!(src_partition != dst_partition); + src_partition + } + Some(Transition::Additional { src_partition }) => { + assert!(src_partition != dst_partition); + let variance = &out_variances[input.i][dst_partition]; + assert!( + variance.coeff_partition_keyswitch_to_big(src_partition, dst_partition) + == 1.0 + ); + dst_partition + } + }; + let variance = &out_variances[input.i][src_partition].clone(); + let variance = variance + .after_partition_keyswitch_to_small(src_partition, dst_partition) + .after_modulus_switching(partition); + constraints.push(variance_constraint( + dag, + noise_config, + partition, + op_i, + precision, + variance, + )); + } + if decryption_points[op_i] { + let precision = dag.out_precisions[op_i]; + let variance = out_variances[op_i][partition].clone(); + constraints.push(variance_constraint( + dag, + noise_config, + partition, + op_i, + precision, + variance, + )); + } + } + constraints +} + #[cfg(test)] mod tests { use super::*; @@ -444,4 +540,122 @@ mod tests { assert!(first_bit_extract_verified); assert!(first_bit_erase_verified); } + + #[test] + fn test_rounded_v3_classic_first_layer_second_layer_constraints() { + let acc_precision = 7; + let precision = 4; + let mut dag = unparametrized::OperationDag::new(); + let free_input1 = dag.add_input(precision, Shape::number()); + let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); + let rounded1 = dag.add_expanded_round(input1, precision); + let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); + let old_dag = dag; + let dag = analyze(&old_dag); + show_partitionning(&old_dag, &dag.instrs_partition); + let constraints: Vec<_> = dag + .variance_constraints + .iter() + .map(VarianceConstraint::to_string) + .collect(); + let expected_constraints = [ + // First lut to force partition HIGH_PRECISION_PARTITION + "1σ²In[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)", + // 16384(shift) = (2**7)², for Br[1] + "16384σ²Br[1] + 16384σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)", + // 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1] + "4096σ²Br[0] + 4096σ²Br[1] + 4096σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)", + // 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1] + "2048σ²Br[0] + 1024σ²Br[1] + 1024σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)", + // 3(erase bit) Br[0] and 1 initial Br[1] + "3σ²Br[0] + 1σ²Br[1] + 1σ²FK[1→0] + 1σ²K[0→1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)", + // Last lut to close the cycle + "1σ²Br[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)", + ]; + for (c, ec) in constraints.iter().zip(expected_constraints) { + assert!( + c == ec, + "\nBad constraint\nActual: {c}\nTruth : {ec} (expected)\n" + ); + } + let simplified_constraints: Vec<_> = dag + .undominated_variance_constraints + .iter() + .map(VarianceConstraint::to_string) + .collect(); + let expected_simplified_constraints = [ + expected_constraints[1], // biggest weights on Br[1] + expected_constraints[2], // biggest weights on Br[0] + expected_constraints[4], // only one to have K[0→1] + expected_constraints[0], // only one to have K[1] + // 3 is dominated by 2 + ]; + for (c, ec) in simplified_constraints + .iter() + .zip(expected_simplified_constraints) + { + assert!( + c == ec, + "\nBad simplified constraint\nActual: {c}\nTruth : {ec} (expected)\n" + ); + } + } + + #[test] + fn test_rounded_v1_classic_first_layer_second_layer_constraints() { + let acc_precision = 7; + let precision = 4; + let mut dag = unparametrized::OperationDag::new(); + let free_input1 = dag.add_input(precision, Shape::number()); + let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); + // let input1 = dag.add_input(acc_precision, Shape::number()); + let rounded1 = dag.add_expanded_round(input1, precision); + let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision); + let old_dag = dag; + let dag = analyze_with_preferred(&old_dag, HIGH_PRECISION_PARTITION); + show_partitionning(&old_dag, &dag.instrs_partition); + let constraints: Vec<_> = dag + .variance_constraints + .iter() + .map(VarianceConstraint::to_string) + .collect(); + let expected_constraints = [ + // First lut to force partition HIGH_PRECISION_PARTITION + "1σ²In[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)", + // 16384(shift) = (2**7)², for Br[1] + "16384σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)", + // 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1] + "4096σ²Br[0] + 4096σ²FK[0→1] + 4096σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)", + // 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1] + "2048σ²Br[0] + 2048σ²FK[0→1] + 1024σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)", + "3σ²Br[0] + 3σ²FK[0→1] + 1σ²Br[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)", + ]; + for (c, ec) in constraints.iter().zip(expected_constraints) { + assert!( + c == ec, + "\nBad constraint\nActual: {c}\nTruth : {ec} (expected)\n" + ); + } + let simplified_constraints: Vec<_> = dag + .undominated_variance_constraints + .iter() + .map(VarianceConstraint::to_string) + .collect(); + let expected_simplified_constraints = [ + expected_constraints[1], // biggest weights on Br[1] + expected_constraints[2], // biggest weights on Br[0] + expected_constraints[4], // only one to have K[0→1] + expected_constraints[0], // only one to have K[1] + // 3 is dominated by 2 + ]; + for (c, ec) in simplified_constraints + .iter() + .zip(expected_simplified_constraints) + { + assert!( + c == ec, + "\nBad simplified constraint\nActual: {c}\nTruth : {ec} (expected)\n" + ); + } + } } diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 4b6b658e2..35f23044c 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -6,3 +6,4 @@ pub(crate) mod partitions; pub(crate) mod precision_cut; pub(crate) mod symbolic_variance; pub(crate) mod union_find; +pub(crate) mod variance_constraint; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs new file mode 100644 index 000000000..adbe3bec5 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/variance_constraint.rs @@ -0,0 +1,111 @@ +use std::fmt; + +use crate::dag::operator::Precision; +use crate::optimization::dag::multi_parameters::partitions::PartitionIndex; +use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance; + +#[derive(Clone, Debug)] +pub struct VarianceConstraint { + pub precision: Precision, + pub partition: PartitionIndex, + pub nb_constraints: u64, + pub safe_variance_bound: f64, + pub variance: SymbolicVariance, +} + +impl fmt::Display for VarianceConstraint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{} < (2²)**{} ({}bits partition:{} count:{}, dom={})", + self.variance, + self.safe_variance_bound.log2().round() / 2.0, + self.precision, + self.partition, + self.nb_constraints, + self.dominance_index() + )?; + Ok(()) + } +} + +impl VarianceConstraint { + #[allow(clippy::cast_sign_loss)] + fn dominance_index(&self) -> u64 { + let max_coeff = self + .variance + .coeffs + .iter() + .copied() + .reduce(f64::max) + .unwrap(); + (max_coeff / self.safe_variance_bound).log2().ceil() as u64 + } + + fn dominate_or_equal(&self, other: &Self) -> bool { + // With BR > Fresh + let self_var = &self.variance; + let other_var = &other.variance; + let self_renorm = other.safe_variance_bound / self.safe_variance_bound; + let rel_diff = + |f: &dyn Fn(&SymbolicVariance) -> f64| self_renorm * f(self_var) - f(other_var); + for partition in 0..self.variance.nb_partitions() { + let diffs = [ + rel_diff(&|var| var.coeff_pbs(partition)), + rel_diff(&|var| var.coeff_pbs(partition) + var.coeff_input(partition)), + rel_diff(&|var| var.coeff_modulus_switching(partition)), + ]; + for diff in diffs { + if diff < 0.0 { + return false; + } + } + } + for src_partition in 0..self.variance.nb_partitions() { + for dst_partition in 0..self.variance.nb_partitions() { + let diffs = [ + rel_diff(&|var| var.coeff_keyswitch_to_small(src_partition, dst_partition)), + rel_diff(&|var| { + var.coeff_partition_keyswitch_to_big(src_partition, dst_partition) + }), + ]; + for diff in diffs { + if diff < 0.0 { + return false; + } + } + } + } + true + } + + pub fn remove_dominated(constraints: &[Self]) -> Vec { + let mut constraints = constraints.to_vec(); + constraints.sort_by_cached_key(Self::dominance_index); + constraints.reverse(); + let mut dominated = vec![false; constraints.len()]; + for (i, constraint) in constraints.iter().enumerate() { + if dominated[i] { + continue; + } + for (j, other_constraint) in constraints.iter().enumerate() { + if j <= i { + continue; + } + if constraint.dominate_or_equal(other_constraint) { + dominated[j] = true; + } else if other_constraint.dominate_or_equal(constraint) { + dominated[i] = true; + break; + } + } + } + let mut result = vec![]; + for (i, c) in constraints.iter().enumerate() { + if !dominated[i] { + result.push(c.clone()); + } + } + result + } +}