feat(optimizer): symbolic variance for multi-parameter

This commit is contained in:
rudy
2023-02-24 11:12:39 +01:00
committed by Quentin Bourgerie
parent 38646b7559
commit 104ec93881
5 changed files with 898 additions and 8 deletions

View File

@@ -0,0 +1,447 @@
use crate::dag::operator::{dot_kind, DotKind, LevelledComplexity, Operator, OperatorIndex, Shape};
use crate::dag::rewrite::round::expand_round;
use crate::dag::unparametrized;
use crate::optimization::config::NoiseBoundConfig;
use crate::optimization::dag::multi_parameters::partitionning::partitionning_with_preferred;
use crate::optimization::dag::multi_parameters::partitions::{
InstructionPartition, PartitionIndex,
};
use crate::optimization::dag::multi_parameters::precision_cut::PrecisionCut;
use crate::optimization::dag::multi_parameters::symbolic_variance::SymbolicVariance;
use crate::optimization::dag::solo_key::analyze::first;
use crate::utils::square;
// private short convention
use DotKind as DK;
type Op = Operator;
pub struct AnalyzedDag {
pub operators: Vec<Op>,
// Collect all operators ouput variances
pub nb_partitions: usize,
pub instrs_partition: Vec<InstructionPartition>,
pub out_variances: Vec<Vec<SymbolicVariance>>,
// The full dag levelled complexity
pub levelled_complexity: LevelledComplexity,
}
pub fn analyze(
dag: &unparametrized::OperationDag,
_noise_config: &NoiseBoundConfig,
p_cut: &PrecisionCut,
default_partition: PartitionIndex,
) -> AnalyzedDag {
assert!(
p_cut.p_cut.len() <= 1,
"Multi-parameter can only be used 0 or 1 precision cut"
);
let dag = expand_round(dag);
let levelled_complexity = LevelledComplexity::ZERO;
// The precision cut is chosen to work well with rounded pbs
// Note: this is temporary
let partitions = partitionning_with_preferred(&dag, p_cut, default_partition);
let instrs_partition = partitions.instrs_partition;
let nb_partitions = partitions.nb_partitions;
let out_variances = out_variances(&dag, nb_partitions, &instrs_partition);
AnalyzedDag {
operators: dag.operators,
nb_partitions,
instrs_partition,
out_variances,
levelled_complexity,
}
}
fn out_variance(
op: &unparametrized::UnparameterizedOperator,
out_shapes: &[Shape],
out_variances: &mut Vec<Vec<SymbolicVariance>>,
nb_partitions: usize,
instr_partition: &InstructionPartition,
) -> Vec<SymbolicVariance> {
// one variance per partition, in case the result is converted
let partition = instr_partition.instruction_partition;
let out_variance_of = |input: &OperatorIndex| {
assert!(input.i < out_variances.len());
assert!(partition < out_variances[input.i].len());
assert!(out_variances[input.i][partition] != SymbolicVariance::ZERO);
assert!(!out_variances[input.i][partition].coeffs.values[0].is_nan());
assert!(out_variances[input.i][partition].partition != usize::MAX);
out_variances[input.i][partition].clone()
};
let max_variance = |acc: SymbolicVariance, input: SymbolicVariance| acc.max(&input);
let variance = match op {
Op::Input { .. } => SymbolicVariance::input(nb_partitions, partition),
Op::Lut { .. } => SymbolicVariance::after_pbs(nb_partitions, partition),
Op::LevelledOp { inputs, manp, .. } => {
let inputs_variance = inputs.iter().map(out_variance_of);
let max_variance = inputs_variance.reduce(max_variance).unwrap();
max_variance.after_levelled_op(*manp)
}
Op::Dot {
inputs, weights, ..
} => {
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor | DK::Broadcast => {
let inputs_variance = (0..weights.values.len()).map(|j| {
let input = if inputs.len() > 1 {
inputs[j]
} else {
inputs[0]
};
out_variance_of(&input)
});
let mut out_variance = SymbolicVariance::ZERO;
for (input_variance, &weight) in inputs_variance.zip(&weights.values) {
assert!(input_variance != SymbolicVariance::ZERO);
out_variance += input_variance * square(weight);
}
out_variance
}
DK::CompatibleTensor { .. } => todo!("TODO"),
DK::Unsupported { .. } => panic!("Unsupported"),
}
}
Op::UnsafeCast { input, .. } => out_variance_of(input),
Op::Round { .. } => {
unreachable!("Round should have been either expanded or integrated to a lut")
}
};
// Injecting NAN in unused symbolic variance to detect bad use
let unused = SymbolicVariance::nan(nb_partitions);
let mut result = vec![unused; nb_partitions];
for &dst_partition in &instr_partition.alternative_output_representation {
let src_partition = partition;
// make converted variance available in dst_partition
result[dst_partition] =
variance.after_partition_keyswitch_to_big(src_partition, dst_partition);
}
result[partition] = variance;
result
}
fn out_variances(
dag: &unparametrized::OperationDag,
nb_partitions: usize,
instrs_partition: &[InstructionPartition],
) -> Vec<Vec<SymbolicVariance>> {
let nb_ops = dag.operators.len();
let mut out_variances = Vec::with_capacity(nb_ops);
for (op, instr_partition) in dag.operators.iter().zip(instrs_partition) {
let vf = out_variance(
op,
&dag.out_shapes,
&mut out_variances,
nb_partitions,
instr_partition,
);
out_variances.push(vf);
}
out_variances
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::operator::{FunctionTable, Shape};
use crate::dag::unparametrized;
use crate::optimization::dag::multi_parameters::partitionning::tests::{
show_partitionning, HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION,
};
use crate::optimization::dag::solo_key::analyze::tests::CONFIG;
fn analyze(dag: &unparametrized::OperationDag) -> AnalyzedDag {
analyze_with_preferred(dag, LOW_PRECISION_PARTITION)
}
fn analyze_with_preferred(
dag: &unparametrized::OperationDag,
default_partition: PartitionIndex,
) -> AnalyzedDag {
let p_cut = PrecisionCut { p_cut: vec![2] };
super::analyze(dag, &CONFIG, &p_cut, default_partition)
}
#[allow(clippy::float_cmp)]
fn assert_input_on(dag: &AnalyzedDag, partition: usize, op_i: usize, expected_coeff: f64) {
for symbolic_variance_partition in [LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION] {
let sb = dag.out_variances[op_i][partition].clone();
let coeff = if sb == SymbolicVariance::ZERO {
0.0
} else {
sb.coeff_input(symbolic_variance_partition)
};
if symbolic_variance_partition == partition {
assert!(
coeff == expected_coeff,
"INCORRECT INPUT COEFF ON GOOD PARTITION {:?} {:?} {} {}",
dag.out_variances[op_i],
partition,
coeff,
expected_coeff
);
} else {
assert!(
coeff == 0.0,
"INCORRECT INPUT COEFF ON WRONG PARTITION {:?} {:?} {} {}",
dag.out_variances[op_i],
partition,
coeff,
expected_coeff
);
}
}
}
#[allow(clippy::float_cmp)]
fn assert_pbs_on(dag: &AnalyzedDag, partition: usize, op_i: usize, expected_coeff: f64) {
for symbolic_variance_partition in [LOW_PRECISION_PARTITION, HIGH_PRECISION_PARTITION] {
let sb = dag.out_variances[op_i][partition].clone();
eprintln!("{:?}", dag.out_variances[op_i]);
eprintln!("{:?}", dag.out_variances[op_i][partition]);
let coeff = if sb == SymbolicVariance::ZERO {
0.0
} else {
sb.coeff_pbs(symbolic_variance_partition)
};
if symbolic_variance_partition == partition {
assert!(
coeff == expected_coeff,
"INCORRECT PBS COEFF ON GOOD PARTITION {:?} {:?} {} {}",
dag.out_variances[op_i],
partition,
coeff,
expected_coeff
);
} else {
assert!(
coeff == 0.0,
"INCORRECT PBS COEFF ON GOOD PARTITION {:?} {:?} {} {}",
dag.out_variances[op_i],
partition,
coeff,
expected_coeff
);
}
}
}
#[allow(clippy::needless_range_loop)]
#[test]
fn test_lut_sequence() {
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(8, Shape::number());
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8);
let lut2 = dag.add_lut(lut1, FunctionTable::UNKWOWN, 1);
let lut3 = dag.add_lut(lut2, FunctionTable::UNKWOWN, 1);
let lut4 = dag.add_lut(lut3, FunctionTable::UNKWOWN, 8);
let lut5 = dag.add_lut(lut4, FunctionTable::UNKWOWN, 8);
let partitions = [
HIGH_PRECISION_PARTITION,
HIGH_PRECISION_PARTITION,
HIGH_PRECISION_PARTITION,
LOW_PRECISION_PARTITION,
LOW_PRECISION_PARTITION,
HIGH_PRECISION_PARTITION,
];
let dag = analyze(&dag);
assert!(dag.nb_partitions == 2);
for op_i in input1.i..=lut5.i {
let p = &dag.instrs_partition[op_i];
let is_input = op_i == input1.i;
assert!(p.instruction_partition == partitions[op_i]);
if is_input {
assert_input_on(&dag, p.instruction_partition, op_i, 1.0);
assert_pbs_on(&dag, p.instruction_partition, op_i, 0.0);
} else {
assert_pbs_on(&dag, p.instruction_partition, op_i, 1.0);
assert_input_on(&dag, p.instruction_partition, op_i, 0.0);
}
}
}
#[test]
fn test_levelled_op() {
let mut dag = unparametrized::OperationDag::new();
let out_shape = Shape::number();
let manp = 8.0;
let input1 = dag.add_input(8, Shape::number());
let input2 = dag.add_input(8, Shape::number());
let lut1 = dag.add_lut(input1, FunctionTable::UNKWOWN, 8);
let _levelled = dag.add_levelled_op(
[lut1, input2],
LevelledComplexity::ZERO,
manp,
&out_shape,
"comment",
);
let dag = analyze(&dag);
assert!(dag.nb_partitions == 1);
}
fn nan_symbolic_variance(sb: &SymbolicVariance) -> bool {
sb.coeffs[0].is_nan()
}
#[allow(clippy::float_cmp)]
#[test]
fn test_rounded_v3_first_layer_and_second_layer() {
let acc_precision = 16;
let precision = 8;
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(acc_precision, Shape::number());
let rounded1 = dag.add_expanded_round(input1, precision);
let lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision);
let rounded2 = dag.add_expanded_round(lut1, precision);
let lut2 = dag.add_lut(rounded2, FunctionTable::UNKWOWN, acc_precision);
let old_dag = dag;
let dag = analyze(&old_dag);
show_partitionning(&old_dag, &dag.instrs_partition);
// First layer is fully LOW_PRECISION_PARTITION
for op_i in input1.i..lut1.i {
let p = LOW_PRECISION_PARTITION;
let sb = &dag.out_variances[op_i][p];
assert!(sb.coeff_input(p) >= 1.0 || sb.coeff_pbs(p) >= 1.0);
assert!(nan_symbolic_variance(
&dag.out_variances[op_i][HIGH_PRECISION_PARTITION]
));
}
// First lut is HIGH_PRECISION_PARTITION and immedialtely converted to LOW_PRECISION_PARTITION
let p = HIGH_PRECISION_PARTITION;
let sb = &dag.out_variances[lut1.i][p];
assert!(sb.coeff_input(p) == 0.0);
assert!(sb.coeff_pbs(p) == 1.0);
let sb_after_fast_ks = &dag.out_variances[lut1.i][LOW_PRECISION_PARTITION];
assert!(
sb_after_fast_ks.coeff_partition_keyswitch_to_big(
HIGH_PRECISION_PARTITION,
LOW_PRECISION_PARTITION
) == 1.0
);
// The next rounded is on LOW_PRECISION_PARTITION but base noise can comes from HIGH_PRECISION_PARTITION + FKS
for op_i in (lut1.i + 1)..lut2.i {
assert!(LOW_PRECISION_PARTITION == dag.instrs_partition[op_i].instruction_partition);
let p = LOW_PRECISION_PARTITION;
let sb = &dag.out_variances[op_i][p];
// The base noise is either from the other partition and shifted or from the current partition and 1
assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0);
assert!(sb.coeff_input(HIGH_PRECISION_PARTITION) == 0.0);
if sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0 {
assert!(
sb.coeff_pbs(HIGH_PRECISION_PARTITION)
== sb.coeff_partition_keyswitch_to_big(
HIGH_PRECISION_PARTITION,
LOW_PRECISION_PARTITION
)
);
} else {
assert!(sb.coeff_pbs(LOW_PRECISION_PARTITION) == 1.0);
assert!(
sb.coeff_partition_keyswitch_to_big(
HIGH_PRECISION_PARTITION,
LOW_PRECISION_PARTITION
) == 0.0
);
}
}
assert!(nan_symbolic_variance(
&dag.out_variances[lut2.i][LOW_PRECISION_PARTITION]
));
let sb = &dag.out_variances[lut2.i][HIGH_PRECISION_PARTITION];
assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) >= 1.0);
}
#[allow(clippy::float_cmp, clippy::cognitive_complexity)]
#[test]
fn test_rounded_v3_classic_first_layer_second_layer() {
let acc_precision = 16;
let precision = 8;
let mut dag = unparametrized::OperationDag::new();
let free_input1 = dag.add_input(precision, Shape::number());
let input1 = dag.add_lut(free_input1, FunctionTable::UNKWOWN, acc_precision);
let rounded1 = dag.add_expanded_round(input1, precision);
let _lut1 = dag.add_lut(rounded1, FunctionTable::UNKWOWN, acc_precision);
let old_dag = dag;
let dag = analyze(&old_dag);
show_partitionning(&old_dag, &dag.instrs_partition);
// First layer is fully HIGH_PRECISION_PARTITION
assert!(
dag.out_variances[free_input1.i][HIGH_PRECISION_PARTITION]
.coeff_input(HIGH_PRECISION_PARTITION)
== 1.0
);
// First layer tlu
let sb = &dag.out_variances[input1.i][HIGH_PRECISION_PARTITION];
assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0);
assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0);
assert!(
sb.coeff_partition_keyswitch_to_big(HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION)
== 0.0
);
// The same cyphertext exists in another partition with additional noise due to fast keyswitch
let sb = &dag.out_variances[input1.i][LOW_PRECISION_PARTITION];
assert!(sb.coeff_input(LOW_PRECISION_PARTITION) == 0.0);
assert!(sb.coeff_pbs(HIGH_PRECISION_PARTITION) == 1.0);
assert!(
sb.coeff_partition_keyswitch_to_big(HIGH_PRECISION_PARTITION, LOW_PRECISION_PARTITION)
== 1.0
);
// Second layer
let mut first_bit_extract_verified = false;
let mut first_bit_erase_verified = false;
for op_i in (input1.i + 1)..rounded1.i {
if let Op::Dot {
weights, inputs, ..
} = &dag.operators[op_i]
{
let bit_extract = weights.values.len() == 1;
let first_bit_extract = bit_extract && !first_bit_extract_verified;
let bit_erase = weights.values == [1, -1];
let first_bit_erase = bit_erase && !first_bit_erase_verified;
let input0_sb = &dag.out_variances[inputs[0].i][LOW_PRECISION_PARTITION];
let input0_coeff_pbs_high = input0_sb.coeff_pbs(HIGH_PRECISION_PARTITION);
let input0_coeff_pbs_low = input0_sb.coeff_pbs(LOW_PRECISION_PARTITION);
let input0_coeff_fks = input0_sb.coeff_partition_keyswitch_to_big(
HIGH_PRECISION_PARTITION,
LOW_PRECISION_PARTITION,
);
if bit_extract {
first_bit_extract_verified |= first_bit_extract;
assert!(input0_coeff_pbs_high >= 1.0);
if first_bit_extract {
assert!(input0_coeff_pbs_low == 0.0);
} else {
assert!(input0_coeff_pbs_low >= 1.0);
}
assert!(input0_coeff_fks == 1.0);
} else if bit_erase {
first_bit_erase_verified |= first_bit_erase;
let input1_sb = &dag.out_variances[inputs[1].i][LOW_PRECISION_PARTITION];
let input1_coeff_pbs_high = input1_sb.coeff_pbs(HIGH_PRECISION_PARTITION);
let input1_coeff_pbs_low = input1_sb.coeff_pbs(LOW_PRECISION_PARTITION);
let input1_coeff_fks = input1_sb.coeff_partition_keyswitch_to_big(
HIGH_PRECISION_PARTITION,
LOW_PRECISION_PARTITION,
);
if first_bit_erase {
assert!(input0_coeff_pbs_low == 0.0);
} else {
assert!(input0_coeff_pbs_low >= 1.0);
}
assert!(input0_coeff_pbs_high == 1.0);
assert!(input0_coeff_fks == 1.0);
assert!(input1_coeff_pbs_low == 1.0);
assert!(input1_coeff_pbs_high == 0.0);
assert!(input1_coeff_fks == 0.0);
}
}
}
assert!(first_bit_extract_verified);
assert!(first_bit_erase_verified);
}
}

