feat(optimizer): symbolic variance constraints for multiparameters

This commit is contained in:
rudy
2023-02-23 17:37:42 +01:00
committed by Quentin Bourgerie
parent 104ec93881
commit 361244abd0
3 changed files with 330 additions and 4 deletions

View File

@@ -1,14 +1,21 @@
use crate::dag::operator::{dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Shape};
use crate::dag::operator::{
dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape,
};
use crate::dag::rewrite::round::expand_round;
use crate::dag::unparametrized;
use crate::optimization::config::NoiseBoundConfig;
use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred;
use crate::optimization::dag::multi_parameters::partitions::{
InstructionPartition, PartitionIndex,
InstructionPartition, PartitionIndex, Transition,
};
use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut;
use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance;
use crate::optimization::dag::solo_key::analyze::first;
use crate::optimization::dag::solo_key::analyze::{
extra_final_values_to_check, first, safe_noise_bound,
};
use super::variance_constraint::VarianceConstraint;
use crate::utils::square;
// private short convention
@@ -24,11 +31,15 @@ pub struct AnalyzedDag {
pub out_variances: Vec<Vec<SymbolicVariance>>,
// The full dag levelled complexity
pub levelled_complexity: LevelledComplexity,
// All variance constraints including dominated ones
pub variance_constraints: Vec<VarianceConstraint>,
// Undominated variance constraints
pub undominated_variance_constraints: Vec<VarianceConstraint>,
}
pub fn analyze(
dag: &unparametrized::OperationDag,
_noise_config: &NoiseBoundConfig,
noise_config: &NoiseBoundConfig,
p_cut: &PrecisionCut,
default_partition: PartitionIndex,
) -> AnalyzedDag {
@@ -45,12 +56,18 @@ pub fn analyze(
let nb_partitions = partitions.nb_partitions;
let out_variances = out_variances(&dag, nb_partitions, &instrs_partition);
let variance_constraints =
collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances);
let undominated_variance_constraints =
VarianceConstraint::remove_dominated(&variance_constraints);
AnalyzedDag {
operators: dag.operators,
nb_partitions,
instrs_partition,
out_variances,
levelled_complexity,
variance_constraints,
undominated_variance_constraints,
}
}
@@ -144,6 +161,85 @@ fn out_variances(
out_variances
}
fn variance_constraint(
dag: &unparametrized::OperationDag,
noise_config: &NoiseBoundConfig,
partition: PartitionIndex,
op_i: usize,
precision: Precision,
variance: SymbolicVariance,
) -> VarianceConstraint {
let nb_constraints = dag.out_shapes[op_i].flat_size();
let safe_variance_bound = safe_noise_bound(precision, noise_config);
VarianceConstraint {
precision,
partition,
nb_constraints,
safe_variance_bound,
variance,
}
}
#[allow(clippy::float_cmp)]
#[allow(clippy::match_on_vec_items)]
fn collect_all_variance_constraints(
dag: &unparametrized::OperationDag,
noise_config: &NoiseBoundConfig,
instrs_partition: &[InstructionPartition],
out_variances: &[Vec<SymbolicVariance>],
) -> Vec<VarianceConstraint> {
let decryption_points = extra_final_values_to_check(dag);
let mut constraints = vec![];
for (op_i, op) in dag.operators.iter().enumerate() {
let partition = instrs_partition[op_i].instruction_partition;
if let Op::Lut { input, .. } = op {
let precision = dag.out_precisions[input.i];
let dst_partition = partition;
let src_partition = match instrs_partition[op_i].inputs_transition[0] {
None => dst_partition,
Some(Transition::Internal { src_partition }) => {
assert!(src_partition != dst_partition);
src_partition
}
Some(Transition::Additional { src_partition }) => {
assert!(src_partition != dst_partition);
let variance = &out_variances[input.i][dst_partition];
assert!(
variance.coeff_partition_keyswitch_to_big(src_partition, dst_partition)
== 1.0
);
dst_partition
}
};
let variance = &out_variances[input.i][src_partition].clone();
let variance = variance
.after_partition_keyswitch_to_small(src_partition, dst_partition)
.after_modulus_switching(partition);
constraints.push(variance_constraint(
dag,
noise_config,
partition,
op_i,
precision,
variance,
));
}
if decryption_points[op_i] {
let precision = dag.out_precisions[op_i];
let variance = out_variances[op_i][partition].clone();
constraints.push(variance_constraint(
dag,
noise_config,
partition,
op_i,
precision,
variance,
));
}
}
constraints
}
#[cfg(test)]
mod tests {
use super::*;
@@ -444,4 +540,122 @@ mod tests {
assert!(first_bit_extract_verified);
assert!(first_bit_erase_verified);
}
#[test]
fn test_rounded_v3_classic_first_layer_second_layer_constraints() {
let acc_precision = 7;
let precision = 4;
let mut dag = unparametrized::OperationDag::new();
let free_input1 = dag.add_input(precision, Shape::number());
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
let rounded1 = dag.add_expanded_round(input1, precision);
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision);
let old_dag = dag;
let dag = analyze(&old_dag);
show_partitionning(&old_dag, &dag.instrs_partition);
let constraints: Vec<_> = dag
.variance_constraints
.iter()
.map(VarianceConstraint::to_string)
.collect();
let expected_constraints = [
// First lut to force partition HIGH_PRECISION_PARTITION
"1σ²In[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)",
// 16384(shift) = (2**7)², for Br[1]
"16384σ²Br[1] + 16384σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)",
// 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1]
"4096σ²Br[0] + 4096σ²Br[1] + 4096σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)",
// 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1]
"2048σ²Br[0] + 1024σ²Br[1] + 1024σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)",
// 3(erase bit) Br[0] and 1 initial Br[1]
"3σ²Br[0] + 1σ²Br[1] + 1σ²FK[1→0] + 1σ²K[0→1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)",
// Last lut to close the cycle
"1σ²Br[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)",
];
for (c, ec) in constraints.iter().zip(expected_constraints) {
assert!(
c == ec,
"\nBad constraint\nActual: {c}\nTruth : {ec} (expected)\n"
);
}
let simplified_constraints: Vec<_> = dag
.undominated_variance_constraints
.iter()
.map(VarianceConstraint::to_string)
.collect();
let expected_simplified_constraints = [
expected_constraints[1], // biggest weights on Br[1]
expected_constraints[2], // biggest weights on Br[0]
expected_constraints[4], // only one to have K[0→1]
expected_constraints[0], // only one to have K[1]
// 3 is dominated by 2
];
for (c, ec) in simplified_constraints
.iter()
.zip(expected_simplified_constraints)
{
assert!(
c == ec,
"\nBad simplified constraint\nActual: {c}\nTruth : {ec} (expected)\n"
);
}
}
#[test]
fn test_rounded_v1_classic_first_layer_second_layer_constraints() {
let acc_precision = 7;
let precision = 4;
let mut dag = unparametrized::OperationDag::new();
let free_input1 = dag.add_input(precision, Shape::number());
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
// let input1 = dag.add_input(acc_precision, Shape::number());
let rounded1 = dag.add_expanded_round(input1, precision);
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision);
let old_dag = dag;
let dag = analyze_with_preferred(&old_dag, HIGH_PRECISION_PARTITION);
show_partitionning(&old_dag, &dag.instrs_partition);
let constraints: Vec<_> = dag
.variance_constraints
.iter()
.map(VarianceConstraint::to_string)
.collect();
let expected_constraints = [
// First lut to force partition HIGH_PRECISION_PARTITION
"1σ²In[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)",
// 16384(shift) = (2**7)², for Br[1]
"16384σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)",
// 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1]
"4096σ²Br[0] + 4096σ²FK[0→1] + 4096σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)",
// 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1]
"2048σ²Br[0] + 2048σ²FK[0→1] + 1024σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)",
"3σ²Br[0] + 3σ²FK[0→1] + 1σ²Br[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)",
];
for (c, ec) in constraints.iter().zip(expected_constraints) {
assert!(
c == ec,
"\nBad constraint\nActual: {c}\nTruth : {ec} (expected)\n"
);
}
let simplified_constraints: Vec<_> = dag
.undominated_variance_constraints
.iter()
.map(VarianceConstraint::to_string)
.collect();
let expected_simplified_constraints = [
expected_constraints[1], // biggest weights on Br[1]
expected_constraints[2], // biggest weights on Br[0]
expected_constraints[4], // only one to have K[0→1]
expected_constraints[0], // only one to have K[1]
// 3 is dominated by 2
];
for (c, ec) in simplified_constraints
.iter()
.zip(expected_simplified_constraints)
{
assert!(
c == ec,
"\nBad simplified constraint\nActual: {c}\nTruth : {ec} (expected)\n"
);
}
}
}

