mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 03:25:05 -05:00
feat(optimizer): symbolic variance constraints for multiparameters
This commit is contained in:
@@ -1,14 +1,21 @@
|
||||
use crate::dag::operator::{dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Shape};
|
||||
use crate::dag::operator::{
|
||||
dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Precision, Shape,
|
||||
};
|
||||
use crate::dag::rewrite::round::expand_round;
|
||||
use crate::dag::unparametrized;
|
||||
use crate::optimization::config::NoiseBoundConfig;
|
||||
use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred;
|
||||
use crate::optimization::dag::multi_parameters::partitions::{
|
||||
InstructionPartition, PartitionIndex,
|
||||
InstructionPartition, PartitionIndex, Transition,
|
||||
};
|
||||
use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut;
|
||||
use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance;
|
||||
use crate::optimization::dag::solo_key::analyze::first;
|
||||
use crate::optimization::dag::solo_key::analyze::{
|
||||
extra_final_values_to_check, first, safe_noise_bound,
|
||||
};
|
||||
|
||||
use super::variance_constraint::VarianceConstraint;
|
||||
|
||||
use crate::utils::square;
|
||||
|
||||
// private short convention
|
||||
@@ -24,11 +31,15 @@ pub struct AnalyzedDag {
|
||||
pub out_variances: Vec<Vec<SymbolicVariance>>,
|
||||
// The full dag levelled complexity
|
||||
pub levelled_complexity: LevelledComplexity,
|
||||
// All variance constraints including dominated ones
|
||||
pub variance_constraints: Vec<VarianceConstraint>,
|
||||
// Undominated variance constraints
|
||||
pub undominated_variance_constraints: Vec<VarianceConstraint>,
|
||||
}
|
||||
|
||||
pub fn analyze(
|
||||
dag: &unparametrized::OperationDag,
|
||||
_noise_config: &NoiseBoundConfig,
|
||||
noise_config: &NoiseBoundConfig,
|
||||
p_cut: &PrecisionCut,
|
||||
default_partition: PartitionIndex,
|
||||
) -> AnalyzedDag {
|
||||
@@ -45,12 +56,18 @@ pub fn analyze(
|
||||
let nb_partitions = partitions.nb_partitions;
|
||||
let out_variances = out_variances(&dag, nb_partitions, &instrs_partition);
|
||||
|
||||
let variance_constraints =
|
||||
collect_all_variance_constraints(&dag, noise_config, &instrs_partition, &out_variances);
|
||||
let undominated_variance_constraints =
|
||||
VarianceConstraint::remove_dominated(&variance_constraints);
|
||||
AnalyzedDag {
|
||||
operators: dag.operators,
|
||||
nb_partitions,
|
||||
instrs_partition,
|
||||
out_variances,
|
||||
levelled_complexity,
|
||||
variance_constraints,
|
||||
undominated_variance_constraints,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +161,85 @@ fn out_variances(
|
||||
out_variances
|
||||
}
|
||||
|
||||
fn variance_constraint(
|
||||
dag: &unparametrized::OperationDag,
|
||||
noise_config: &NoiseBoundConfig,
|
||||
partition: PartitionIndex,
|
||||
op_i: usize,
|
||||
precision: Precision,
|
||||
variance: SymbolicVariance,
|
||||
) -> VarianceConstraint {
|
||||
let nb_constraints = dag.out_shapes[op_i].flat_size();
|
||||
let safe_variance_bound = safe_noise_bound(precision, noise_config);
|
||||
VarianceConstraint {
|
||||
precision,
|
||||
partition,
|
||||
nb_constraints,
|
||||
safe_variance_bound,
|
||||
variance,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::float_cmp)]
|
||||
#[allow(clippy::match_on_vec_items)]
|
||||
fn collect_all_variance_constraints(
|
||||
dag: &unparametrized::OperationDag,
|
||||
noise_config: &NoiseBoundConfig,
|
||||
instrs_partition: &[InstructionPartition],
|
||||
out_variances: &[Vec<SymbolicVariance>],
|
||||
) -> Vec<VarianceConstraint> {
|
||||
let decryption_points = extra_final_values_to_check(dag);
|
||||
let mut constraints = vec![];
|
||||
for (op_i, op) in dag.operators.iter().enumerate() {
|
||||
let partition = instrs_partition[op_i].instruction_partition;
|
||||
if let Op::Lut { input, .. } = op {
|
||||
let precision = dag.out_precisions[input.i];
|
||||
let dst_partition = partition;
|
||||
let src_partition = match instrs_partition[op_i].inputs_transition[0] {
|
||||
None => dst_partition,
|
||||
Some(Transition::Internal { src_partition }) => {
|
||||
assert!(src_partition != dst_partition);
|
||||
src_partition
|
||||
}
|
||||
Some(Transition::Additional { src_partition }) => {
|
||||
assert!(src_partition != dst_partition);
|
||||
let variance = &out_variances[input.i][dst_partition];
|
||||
assert!(
|
||||
variance.coeff_partition_keyswitch_to_big(src_partition, dst_partition)
|
||||
== 1.0
|
||||
);
|
||||
dst_partition
|
||||
}
|
||||
};
|
||||
let variance = &out_variances[input.i][src_partition].clone();
|
||||
let variance = variance
|
||||
.after_partition_keyswitch_to_small(src_partition, dst_partition)
|
||||
.after_modulus_switching(partition);
|
||||
constraints.push(variance_constraint(
|
||||
dag,
|
||||
noise_config,
|
||||
partition,
|
||||
op_i,
|
||||
precision,
|
||||
variance,
|
||||
));
|
||||
}
|
||||
if decryption_points[op_i] {
|
||||
let precision = dag.out_precisions[op_i];
|
||||
let variance = out_variances[op_i][partition].clone();
|
||||
constraints.push(variance_constraint(
|
||||
dag,
|
||||
noise_config,
|
||||
partition,
|
||||
op_i,
|
||||
precision,
|
||||
variance,
|
||||
));
|
||||
}
|
||||
}
|
||||
constraints
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -444,4 +540,122 @@ mod tests {
|
||||
assert!(first_bit_extract_verified);
|
||||
assert!(first_bit_erase_verified);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rounded_v3_classic_first_layer_second_layer_constraints() {
|
||||
let acc_precision = 7;
|
||||
let precision = 4;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let free_input1 = dag.add_input(precision, Shape::number());
|
||||
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
|
||||
let rounded1 = dag.add_expanded_round(input1, precision);
|
||||
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision);
|
||||
let old_dag = dag;
|
||||
let dag = analyze(&old_dag);
|
||||
show_partitionning(&old_dag, &dag.instrs_partition);
|
||||
let constraints: Vec<_> = dag
|
||||
.variance_constraints
|
||||
.iter()
|
||||
.map(VarianceConstraint::to_string)
|
||||
.collect();
|
||||
let expected_constraints = [
|
||||
// First lut to force partition HIGH_PRECISION_PARTITION
|
||||
"1σ²In[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)",
|
||||
// 16384(shift) = (2**7)², for Br[1]
|
||||
"16384σ²Br[1] + 16384σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)",
|
||||
// 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1]
|
||||
"4096σ²Br[0] + 4096σ²Br[1] + 4096σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)",
|
||||
// 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1]
|
||||
"2048σ²Br[0] + 1024σ²Br[1] + 1024σ²FK[1→0] + 1σ²K[0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)",
|
||||
// 3(erase bit) Br[0] and 1 initial Br[1]
|
||||
"3σ²Br[0] + 1σ²Br[1] + 1σ²FK[1→0] + 1σ²K[0→1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)",
|
||||
// Last lut to close the cycle
|
||||
"1σ²Br[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)",
|
||||
];
|
||||
for (c, ec) in constraints.iter().zip(expected_constraints) {
|
||||
assert!(
|
||||
c == ec,
|
||||
"\nBad constraint\nActual: {c}\nTruth : {ec} (expected)\n"
|
||||
);
|
||||
}
|
||||
let simplified_constraints: Vec<_> = dag
|
||||
.undominated_variance_constraints
|
||||
.iter()
|
||||
.map(VarianceConstraint::to_string)
|
||||
.collect();
|
||||
let expected_simplified_constraints = [
|
||||
expected_constraints[1], // biggest weights on Br[1]
|
||||
expected_constraints[2], // biggest weights on Br[0]
|
||||
expected_constraints[4], // only one to have K[0→1]
|
||||
expected_constraints[0], // only one to have K[1]
|
||||
// 3 is dominated by 2
|
||||
];
|
||||
for (c, ec) in simplified_constraints
|
||||
.iter()
|
||||
.zip(expected_simplified_constraints)
|
||||
{
|
||||
assert!(
|
||||
c == ec,
|
||||
"\nBad simplified constraint\nActual: {c}\nTruth : {ec} (expected)\n"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rounded_v1_classic_first_layer_second_layer_constraints() {
|
||||
let acc_precision = 7;
|
||||
let precision = 4;
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let free_input1 = dag.add_input(precision, Shape::number());
|
||||
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
|
||||
// let input1 = dag.add_input(acc_precision, Shape::number());
|
||||
let rounded1 = dag.add_expanded_round(input1, precision);
|
||||
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, precision);
|
||||
let old_dag = dag;
|
||||
let dag = analyze_with_preferred(&old_dag, HIGH_PRECISION_PARTITION);
|
||||
show_partitionning(&old_dag, &dag.instrs_partition);
|
||||
let constraints: Vec<_> = dag
|
||||
.variance_constraints
|
||||
.iter()
|
||||
.map(VarianceConstraint::to_string)
|
||||
.collect();
|
||||
let expected_constraints = [
|
||||
// First lut to force partition HIGH_PRECISION_PARTITION
|
||||
"1σ²In[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=16)",
|
||||
// 16384(shift) = (2**7)², for Br[1]
|
||||
"16384σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=22)",
|
||||
// 4096(shift) = (2**6)², 1(due to 1 erase bit) for Br[0] and 1 for Br[1]
|
||||
"4096σ²Br[0] + 4096σ²FK[0→1] + 4096σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=20)",
|
||||
// 1024(shift) = (2**5)², 2(due to 2 erase bit for Br[0] and 1 for Br[1]
|
||||
"2048σ²Br[0] + 2048σ²FK[0→1] + 1024σ²Br[1] + 1σ²K[1→0] + 1σ²M[0] < (2²)**-4 (0bits partition:0 count:1, dom=19)",
|
||||
"3σ²Br[0] + 3σ²FK[0→1] + 1σ²Br[1] + 1σ²K[1] + 1σ²M[1] < (2²)**-8 (4bits partition:1 count:1, dom=18)",
|
||||
];
|
||||
for (c, ec) in constraints.iter().zip(expected_constraints) {
|
||||
assert!(
|
||||
c == ec,
|
||||
"\nBad constraint\nActual: {c}\nTruth : {ec} (expected)\n"
|
||||
);
|
||||
}
|
||||
let simplified_constraints: Vec<_> = dag
|
||||
.undominated_variance_constraints
|
||||
.iter()
|
||||
.map(VarianceConstraint::to_string)
|
||||
.collect();
|
||||
let expected_simplified_constraints = [
|
||||
expected_constraints[1], // biggest weights on Br[1]
|
||||
expected_constraints[2], // biggest weights on Br[0]
|
||||
expected_constraints[4], // only one to have K[0→1]
|
||||
expected_constraints[0], // only one to have K[1]
|
||||
// 3 is dominated by 2
|
||||
];
|
||||
for (c, ec) in simplified_constraints
|
||||
.iter()
|
||||
.zip(expected_simplified_constraints)
|
||||
{
|
||||
assert!(
|
||||
c == ec,
|
||||
"\nBad simplified constraint\nActual: {c}\nTruth : {ec} (expected)\n"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,3 +6,4 @@ pub(crate) mod partitions;
|
||||
pub(crate) mod precision_cut;
|
||||
pub(crate) mod symbolic_variance;
|
||||
pub(crate) mod union_find;
|
||||
pub(crate) mod variance_constraint;
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
use std::fmt;
|
||||
|
||||
use crate::dag::operator::Precision;
|
||||
use crate::optimization::dag::multi_parameters::partitions::PartitionIndex;
|
||||
use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct VarianceConstraint {
|
||||
pub precision: Precision,
|
||||
pub partition: PartitionIndex,
|
||||
pub nb_constraints: u64,
|
||||
pub safe_variance_bound: f64,
|
||||
pub variance: SymbolicVariance,
|
||||
}
|
||||
|
||||
impl fmt::Display for VarianceConstraint {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{} < (2²)**{} ({}bits partition:{} count:{}, dom={})",
|
||||
self.variance,
|
||||
self.safe_variance_bound.log2().round() / 2.0,
|
||||
self.precision,
|
||||
self.partition,
|
||||
self.nb_constraints,
|
||||
self.dominance_index()
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl VarianceConstraint {
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
fn dominance_index(&self) -> u64 {
|
||||
let max_coeff = self
|
||||
.variance
|
||||
.coeffs
|
||||
.iter()
|
||||
.copied()
|
||||
.reduce(f64::max)
|
||||
.unwrap();
|
||||
(max_coeff / self.safe_variance_bound).log2().ceil() as u64
|
||||
}
|
||||
|
||||
fn dominate_or_equal(&self, other: &Self) -> bool {
|
||||
// With BR > Fresh
|
||||
let self_var = &self.variance;
|
||||
let other_var = &other.variance;
|
||||
let self_renorm = other.safe_variance_bound / self.safe_variance_bound;
|
||||
let rel_diff =
|
||||
|f: &dyn Fn(&SymbolicVariance) -> f64| self_renorm * f(self_var) - f(other_var);
|
||||
for partition in 0..self.variance.nb_partitions() {
|
||||
let diffs = [
|
||||
rel_diff(&|var| var.coeff_pbs(partition)),
|
||||
rel_diff(&|var| var.coeff_pbs(partition) + var.coeff_input(partition)),
|
||||
rel_diff(&|var| var.coeff_modulus_switching(partition)),
|
||||
];
|
||||
for diff in diffs {
|
||||
if diff < 0.0 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for src_partition in 0..self.variance.nb_partitions() {
|
||||
for dst_partition in 0..self.variance.nb_partitions() {
|
||||
let diffs = [
|
||||
rel_diff(&|var| var.coeff_keyswitch_to_small(src_partition, dst_partition)),
|
||||
rel_diff(&|var| {
|
||||
var.coeff_partition_keyswitch_to_big(src_partition, dst_partition)
|
||||
}),
|
||||
];
|
||||
for diff in diffs {
|
||||
if diff < 0.0 {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
pub fn remove_dominated(constraints: &[Self]) -> Vec<Self> {
|
||||
let mut constraints = constraints.to_vec();
|
||||
constraints.sort_by_cached_key(Self::dominance_index);
|
||||
constraints.reverse();
|
||||
let mut dominated = vec![false; constraints.len()];
|
||||
for (i, constraint) in constraints.iter().enumerate() {
|
||||
if dominated[i] {
|
||||
continue;
|
||||
}
|
||||
for (j, other_constraint) in constraints.iter().enumerate() {
|
||||
if j <= i {
|
||||
continue;
|
||||
}
|
||||
if constraint.dominate_or_equal(other_constraint) {
|
||||
dominated[j] = true;
|
||||
} else if other_constraint.dominate_or_equal(constraint) {
|
||||
dominated[i] = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut result = vec![];
|
||||
for (i, c) in constraints.iter().enumerate() {
|
||||
if !dominated[i] {
|
||||
result.push(c.clone());
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user