View File

@@ -1,5 +1,8 @@
pub mod analyze;
pub mod keys_spec;
pub mod partitionning;
pub(crate) mod operations_value;
pub(crate) mod partitionning;
pub(crate) mod partitions;
pub(crate) mod precision_cut;
pub(crate) mod symbolic_variance;
pub(crate) mod union_find;

View File

@@ -0,0 +1,179 @@
use std::ops::{Deref, DerefMut};
/**
* Index actual operations (input, ks, pbs, fks, modulus switching, etc).
*/
#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd)]
pub struct Indexing {
/* Values order
[
// Partition 1
// related only to the partition
fresh, pbs, modulus,
// Keyswitchs to small, from any partition to 1
ks from 1, ks from 2, ...
// Keyswitch to big, from any partition to 1
ks from 1, ks from 2, ...
// Partition 2
// same
]
*/
pub nb_partitions: usize,
}
pub const VALUE_INDEX_FRESH: usize = 0;
pub const VALUE_INDEX_PBS: usize = 1;
pub const VALUE_INDEX_MODULUS: usize = 2;
// number of value always present for a partition
pub const STABLE_NB_VALUES_BY_PARTITION: usize = 3;
impl Indexing {
fn nb_keyswitchs_per_partition(self) -> usize {
self.nb_partitions
}
pub fn nb_coeff_per_partition(self) -> usize {
STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions
}
pub fn nb_coeff(self) -> usize {
self.nb_partitions * (STABLE_NB_VALUES_BY_PARTITION + 2 * self.nb_partitions)
}
pub fn input(self, partition: usize) -> usize {
partition * self.nb_coeff_per_partition() + VALUE_INDEX_FRESH
}
pub fn pbs(self, partition: usize) -> usize {
partition * self.nb_coeff_per_partition() + VALUE_INDEX_PBS
}
pub fn modulus_switching(self, partition: usize) -> usize {
partition * self.nb_coeff_per_partition() + VALUE_INDEX_MODULUS
}
pub fn keyswitch_to_small(self, src_partition: usize, dst_partition: usize) -> usize {
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
+ STABLE_NB_VALUES_BY_PARTITION
// Select the right keyswicth to small
+ src_partition
}
pub fn keyswitch_to_big(self, src_partition: usize, dst_partition: usize) -> usize {
// Skip other partition
dst_partition * self.nb_coeff_per_partition()
// Skip non keyswitchs
+ STABLE_NB_VALUES_BY_PARTITION
// Skip keyswitch to small
+ self.nb_keyswitchs_per_partition()
// Select the right keyswicth to big
+ src_partition
}
}
/**
* Represent any values indexed by actual operations (input, pbs, modulus switching, ks, fks, , etc) variance,
*/
#[derive(Clone, Debug, PartialEq, PartialOrd)]
pub struct OperationsValue {
pub index: Indexing,
pub values: Vec<f64>,
}
impl OperationsValue {
pub const ZERO: Self = Self {
index: Indexing { nb_partitions: 0 },
values: vec![],
};
pub fn zero(nb_partitions: usize) -> Self {
let index = Indexing { nb_partitions };
Self {
index,
values: vec![0.0; index.nb_coeff()],
}
}
pub fn nan(nb_partitions: usize) -> Self {
let index = Indexing { nb_partitions };
Self {
index,
values: vec![f64::NAN; index.nb_coeff()],
}
}
pub fn input(&mut self, partition: usize) -> &mut f64 {
&mut self.values[self.index.input(partition)]
}
pub fn pbs(&mut self, partition: usize) -> &mut f64 {
&mut self.values[self.index.pbs(partition)]
}
pub fn ks(&mut self, src_partition: usize, dst_partition: usize) -> &mut f64 {
&mut self.values[self.index.keyswitch_to_small(src_partition, dst_partition)]
}
pub fn fks(&mut self, src_partition: usize, dst_partition: usize) -> &mut f64 {
&mut self.values[self.index.keyswitch_to_big(src_partition, dst_partition)]
}
pub fn modulus_switching(&mut self, partition: usize) -> &mut f64 {
&mut self.values[self.index.modulus_switching(partition)]
}
pub fn nb_partitions(&self) -> usize {
self.index.nb_partitions
}
}
impl Deref for OperationsValue {
type Target = [f64];
fn deref(&self) -> &Self::Target {
&self.values
}
}
impl DerefMut for OperationsValue {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.values
}
}
impl std::ops::AddAssign for OperationsValue {
fn add_assign(&mut self, rhs: Self) {
if self.values.is_empty() {
*self = rhs;
} else {
for i in 0..self.values.len() {
self.values[i] += rhs.values[i];
}
}
}
}
impl std::ops::AddAssign<&Self> for OperationsValue {
fn add_assign(&mut self, rhs: &Self) {
if self.values.is_empty() {
*self = rhs.clone();
} else {
for i in 0..self.values.len() {
self.values[i] += rhs.values[i];
}
}
}
}
impl std::ops::Mul<f64> for OperationsValue {
type Output = Self;
fn mul(self, sq_weight: f64) -> Self {
Self {
values: self.values.iter().map(|v| v * sq_weight).collect(),
..self
}
}
}