View File

@@ -6,3 +6,4 @@ pub(crate) mod partitions;
pub(crate) mod precision_cut;
pub(crate) mod symbolic_variance;
pub(crate) mod union_find;
pub(crate) mod variance_constraint;

View File

@@ -0,0 +1,111 @@
use std::fmt;
use crate::dag::operator::Precision;
use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance;
#[derive(Clone, Debug)]
pub struct VarianceConstraint {
pub precision: Precision,
pub partition: PartitionIndex,
pub nb_constraints: u64,
pub safe_variance_bound: f64,
pub variance: SymbolicVariance,
}
impl fmt::Display for VarianceConstraint {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"{} < (2²)**{} ({}bits partition:{} count:{}, dom={})",
self.variance,
self.safe_variance_bound.log2().round() / 2.0,
self.precision,
self.partition,
self.nb_constraints,
self.dominance_index()
)?;
Ok(())
}
}
impl VarianceConstraint {
#[allow(clippy::cast_sign_loss)]
fn dominance_index(&self) -> u64 {
let max_coeff = self
.variance
.coeffs
.iter()
.copied()
.reduce(f64::max)
.unwrap();
(max_coeff / self.safe_variance_bound).log2().ceil() as u64
}
fn dominate_or_equal(&self, other: &Self) -> bool {
// With BR > Fresh
let self_var = &self.variance;
let other_var = &other.variance;
let self_renorm = other.safe_variance_bound / self.safe_variance_bound;
let rel_diff =
|f: &dyn Fn(&SymbolicVariance) -> f64| self_renorm * f(self_var) - f(other_var);
for partition in 0..self.variance.nb_partitions() {
let diffs = [
rel_diff(&|var| var.coeff_pbs(partition)),
rel_diff(&|var| var.coeff_pbs(partition) + var.coeff_input(partition)),
rel_diff(&|var| var.coeff_modulus_switching(partition)),
];
for diff in diffs {
if diff < 0.0 {
return false;
}
}
}
for src_partition in 0..self.variance.nb_partitions() {
for dst_partition in 0..self.variance.nb_partitions() {
let diffs = [
rel_diff(&|var| var.coeff_keyswitch_to_small(src_partition, dst_partition)),
rel_diff(&|var| {
var.coeff_partition_keyswitch_to_big(src_partition, dst_partition)
}),
];
for diff in diffs {
if diff < 0.0 {
return false;
}
}
}
}
true
}
pub fn remove_dominated(constraints: &[Self]) -> Vec<Self> {
let mut constraints = constraints.to_vec();
constraints.sort_by_cached_key(Self::dominance_index);
constraints.reverse();
let mut dominated = vec![false; constraints.len()];
for (i, constraint) in constraints.iter().enumerate() {
if dominated[i] {
continue;
}
for (j, other_constraint) in constraints.iter().enumerate() {
if j <= i {
continue;
}
if constraint.dominate_or_equal(other_constraint) {
dominated[j] = true;
} else if other_constraint.dominate_or_equal(constraint) {
dominated[i] = true;
break;
}
}
}
let mut result = vec![];
for (i, c) in constraints.iter().enumerate() {
if !dominated[i] {
result.push(c.clone());
}
}
result
}
}