diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs new file mode 100644 index 000000000..eb8a67e7a --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/analyze.rs @@ -0,0 +1,447 @@ +use crate::dag::operator::{dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Shape}; +use crate::dag::rewrite::round::expand_round; +use crate::dag::unparametrized; +use crate::optimization::config::NoiseBoundConfig; +use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred; +use crate::optimization::dag::multi_parameters::partitions::{ + InstructionPartition, PartitionIndex, +}; +use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut; +use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance; +use crate::optimization::dag::solo_key::analyze::first; +use crate::utils::square; + +// private short convention +use DotKind as DK; + +type Op = Operator; + +pub struct AnalyzedDag { + pub operators: Vec, + // Collect all operators ouput variances + pub nb_partitions: usize, + pub instrs_partition: Vec, + pub out_variances: Vec>, + // The full dag levelled complexity + pub levelled_complexity: LevelledComplexity, +} + +pub fn analyze( + dag: &unparametrized::OperationDag, + _noise_config: &NoiseBoundConfig, + p_cut: &PrecisionCut, + default_partition: PartitionIndex, +) -> AnalyzedDag { + assert!( + p_cut.p_cut.len() <= 1, + "Multi-parameter can only be used 0 or 1 precision cut" + ); + let dag = expand_round(dag); + let levelled_complexity = LevelledComplexity::ZERO; + // The precision cut is chosen to work well with rounded pbs + // Note: this is temporary + let partitions = partitionning_with_preferred(&dag, p_cut, default_partition); + let instrs_partition = partitions.instrs_partition; + let nb_partitions = partitions.nb_partitions; + let out_variances = out_variances(&dag, nb_partitions, &instrs_partition); + + AnalyzedDag { + operators: dag.operators, + nb_partitions, + instrs_partition, + out_variances, + levelled_complexity, + } +} + +fn out_variance( + op: &unparametrized::UnparameterizedOperator, + out_shapes: &[Shape], + out_variances: &mut Vec>, + nb_partitions: usize, + instr_partition: &InstructionPartition, +) -> Vec { + // one variance per partition, in case the result is converted + let partition = instr_partition.instruction_partition; + let out_variance_of = |input: &OperatorIndex| { + assert!(input.i < out_variances.len()); + assert!(partition < out_variances[input.i].len()); + assert!(out_variances[input.i][partition] != SymbolicVariance::ZERO); + assert!(!out_variances[input.i][partition].coeffs.values[0].is_nan()); + assert!(out_variances[input.i][partition].partition != usize::MAX); + out_variances[input.i][partition].clone() + }; + let max_variance = |acc: SymbolicVariance, input: SymbolicVariance| acc.max(&input); + let variance = match op { + Op::Input { .. } => SymbolicVariance::input(nb_partitions, partition), + Op::Lut { .. } => SymbolicVariance::after_pbs(nb_partitions, partition), + Op::LevelledOp { inputs, manp, .. } => { + let inputs_variance = inputs.iter().map(out_variance_of); + let max_variance = inputs_variance.reduce(max_variance).unwrap(); + max_variance.after_levelled_op(*manp) + } + Op::Dot { + inputs, weights, .. + } => { + let input_shape = first(inputs, out_shapes); + let kind = dot_kind(inputs.len() as u64, input_shape, weights); + match kind { + DK::Simple | DK::Tensor | DK::Broadcast => { + let inputs_variance = (0..weights.values.len()).map(|j| { + let input = if inputs.len() > 1 { + inputs[j] + } else { + inputs[0] + }; + out_variance_of(&input) + }); + let mut out_variance = SymbolicVariance::ZERO; + for (input_variance, &weight) in inputs_variance.zip(&weights.values) { + assert!(input_variance != SymbolicVariance::ZERO); + out_variance += input_variance * square(weight); + } + out_variance + } + DK::CompatibleTensor { .. } => todo!("TODO"), + DK::Unsupported { .. } => panic!("Unsupported"), + } + } + Op::UnsafeCast { input, .. } => out_variance_of(input), + Op::Round { .. } => { + unreachable!("Round should have been either expanded or integrated to a lut") + } + }; + // Injecting NAN in unused symbolic variance to detect bad use + let unused = SymbolicVariance::nan(nb_partitions); + let mut result = vec![unused; nb_partitions]; + for &dst_partition in &instr_partition.alternative_output_representation { + let src_partition = partition; + // make converted variance available in dst_partition + result[dst_partition] = + variance.after_partition_keyswitch_to_big(src_partition, dst_partition); + } + result[partition] = variance; + result +} + +fn out_variances( + dag: &unparametrized::OperationDag, + nb_partitions: usize, + instrs_partition: &[InstructionPartition], +) -> Vec> { + let nb_ops = dag.operators.len(); + let mut out_variances = Vec::with_capacity(nb_ops); + for (op, instr_partition) in dag.operators.iter().zip(instrs_partition) { + let vf = out_variance( + op, + &dag.out_shapes, + &mut out_variances, + nb_partitions, + instr_partition, + ); + out_variances.push(vf); + } + out_variances +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dag::operator::{FunctionTable, Shape}; + use crate::dag::unparametrized; + use crate::optimization::dag::multi_parameters::partitionning::tests::{ + show_partitionning, HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION, + }; + use crate::optimization::dag::solo_key::analyze::tests::CONFIG; + + fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag { + analyze_with_preferred(dag, LOW_PRECISION_PARTITION) + } + + fn analyze_with_preferred( + dag: &unparametrized::OperationDag, + default_partition: PartitionIndex, + ) -> AnalyzedDag { + let p_cut = PrecisionCut { p_cut: vec![2] }; + super::analyze(dag, &CONFIG, &p_cut, default_partition) + } + + #[allow(clippy::float_cmp)] + fn assert_input_on(dag: &AnalyzedDag, partition: usize, op_i: usize, expected_coeff: f64) { + for symbolic_variance_partition in [LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION] { + let sb = dag.out_variances[op_i][partition].clone(); + let coeff = if sb == SymbolicVariance::ZERO { + 0.0 + } else { + sb.coeff_input(symbolic_variance_partition) + }; + if symbolic_variance_partition == partition { + assert!( + coeff == expected_coeff, + "INCORRECT INPUT COEFF ON GOOD PARTITION {:?} {:?} {} {}", + dag.out_variances[op_i], + partition, + coeff, + expected_coeff + ); + } else { + assert!( + coeff == 0.0, + "INCORRECT INPUT COEFF ON WRONG PARTITION {:?} {:?} {} {}", + dag.out_variances[op_i], + partition, + coeff, + expected_coeff + ); + } + } + } + + #[allow(clippy::float_cmp)] + fn assert_pbs_on(dag: &AnalyzedDag, partition: usize, op_i: usize, expected_coeff: f64) { + for symbolic_variance_partition in [LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION] { + let sb = dag.out_variances[op_i][partition].clone(); + eprintln!("{:?}", dag.out_variances[op_i]); + eprintln!("{:?}", dag.out_variances[op_i][partition]); + let coeff = if sb == SymbolicVariance::ZERO { + 0.0 + } else { + sb.coeff_pbs(symbolic_variance_partition) + }; + if symbolic_variance_partition == partition { + assert!( + coeff == expected_coeff, + "INCORRECT PBS COEFF ON GOOD PARTITION {:?} {:?} {} {}", + dag.out_variances[op_i], + partition, + coeff, + expected_coeff + ); + } else { + assert!( + coeff == 0.0, + "INCORRECT PBS COEFF ON GOOD PARTITION {:?} {:?} {} {}", + dag.out_variances[op_i], + partition, + coeff, + expected_coeff + ); + } + } + } + + #[allow(clippy::needless_range_loop)] + #[test] + fn test_lut_sequence() { + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(8, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); + let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 1); + let lut3 = dag.add_lut(lut2, FunctionTable::UNKWOWN, 1); + let lut4 = dag.add_lut(lut3, FunctionTable::UNKWOWN, 8); + let lut5 = dag.add_lut(lut4, FunctionTable::UNKWOWN, 8); + let partitions = [ + HIGH_PRECISION_PARTITION, + HIGH_PRECISION_PARTITION, + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION, + LOW_PRECISION_PARTITION, + HIGH_PRECISION_PARTITION, + ]; + let dag = analyze(&dag); + assert!(dag.nb_partitions == 2); + for op_i in input1.i..=lut5.i { + let p = &dag.instrs_partition[op_i]; + let is_input = op_i == input1.i; + assert!(p.instruction_partition == partitions[op_i]); + if is_input { + assert_input_on(&dag, p.instruction_partition, op_i, 1.0); + assert_pbs_on(&dag, p.instruction_partition, op_i, 0.0); + } else { + assert_pbs_on(&dag, p.instruction_partition, op_i, 1.0); + assert_input_on(&dag, p.instruction_partition, op_i, 0.0); + } + } + } + + #[test] + fn test_levelled_op() { + let mut dag = unparametrized::OperationDag::new(); + let out_shape = Shape::number(); + let manp = 8.0; + let input1 = dag.add_input(8, Shape::number()); + let input2 = dag.add_input(8, Shape::number()); + let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8); + let _levelled = dag.add_levelled_op( + [lut1, input2], + LevelledComplexity::ZERO, + manp, + &out_shape, + "comment", + ); + let dag = analyze(&dag); + assert!(dag.nb_partitions == 1); + } + + fn nan_symbolic_variance(sb: &SymbolicVariance) -> bool { + sb.coeffs[0].is_nan() + } + + #[allow(clippy::float_cmp)] + #[test] + fn test_rounded_v3_first_layer_and_second_layer() { + let acc_precision = 16; + let precision = 8; + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(acc_precision, Shape::number()); + let rounded1 = dag.add_expanded_round(input1, precision); + let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); + let rounded2 = dag.add_expanded_round(lut1, precision); + let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision); + let old_dag = dag; + let dag = analyze(&old_dag); + show_partitionning(&old_dag, &dag.instrs_partition); + // First layer is fully LOW_PRECISION_PARTITION + for op_i in input1.i..lut1.i { + let p = LOW_PRECISION_PARTITION; + let sb = &dag.out_variances[op_i][p]; + assert!(sb.coeff_input(p) >= 1.0 || sb.coeff_pbs(p) >= 1.0); + assert!(nan_symbolic_variance( + &dag.out_variances[op_i][HIGH_PRECISION_PARTITION] + )); + } + // First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION + let p = HIGH_PRECISION_PARTITION; + let sb = &dag.out_variances[lut1.i][p]; + assert!(sb.coeff_input(p) == 0.0); + assert!(sb.coeff_pbs(p) == 1.0); + let sb_after_fast_ks = &dag.out_variances[lut1.i][LOW_PRECISION_PARTITION]; + assert!( + sb_after_fast_ks.coeff_partition_keyswitch_to_big( + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION + ) == 1.0 + ); + // The next rounded is on LOW_PRECISION_PARTITION but base noise can comes from HIGH_PRECISION_PARTITION + FKS + for op_i in (lut1.i + 1)..lut2.i { + assert!(LOW_PRECISION_PARTITION == dag.instrs_partition[op_i].instruction_partition); + let p = LOW_PRECISION_PARTITION; + let sb = &dag.out_variances[op_i][p]; + // The base noise is either from the other partition and shifted or from the current partition and 1 + assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); + assert!(sb.coeff_input(HIGH_PRECISION_PARTITION) == 0.0); + if sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0 { + assert!( + sb.coeff_pbs(HIGH_PRECISION_PARTITION) + == sb.coeff_partition_keyswitch_to_big( + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION + ) + ); + } else { + assert!(sb.coeff_pbs(LOW_PRECISION_PARTITION) == 1.0); + assert!( + sb.coeff_partition_keyswitch_to_big( + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION + ) == 0.0 + ); + } + } + assert!(nan_symbolic_variance( + &dag.out_variances[lut2.i][LOW_PRECISION_PARTITION] + )); + let sb = &dag.out_variances[lut2.i][HIGH_PRECISION_PARTITION]; + assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0); + } + + #[allow(clippy::float_cmp, clippy::cognitive_complexity)] + #[test] + fn test_rounded_v3_classic_first_layer_second_layer() { + let acc_precision = 16; + let precision = 8; + let mut dag = unparametrized::OperationDag::new(); + let free_input1 = dag.add_input(precision, Shape::number()); + let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision); + let rounded1 = dag.add_expanded_round(input1, precision); + let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision); + let old_dag = dag; + let dag = analyze(&old_dag); + show_partitionning(&old_dag, &dag.instrs_partition); + // First layer is fully HIGH_PRECISION_PARTITION + assert!( + dag.out_variances[free_input1.i][HIGH_PRECISION_PARTITION] + .coeff_input(HIGH_PRECISION_PARTITION) + == 1.0 + ); + // First layer tlu + let sb = &dag.out_variances[input1.i][HIGH_PRECISION_PARTITION]; + assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); + assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); + assert!( + sb.coeff_partition_keyswitch_to_big(HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION) + == 0.0 + ); + // The same cyphertext exists in another partition with additional noise due to fast keyswitch + let sb = &dag.out_variances[input1.i][LOW_PRECISION_PARTITION]; + assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0); + assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0); + assert!( + sb.coeff_partition_keyswitch_to_big(HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION) + == 1.0 + ); + + // Second layer + let mut first_bit_extract_verified = false; + let mut first_bit_erase_verified = false; + for op_i in (input1.i + 1)..rounded1.i { + if let Op::Dot { + weights, inputs, .. + } = &dag.operators[op_i] + { + let bit_extract = weights.values.len() == 1; + let first_bit_extract = bit_extract && !first_bit_extract_verified; + let bit_erase = weights.values == [1, -1]; + let first_bit_erase = bit_erase && !first_bit_erase_verified; + let input0_sb = &dag.out_variances[inputs[0].i][LOW_PRECISION_PARTITION]; + let input0_coeff_pbs_high = input0_sb.coeff_pbs(HIGH_PRECISION_PARTITION); + let input0_coeff_pbs_low = input0_sb.coeff_pbs(LOW_PRECISION_PARTITION); + let input0_coeff_fks = input0_sb.coeff_partition_keyswitch_to_big( + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION, + ); + if bit_extract { + first_bit_extract_verified |= first_bit_extract; + assert!(input0_coeff_pbs_high >= 1.0); + if first_bit_extract { + assert!(input0_coeff_pbs_low == 0.0); + } else { + assert!(input0_coeff_pbs_low >= 1.0); + } + assert!(input0_coeff_fks == 1.0); + } else if bit_erase { + first_bit_erase_verified |= first_bit_erase; + let input1_sb = &dag.out_variances[inputs[1].i][LOW_PRECISION_PARTITION]; + let input1_coeff_pbs_high = input1_sb.coeff_pbs(HIGH_PRECISION_PARTITION); + let input1_coeff_pbs_low = input1_sb.coeff_pbs(LOW_PRECISION_PARTITION); + let input1_coeff_fks = input1_sb.coeff_partition_keyswitch_to_big( + HIGH_PRECISION_PARTITION, + LOW_PRECISION_PARTITION, + ); + if first_bit_erase { + assert!(input0_coeff_pbs_low == 0.0); + } else { + assert!(input0_coeff_pbs_low >= 1.0); + } + assert!(input0_coeff_pbs_high == 1.0); + assert!(input0_coeff_fks == 1.0); + assert!(input1_coeff_pbs_low == 1.0); + assert!(input1_coeff_pbs_high == 0.0); + assert!(input1_coeff_fks == 0.0); + } + } + } + assert!(first_bit_extract_verified); + assert!(first_bit_erase_verified); + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs index 8885311e3..4b6b658e2 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/mod.rs @@ -1,5 +1,8 @@ +pub mod analyze; pub mod keys_spec; -pub mod partitionning; +pub(crate) mod operations_value; +pub(crate) mod partitionning; pub(crate) mod partitions; pub(crate) mod precision_cut; +pub(crate) mod symbolic_variance; pub(crate) mod union_find; diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs new file mode 100644 index 000000000..9251cd8b4 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/operations_value.rs @@ -0,0 +1,179 @@ +use std::ops::{Deref, DerefMut}; + +/** + * Index actual operations (input, ks, pbs, fks, modulus switching, etc). + */ +#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd)] +pub struct Indexing { + /* Values order + [ + // Partition 1 + // related only to the partition + fresh, pbs, modulus, + // Keyswitchs to small, from any partition to 1 + ks from 1, ks from 2, ... + // Keyswitch to big, from any partition to 1 + ks from 1, ks from 2, ... + + // Partition 2 + // same + ] + */ + pub nb_partitions: usize, +} + +pub const VALUE_INDEX_FRESH: usize = 0; +pub const VALUE_INDEX_PBS: usize = 1; +pub const VALUE_INDEX_MODULUS: usize = 2; +// number of value always present for a partition +pub const STABLE_NB_VALUES_BY_PARTITION: usize = 3; + +impl Indexing { + fn nb_keyswitchs_per_partition(self) -> usize { + self.nb_partitions + } + + pub fn nb_coeff_per_partition(self) -> usize { + STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions + } + + pub fn nb_coeff(self) -> usize { + self.nb_partitions * (STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions) + } + + pub fn input(self, partition: usize) -> usize { + partition * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH + } + + pub fn pbs(self, partition: usize) -> usize { + partition * self.nb_coeff_per_partition() + VALUE_INDEX_PBS + } + + pub fn modulus_switching(self, partition: usize) -> usize { + partition * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS + } + + pub fn keyswitch_to_small(self, src_partition: usize, dst_partition: usize) -> usize { + // Skip other partition + dst_partition * self.nb_coeff_per_partition() + // Skip non keyswitchs + + STABLE_NB_VALUES_BY_PARTITION + // Select the right keyswicth to small + + src_partition + } + + pub fn keyswitch_to_big(self, src_partition: usize, dst_partition: usize) -> usize { + // Skip other partition + dst_partition * self.nb_coeff_per_partition() + // Skip non keyswitchs + + STABLE_NB_VALUES_BY_PARTITION + // Skip keyswitch to small + + self.nb_keyswitchs_per_partition() + // Select the right keyswicth to big + + src_partition + } +} + +/** + * Represent any values indexed by actual operations (input, pbs, modulus switching, ks, fks, , etc) variance, + */ +#[derive(Clone, Debug, PartialEq, PartialOrd)] +pub struct OperationsValue { + pub index: Indexing, + pub values: Vec, +} + +impl OperationsValue { + pub const ZERO: Self = Self { + index: Indexing { nb_partitions: 0 }, + values: vec![], + }; + + pub fn zero(nb_partitions: usize) -> Self { + let index = Indexing { nb_partitions }; + Self { + index, + values: vec![0.0; index.nb_coeff()], + } + } + + pub fn nan(nb_partitions: usize) -> Self { + let index = Indexing { nb_partitions }; + Self { + index, + values: vec![f64::NAN; index.nb_coeff()], + } + } + + pub fn input(&mut self, partition: usize) -> &mut f64 { + &mut self.values[self.index.input(partition)] + } + + pub fn pbs(&mut self, partition: usize) -> &mut f64 { + &mut self.values[self.index.pbs(partition)] + } + + pub fn ks(&mut self, src_partition: usize, dst_partition: usize) -> &mut f64 { + &mut self.values[self.index.keyswitch_to_small(src_partition, dst_partition)] + } + + pub fn fks(&mut self, src_partition: usize, dst_partition: usize) -> &mut f64 { + &mut self.values[self.index.keyswitch_to_big(src_partition, dst_partition)] + } + + pub fn modulus_switching(&mut self, partition: usize) -> &mut f64 { + &mut self.values[self.index.modulus_switching(partition)] + } + + pub fn nb_partitions(&self) -> usize { + self.index.nb_partitions + } +} + +impl Deref for OperationsValue { + type Target = [f64]; + + fn deref(&self) -> &Self::Target { + &self.values + } +} + +impl DerefMut for OperationsValue { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.values + } +} + +impl std::ops::AddAssign for OperationsValue { + fn add_assign(&mut self, rhs: Self) { + if self.values.is_empty() { + *self = rhs; + } else { + for i in 0..self.values.len() { + self.values[i] += rhs.values[i]; + } + } + } +} + +impl std::ops::AddAssign<&Self> for OperationsValue { + fn add_assign(&mut self, rhs: &Self) { + if self.values.is_empty() { + *self = rhs.clone(); + } else { + for i in 0..self.values.len() { + self.values[i] += rhs.values[i]; + } + } + } +} + +impl std::ops::Mul for OperationsValue { + type Output = Self; + fn mul(self, sq_weight: f64) -> Self { + Self { + values: self.values.iter().map(|v| v * sq_weight).collect(), + ..self + } + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs new file mode 100644 index 000000000..67fc59639 --- /dev/null +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/multi_parameters/symbolic_variance.rs @@ -0,0 +1,261 @@ +use std::fmt; + +use crate::optimization::dag::multi_parameters::operations_value::{ + OperationsValue, VALUE_INDEX_FRESH, VALUE_INDEX_PBS, +}; + +/** + * A variance that is represented as a linear combination of base variances. + * Only the linear coefficient are known. + * The base variances are unknown. + * + * Possible base variances: + * - fresh, + * - lut output, + * - keyswitch, + * - partition keyswitch, + * - modulus switching + * + * We only kown that the fresh <= lut ouput in the same partition. + * Each linear coefficient is a variance factor. + * There are homogenious to squared weight (or summed square weights or squared norm2). + */ +#[derive(Clone, Debug, PartialEq, PartialOrd)] +pub struct SymbolicVariance { + pub partition: usize, + pub coeffs: OperationsValue, +} + +impl SymbolicVariance { + // To be used as a initial accumulator + pub const ZERO: Self = Self { + partition: 0, + coeffs: OperationsValue::ZERO, + }; + + pub fn nb_partitions(&self) -> usize { + self.coeffs.nb_partitions() + } + + pub fn nan(nb_partitions: usize) -> Self { + Self { + partition: usize::MAX, + coeffs: OperationsValue::nan(nb_partitions), + } + } + + pub fn input(nb_partitions: usize, partition: usize) -> Self { + let mut r = Self { + partition, + coeffs: OperationsValue::zero(nb_partitions), + }; + // rust ..., offset cannot be inlined + *r.coeffs.input(partition) = 1.0; + r + } + + pub fn coeff_input(&self, partition: usize) -> f64 { + self.coeffs[self.coeffs.index.input(partition)] + } + + pub fn after_pbs(nb_partitions: usize, partition: usize) -> Self { + let mut r = Self { + partition, + coeffs: OperationsValue::zero(nb_partitions), + }; + *r.coeffs.pbs(partition) = 1.0; + r + } + + pub fn coeff_pbs(&self, partition: usize) -> f64 { + self.coeffs[self.coeffs.index.pbs(partition)] + } + + pub fn coeff_modulus_switching(&self, partition: usize) -> f64 { + self.coeffs[self.coeffs.index.modulus_switching(partition)] + } + + pub fn after_modulus_switching(&self, partition: usize) -> Self { + let mut new = self.clone(); + let index = self.coeffs.index.modulus_switching(partition); + assert!(new.coeffs[index] == 0.0); + new.coeffs[index] = 1.0; + new + } + + pub fn coeff_keyswitch_to_small(&self, src_partition: usize, dst_partition: usize) -> f64 { + self.coeffs[self + .coeffs + .index + .keyswitch_to_small(src_partition, dst_partition)] + } + + pub fn after_partition_keyswitch_to_small( + &self, + src_partition: usize, + dst_partition: usize, + ) -> Self { + let index = self + .coeffs + .index + .keyswitch_to_small(src_partition, dst_partition); + self.after_partition_keyswitch(src_partition, dst_partition, index) + } + + pub fn coeff_partition_keyswitch_to_big( + &self, + src_partition: usize, + dst_partition: usize, + ) -> f64 { + self.coeffs[self + .coeffs + .index + .keyswitch_to_big(src_partition, dst_partition)] + } + + pub fn after_partition_keyswitch_to_big( + &self, + src_partition: usize, + dst_partition: usize, + ) -> Self { + let index = self + .coeffs + .index + .keyswitch_to_big(src_partition, dst_partition); + self.after_partition_keyswitch(src_partition, dst_partition, index) + } + + pub fn after_partition_keyswitch( + &self, + src_partition: usize, + dst_partition: usize, + index: usize, + ) -> Self { + assert!(src_partition < self.nb_partitions()); + assert!(dst_partition < self.nb_partitions()); + assert!(src_partition == self.partition); + let mut new = self.clone(); + new.partition = dst_partition; + new.coeffs[index] = 1.0; + new + } + + #[allow(clippy::float_cmp)] + pub fn after_levelled_op(&self, manp: f64) -> Self { + let new_coeff = manp * manp; + // detect the previous base manp level + // this is the maximum value of fresh base noise and pbs base noise + let mut current_max: f64 = 0.0; + for partition in 0..self.nb_partitions() { + let partition_offset = partition * self.coeffs.index.nb_coeff_per_partition(); + let fresh_coeff = self.coeffs[partition_offset + VALUE_INDEX_FRESH]; + let pbs_noise_coeff = self.coeffs[partition_offset + VALUE_INDEX_PBS]; + current_max = current_max.max(fresh_coeff).max(pbs_noise_coeff); + } + assert!(1.0 <= current_max); + assert!( + current_max <= new_coeff, + "Non monotonious levelled op: {current_max} <= {new_coeff}" + ); + // replace all current_max by new_coeff + // multiply everything else by new_coeff / current_max + let mut new = self.clone(); + for cell in &mut new.coeffs.values { + if *cell == current_max { + *cell = new_coeff; + } else { + *cell *= new_coeff / current_max; + } + } + new + } + + pub fn max(&self, other: &Self) -> Self { + let mut coeffs = self.coeffs.clone(); + for (i, coeff) in coeffs.iter_mut().enumerate() { + *coeff = coeff.max(other.coeffs[i]); + } + Self { coeffs, ..*self } + } +} + +impl fmt::Display for SymbolicVariance { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if self == &Self::ZERO { + write!(f, "ZERO x σ²")?; + } + if self.coeffs[0].is_nan() { + write!(f, "NAN x σ²")?; + } + let mut add_plus = ""; + for src_partition in 0..self.nb_partitions() { + let coeff = self.coeff_input(src_partition); + if coeff != 0.0 { + write!(f, "{add_plus}{coeff}σ²In[{src_partition}]")?; + add_plus = " + "; + } + let coeff = self.coeff_pbs(src_partition); + if coeff != 0.0 { + write!(f, "{add_plus}{coeff}σ²Br[{src_partition}]")?; + add_plus = " + "; + } + for dst_partition in 0..self.nb_partitions() { + let coeff = self.coeff_partition_keyswitch_to_big(src_partition, dst_partition); + if coeff != 0.0 { + write!(f, "{add_plus}{coeff}σ²FK[{src_partition}→{dst_partition}]")?; + add_plus = " + "; + } + } + } + for src_partition in 0..self.nb_partitions() { + for dst_partition in 0..self.nb_partitions() { + let coeff = self.coeff_keyswitch_to_small(src_partition, dst_partition); + if coeff != 0.0 { + if src_partition == dst_partition { + write!(f, "{add_plus}{coeff}σ²K[{src_partition}]")?; + } else { + write!(f, "{add_plus}{coeff}σ²K[{src_partition}→{dst_partition}]")?; + } + add_plus = " + "; + } + } + } + for partition in 0..self.nb_partitions() { + let coeff = self.coeff_modulus_switching(partition); + if coeff != 0.0 { + write!(f, "{add_plus}{coeff}σ²M[{partition}]")?; + add_plus = " + "; + } + } + Ok(()) + } +} + +impl std::ops::AddAssign for SymbolicVariance { + fn add_assign(&mut self, rhs: Self) { + if self.coeffs.is_empty() { + *self = rhs; + } else { + for i in 0..self.coeffs.len() { + self.coeffs[i] += rhs.coeffs[i]; + } + } + } +} + +impl std::ops::Mul for SymbolicVariance { + type Output = Self; + fn mul(self, sq_weight: f64) -> Self { + Self { + coeffs: self.coeffs * sq_weight, + ..self + } + } +} + +impl std::ops::Mul for SymbolicVariance { + type Output = Self; + fn mul(self, sq_weight: i64) -> Self { + self * sq_weight as f64 + } +} diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 02f8e14c9..db4cb3f16 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -14,7 +14,7 @@ use std::collections::{HashMap, HashSet}; use {DotKind as DK, VarianceOrigin as VO}; type Op = unparametrized::UnparameterizedOperator; -fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property { +pub fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property { &properties[inputs[0].i] } @@ -83,7 +83,7 @@ pub fn has_round(dag: &unparametrized::OperationDag) -> bool { false } -fn assert_no_round(dag: &unparametrized::OperationDag) { +pub fn assert_no_round(dag: &unparametrized::OperationDag) { assert!(!has_round(dag)); } @@ -197,7 +197,7 @@ fn out_variances(dag: &unparametrized::OperationDag) -> Vec { out_variances } -fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec { +pub fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec { let nb_ops = dag.operators.len(); let mut extra_values_to_check = vec![true; nb_ops]; for op in &dag.operators { @@ -283,7 +283,7 @@ fn op_levelled_complexity( } } -fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity { +pub fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity { let mut levelled_complexity = LevelledComplexity::ZERO; for op in &dag.operators { levelled_complexity += op_levelled_complexity(op, &dag.out_shapes); @@ -301,7 +301,7 @@ pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 { count } -fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f64 { +pub fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f64 { error::safe_variance_bound_2padbits( precision as u64, noise_config.ciphertext_modulus_log, @@ -620,7 +620,7 @@ impl OperationDag { } #[cfg(test)] -mod tests { +pub mod tests { use super::*; use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights}; @@ -641,7 +641,7 @@ mod tests { const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; - const CONFIG: NoiseBoundConfig = NoiseBoundConfig { + pub const CONFIG: NoiseBoundConfig = NoiseBoundConfig { security_level: 128, ciphertext_modulus_log: 64, maximum_acceptable_error_probability: _4_SIGMA,