View File

@@ -0,0 +1,261 @@
use std::fmt;
use crate::optimization::dag::multi_parameters::operations_value::{
OperationsValue, VALUE_INDEX_FRESH, VALUE_INDEX_PBS,
};
/**
* A variance that is represented as a linear combination of base variances.
* Only the linear coefficient are known.
* The base variances are unknown.
*
* Possible base variances:
* - fresh,
* - lut output,
* - keyswitch,
* - partition keyswitch,
* - modulus switching
*
* We only kown that the fresh <= lut ouput in the same partition.
* Each linear coefficient is a variance factor.
* There are homogenious to squared weight (or summed square weights or squared norm2).
*/
#[derive(Clone, Debug, PartialEq, PartialOrd)]
pub struct SymbolicVariance {
pub partition: usize,
pub coeffs: OperationsValue,
}
impl SymbolicVariance {
// To be used as a initial accumulator
pub const ZERO: Self = Self {
partition: 0,
coeffs: OperationsValue::ZERO,
};
pub fn nb_partitions(&self) -> usize {
self.coeffs.nb_partitions()
}
pub fn nan(nb_partitions: usize) -> Self {
Self {
partition: usize::MAX,
coeffs: OperationsValue::nan(nb_partitions),
}
}
pub fn input(nb_partitions: usize, partition: usize) -> Self {
let mut r = Self {
partition,
coeffs: OperationsValue::zero(nb_partitions),
};
// rust ..., offset cannot be inlined
*r.coeffs.input(partition) = 1.0;
r
}
pub fn coeff_input(&self, partition: usize) -> f64 {
self.coeffs[self.coeffs.index.input(partition)]
}
pub fn after_pbs(nb_partitions: usize, partition: usize) -> Self {
let mut r = Self {
partition,
coeffs: OperationsValue::zero(nb_partitions),
};
*r.coeffs.pbs(partition) = 1.0;
r
}
pub fn coeff_pbs(&self, partition: usize) -> f64 {
self.coeffs[self.coeffs.index.pbs(partition)]
}
pub fn coeff_modulus_switching(&self, partition: usize) -> f64 {
self.coeffs[self.coeffs.index.modulus_switching(partition)]
}
pub fn after_modulus_switching(&self, partition: usize) -> Self {
let mut new = self.clone();
let index = self.coeffs.index.modulus_switching(partition);
assert!(new.coeffs[index] == 0.0);
new.coeffs[index] = 1.0;
new
}
pub fn coeff_keyswitch_to_small(&self, src_partition: usize, dst_partition: usize) -> f64 {
self.coeffs[self
.coeffs
.index
.keyswitch_to_small(src_partition, dst_partition)]
}
pub fn after_partition_keyswitch_to_small(
&self,
src_partition: usize,
dst_partition: usize,
) -> Self {
let index = self
.coeffs
.index
.keyswitch_to_small(src_partition, dst_partition);
self.after_partition_keyswitch(src_partition, dst_partition, index)
}
pub fn coeff_partition_keyswitch_to_big(
&self,
src_partition: usize,
dst_partition: usize,
) -> f64 {
self.coeffs[self
.coeffs
.index
.keyswitch_to_big(src_partition, dst_partition)]
}
pub fn after_partition_keyswitch_to_big(
&self,
src_partition: usize,
dst_partition: usize,
) -> Self {
let index = self
.coeffs
.index
.keyswitch_to_big(src_partition, dst_partition);
self.after_partition_keyswitch(src_partition, dst_partition, index)
}
pub fn after_partition_keyswitch(
&self,
src_partition: usize,
dst_partition: usize,
index: usize,
) -> Self {
assert!(src_partition < self.nb_partitions());
assert!(dst_partition < self.nb_partitions());
assert!(src_partition == self.partition);
let mut new = self.clone();
new.partition = dst_partition;
new.coeffs[index] = 1.0;
new
}
#[allow(clippy::float_cmp)]
pub fn after_levelled_op(&self, manp: f64) -> Self {
let new_coeff = manp * manp;
// detect the previous base manp level
// this is the maximum value of fresh base noise and pbs base noise
let mut current_max: f64 = 0.0;
for partition in 0..self.nb_partitions() {
let partition_offset = partition * self.coeffs.index.nb_coeff_per_partition();
let fresh_coeff = self.coeffs[partition_offset + VALUE_INDEX_FRESH];
let pbs_noise_coeff = self.coeffs[partition_offset + VALUE_INDEX_PBS];
current_max = current_max.max(fresh_coeff).max(pbs_noise_coeff);
}
assert!(1.0 <= current_max);
assert!(
current_max <= new_coeff,
"Non monotonious levelled op: {current_max} <= {new_coeff}"
);
// replace all current_max by new_coeff
// multiply everything else by new_coeff / current_max
let mut new = self.clone();
for cell in &mut new.coeffs.values {
if *cell == current_max {
*cell = new_coeff;
} else {
*cell *= new_coeff / current_max;
}
}
new
}
pub fn max(&self, other: &Self) -> Self {
let mut coeffs = self.coeffs.clone();
for (i, coeff) in coeffs.iter_mut().enumerate() {
*coeff = coeff.max(other.coeffs[i]);
}
Self { coeffs, ..*self }
}
}
impl fmt::Display for SymbolicVariance {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if self == &Self::ZERO {
write!(f, "ZERO x σ²")?;
}
if self.coeffs[0].is_nan() {
write!(f, "NAN x σ²")?;
}
let mut add_plus = "";
for src_partition in 0..self.nb_partitions() {
let coeff = self.coeff_input(src_partition);
if coeff != 0.0 {
write!(f, "{add_plus}{coeff}σ²In[{src_partition}]")?;
add_plus = " + ";
}
let coeff = self.coeff_pbs(src_partition);
if coeff != 0.0 {
write!(f, "{add_plus}{coeff}σ²Br[{src_partition}]")?;
add_plus = " + ";
}
for dst_partition in 0..self.nb_partitions() {
let coeff = self.coeff_partition_keyswitch_to_big(src_partition, dst_partition);
if coeff != 0.0 {
write!(f, "{add_plus}{coeff}σ²FK[{src_partition}→{dst_partition}]")?;
add_plus = " + ";
}
}
}
for src_partition in 0..self.nb_partitions() {
for dst_partition in 0..self.nb_partitions() {
let coeff = self.coeff_keyswitch_to_small(src_partition, dst_partition);
if coeff != 0.0 {
if src_partition == dst_partition {
write!(f, "{add_plus}{coeff}σ²K[{src_partition}]")?;
} else {
write!(f, "{add_plus}{coeff}σ²K[{src_partition}→{dst_partition}]")?;
}
add_plus = " + ";
}
}
}
for partition in 0..self.nb_partitions() {
let coeff = self.coeff_modulus_switching(partition);
if coeff != 0.0 {
write!(f, "{add_plus}{coeff}σ²M[{partition}]")?;
add_plus = " + ";
}
}
Ok(())
}
}
impl std::ops::AddAssign for SymbolicVariance {
fn add_assign(&mut self, rhs: Self) {
if self.coeffs.is_empty() {
*self = rhs;
} else {
for i in 0..self.coeffs.len() {
self.coeffs[i] += rhs.coeffs[i];
}
}
}
}
impl std::ops::Mul<f64> for SymbolicVariance {
type Output = Self;
fn mul(self, sq_weight: f64) -> Self {
Self {
coeffs: self.coeffs * sq_weight,
..self
}
}
}
impl std::ops::Mul<i64> for SymbolicVariance {
type Output = Self;
fn mul(self, sq_weight: i64) -> Self {
self * sq_weight as f64
}
}

View File

@@ -14,7 +14,7 @@ use std::collections::{HashMap, HashSet};
use {DotKind as DK, VarianceOrigin as VO};
type Op = unparametrized::UnparameterizedOperator;
fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property {
pub fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property {
&properties[inputs[0].i]
}
@@ -83,7 +83,7 @@ pub fn has_round(dag: &unparametrized::OperationDag) -> bool {
false
}
fn assert_no_round(dag: &unparametrized::OperationDag) {
pub fn assert_no_round(dag: &unparametrized::OperationDag) {
assert!(!has_round(dag));
}
@@ -197,7 +197,7 @@ fn out_variances(dag: &unparametrized::OperationDag) -> Vec<SymbolicVariance> {
out_variances
}
fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool> {
pub fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool> {
let nb_ops = dag.operators.len();
let mut extra_values_to_check = vec![true; nb_ops];
for op in &dag.operators {
@@ -283,7 +283,7 @@ fn op_levelled_complexity(
}
}
fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity {
pub fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity {
let mut levelled_complexity = LevelledComplexity::ZERO;
for op in &dag.operators {
levelled_complexity += op_levelled_complexity(op, &dag.out_shapes);
@@ -301,7 +301,7 @@ pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 {
count
}
fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f64 {
pub fn safe_noise_bound(precision: Precision, noise_config: &NoiseBoundConfig) -> f64 {
error::safe_variance_bound_2padbits(
precision as u64,
noise_config.ciphertext_modulus_log,
@@ -620,7 +620,7 @@ impl OperationDag {
}
#[cfg(test)]
mod tests {
pub mod tests {
use super::*;
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
@@ -641,7 +641,7 @@ mod tests {
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
const CONFIG: NoiseBoundConfig = NoiseBoundConfig {
pub const CONFIG: NoiseBoundConfig = NoiseBoundConfig {
security_level: 128,
ciphertext_modulus_log: 64,
maximum_acceptable_error_probability: _4_SIGMA,