feat: dag + solo key optimization

This commit is contained in:
rudy
2022-05-12 10:44:53 +02:00
committed by rudy-6-4
parent 33253a7582
commit 8f2c21ddbe
21 changed files with 1561 additions and 101 deletions

View File

@@ -1,7 +1,7 @@
use concrete_optimizer::graph::operator::{
self, FunctionTable, LevelledComplexity, OperatorIndex, Shape,
use concrete_optimizer::dag::operator::{
self, FunctionTable, LevelledComplexity, OperatorIndex, Precision, Shape,
};
use concrete_optimizer::graph::unparametrized;
use concrete_optimizer::dag::unparametrized;
fn no_solution() -> ffi::Solution {
ffi::Solution {
@@ -63,7 +63,7 @@ fn empty() -> Box<OperationDag> {
}
impl OperationDag {
fn add_input(&mut self, out_precision: u8, out_shape: &[u64]) -> ffi::OperatorIndex {
fn add_input(&mut self, out_precision: Precision, out_shape: &[u64]) -> ffi::OperatorIndex {
let out_shape = Shape {
dimensions_size: out_shape.to_owned(),
};
@@ -87,7 +87,7 @@ impl OperationDag {
) -> ffi::OperatorIndex {
let inputs: Vec<OperatorIndex> = inputs.iter().copied().map(Into::into).collect();
self.0.add_dot(&inputs, &weights.0).into()
self.0.add_dot(inputs, weights.0).into()
}
fn add_levelled_op(
@@ -111,7 +111,7 @@ impl OperationDag {
};
self.0
.add_levelled_op(&inputs, complexity, manp, out_shape, comment)
.add_levelled_op(inputs, complexity, manp, out_shape, comment)
.into()
}
}

View File

@@ -0,0 +1,84 @@
use super::{ClearTensor, Shape};
#[derive(PartialEq, Eq, Debug)]
pub enum DotKind {
// inputs = [x,y,z], weights = [a,b,c], = x*a + y*b + z*c
Simple,
// inputs = [[x, y, z]], weights = [a,b,c], = same
Tensor,
// inputs = [[x], [y], [z]], weights = [[a],[b],[c]], = same
CompatibleTensor,
// inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same]
Broadcast,
Unsupported,
}
pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor) -> DotKind {
let inputs_shape = Shape::duplicated(nb_inputs, input_shape);
if input_shape.is_number() && inputs_shape == weights.shape {
DotKind::Simple
} else if nb_inputs == 1 && *input_shape == weights.shape {
DotKind::Tensor
} else if inputs_shape == weights.shape {
DotKind::CompatibleTensor
} else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape {
DotKind::Broadcast
} else {
DotKind::Unsupported
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::operator::{Shape, Weights};
#[test]
fn test_simple() {
assert_eq!(
dot_kind(2, &Shape::number(), &Weights::vector([1, 2])),
DotKind::Simple
);
}
#[test]
fn test_tensor() {
assert_eq!(
dot_kind(1, &Shape::vector(2), &Weights::vector([1, 2])),
DotKind::Tensor
);
}
#[test]
fn test_broadcast() {
let s2x2 = Shape {
dimensions_size: vec![2, 2],
};
assert_eq!(
dot_kind(1, &s2x2, &Weights::vector([1, 2])),
DotKind::Broadcast
);
}
#[test]
fn test_compatible() {
let weights = ClearTensor {
shape: Shape {
dimensions_size: vec![2, 1],
},
values: vec![1, 2],
};
assert_eq!(
dot_kind(2, &Shape::vector(1), &weights),
DotKind::CompatibleTensor
);
}
#[test]
fn test_unsupported() {
assert_eq!(
dot_kind(3, &Shape::number(), &Weights::vector([1, 2])),
DotKind::Unsupported
);
}
}

View File

@@ -1,6 +1,8 @@
#![allow(clippy::module_inception)]
pub mod dot_kind;
pub mod operator;
pub mod tensor;
pub use self::dot_kind::*;
pub use self::operator::*;
pub use self::tensor::*;

View File

@@ -1,4 +1,4 @@
use crate::graph::operator::tensor::{ClearTensor, Shape};
use crate::dag::operator::tensor::{ClearTensor, Shape};
use derive_more::{Add, AddAssign};
pub type Weights = ClearTensor;
@@ -50,11 +50,13 @@ impl std::ops::Mul<u64> for LevelledComplexity {
}
}
}
pub type Precision = u8;
pub const MIN_PRECISION: Precision = 1;
#[derive(Clone, PartialEq, Debug)]
pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraData> {
Input {
out_precision: u8,
out_precision: Precision,
out_shape: Shape,
extra_data: InputExtraData,
},

View File

@@ -1,10 +1,17 @@
use delegate::delegate;
use crate::utils::square_ref;
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Shape {
pub dimensions_size: Vec<u64>,
}
impl Shape {
pub fn first_dim_size(&self) -> u64 {
self.dimensions_size[0]
}
pub fn rank(&self) -> usize {
self.dimensions_size.len()
}
@@ -43,6 +50,12 @@ impl Shape {
dimensions_size.extend_from_slice(&other.dimensions_size);
Self { dimensions_size }
}
pub fn erase_first_dim(&self) -> Self {
Self {
dimensions_size: self.dimensions_size[1..].to_vec(),
}
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
@@ -51,11 +64,6 @@ pub struct ClearTensor {
pub values: Vec<u64>,
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn square(v: &u64) -> u64 {
v * v
}
impl ClearTensor {
pub fn number(value: u64) -> Self {
Self {
@@ -64,10 +72,11 @@ impl ClearTensor {
}
}
pub fn vector(values: &[u64]) -> Self {
pub fn vector(values: impl Into<Vec<u64>>) -> Self {
let values = values.into();
Self {
shape: Shape::vector(values.len() as u64),
values: values.to_vec(),
values,
}
}
@@ -81,6 +90,27 @@ impl ClearTensor {
}
pub fn square_norm2(&self) -> u64 {
self.values.iter().map(square).sum()
self.values.iter().map(square_ref).sum()
}
}
// helps using shared shapes
impl From<&Self> for Shape {
fn from(item: &Self) -> Self {
item.clone()
}
}
// helps using shared weights
impl From<&Self> for ClearTensor {
fn from(item: &Self) -> Self {
item.clone()
}
}
// helps using array as weights
impl<const N: usize> From<[u64; N]> for ClearTensor {
fn from(item: [u64; N]) -> Self {
Self::vector(item)
}
}

View File

@@ -1,5 +1,5 @@
use crate::dag::parameter_indexed::OperatorParameterIndexed;
use crate::global_parameters::{ParameterRanges, ParameterToOperation};
use crate::graph::parameter_indexed::OperatorParameterIndexed;
#[allow(dead_code)]
pub struct OperationDag {

View File

@@ -1,5 +1,5 @@
use crate::graph::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Shape, Weights,
use crate::dag::operator::{
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
};
pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>;
@@ -21,7 +21,12 @@ impl OperationDag {
OperatorIndex { i }
}
pub fn add_input(&mut self, out_precision: u8, out_shape: Shape) -> OperatorIndex {
pub fn add_input(
&mut self,
out_precision: Precision,
out_shape: impl Into<Shape>,
) -> OperatorIndex {
let out_shape = out_shape.into();
self.add_operator(Operator::Input {
out_precision,
out_shape,
@@ -37,23 +42,30 @@ impl OperationDag {
})
}
pub fn add_dot(&mut self, inputs: &[OperatorIndex], weights: &Weights) -> OperatorIndex {
pub fn add_dot(
&mut self,
inputs: impl Into<Vec<OperatorIndex>>,
weights: impl Into<Weights>,
) -> OperatorIndex {
let inputs = inputs.into();
let weights = weights.into();
self.add_operator(Operator::Dot {
inputs: inputs.to_vec(),
weights: weights.clone(),
inputs,
weights,
extra_data: (),
})
}
pub fn add_levelled_op(
&mut self,
inputs: &[OperatorIndex],
inputs: impl Into<Vec<OperatorIndex>>,
complexity: LevelledComplexity,
manp: f64,
out_shape: Shape,
out_shape: impl Into<Shape>,
comment: &str,
) -> OperatorIndex {
let inputs = inputs.to_vec();
let inputs = inputs.into();
let out_shape = out_shape.into();
let comment = comment.to_string();
let op = Operator::LevelledOp {
inputs,
@@ -75,7 +87,7 @@ impl OperationDag {
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::operator::Shape;
use crate::dag::operator::Shape;
#[test]
fn graph_creation() {
@@ -86,14 +98,14 @@ mod tests {
let input2 = graph.add_input(2, Shape::number());
let cpx_add = LevelledComplexity::ADDITION;
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
let concat =
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat");
graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat");
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
let dot = graph.add_dot([concat], [1, 2]);
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);

View File

@@ -1,5 +1,5 @@
use crate::dag::parameter_indexed::OperatorParameterIndexed;
use crate::global_parameters::{ParameterToOperation, ParameterValues};
use crate::graph::parameter_indexed::OperatorParameterIndexed;
#[allow(dead_code)]
pub struct OperationDag {

View File

@@ -1,11 +1,11 @@
use std::collections::HashSet;
use crate::graph::operator::{Operator, OperatorIndex};
use crate::graph::parameter_indexed::{
use crate::dag::operator::{Operator, OperatorIndex};
use crate::dag::parameter_indexed::{
InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed,
};
use crate::graph::unparametrized::UnparameterizedOperator;
use crate::graph::{parameter_indexed, range_parametrized, unparametrized};
use crate::dag::unparametrized::UnparameterizedOperator;
use crate::dag::{parameter_indexed, range_parametrized, unparametrized};
use crate::parameters::{
BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters,
KsDecompositionParameterRanges, KsDecompositionParameters,
@@ -244,7 +244,7 @@ pub fn domains_to_ranges(
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape};
#[test]
fn test_maximal_unify() {
@@ -255,14 +255,13 @@ mod tests {
let input2 = graph.add_input(2, Shape::number());
let cpx_add = LevelledComplexity::ADDITION;
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
let concat =
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
let dot = graph.add_dot([concat], [1, 2]);
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
@@ -336,7 +335,7 @@ mod tests {
let cpx_add = LevelledComplexity::ADDITION;
let concat =
graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN);

View File

@@ -1,27 +1,29 @@
#![warn(clippy::nursery)]
#![warn(clippy::pedantic)]
#![warn(clippy::style)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::cast_precision_loss)] // u64 to f64
#![allow(clippy::cast_possible_truncation)] // u64 to usize
#![allow(clippy::inline_always)] // needed by delegate
#![allow(clippy::match_wildcard_for_single_variants)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::missing_const_for_fn)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::must_use_candidate)]
#![allow(clippy::return_self_not_must_use)]
#![allow(clippy::similar_names)]
#![allow(clippy::suboptimal_flops)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::match_wildcard_for_single_variants)]
#![allow(clippy::cast_lossless)]
#![warn(unused_results)]
pub mod computing_cost;
pub mod dag;
pub mod global_parameters;
pub mod graph;
pub mod noise_estimator;
pub mod optimization;
pub mod parameters;
pub mod pareto;
pub mod security;
pub mod utils;
pub mod weight;

View File

@@ -1,3 +1,8 @@
use concrete_commons::dispersion::DispersionParameter;
use super::utils;
use crate::utils::square;
pub fn sigma_scale_of_error_probability(p_error: f64) -> f64 {
// https://en.wikipedia.org/wiki/Error_function#Applications
let p_in = 1.0 - p_error;
@@ -9,6 +14,31 @@ pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
1.0 - statrs::function::erf::erf(sigma_scale / 2_f64.sqrt())
}
const LEFT_PADDING_BITS: u64 = 1;
const RIGHT_PADDING_BITS: u64 = 1;
pub fn fatal_noise_limit(precision: u64, ciphertext_modulus_log: u64) -> f64 {
let no_noise_bits = LEFT_PADDING_BITS + precision + RIGHT_PADDING_BITS;
let noise_bits = ciphertext_modulus_log - no_noise_bits;
2_f64.powi(noise_bits as i32)
}
pub fn variance_max(
precision: u64,
ciphertext_modulus_log: u64,
maximum_acceptable_error_probability: f64,
) -> f64 {
let fatal_noise_limit = fatal_noise_limit(precision, ciphertext_modulus_log);
// We want safe_sigma such that:
// P(x not in [-+fatal_noise_limit] | σ = safe_sigma) = p_error
// <=> P(x not in [-+fatal_noise_limit/safe_sigma] | σ = 1) = p_error
// <=> P(x not in [-+kappa] | σ = 1) = p_error, with safe_sigma = fatal_noise_limit / kappa
let kappa = sigma_scale_of_error_probability(maximum_acceptable_error_probability);
let safe_sigma = fatal_noise_limit / kappa;
let modular_variance = square(safe_sigma);
utils::from_modular_variance(modular_variance, ciphertext_modulus_log).get_variance()
}
#[cfg(test)]
mod tests {
use super::*;

View File

@@ -1,9 +1,8 @@
use crate::computing_cost::operators::atomic_pattern as complexity_atomic_pattern;
use crate::computing_cost::operators::keyswitch_lwe::KeySwitchLWEComplexity;
use crate::computing_cost::operators::pbs::PbsComplexity;
use crate::noise_estimator::error::{
error_probability_of_sigma_scale, sigma_scale_of_error_probability,
};
use crate::noise_estimator::error;
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
use crate::parameters::{
AtomicPatternParameters, BrDecompositionParameters, GlweParameters, KeyswitchParameters,
@@ -11,14 +10,11 @@ use crate::parameters::{
};
use crate::pareto;
use crate::security;
use crate::utils::square;
use complexity_atomic_pattern::{AtomicPatternComplexity, DEFAULT as DEFAULT_COMPLEXITY};
use concrete_commons::dispersion::{DispersionParameter, Variance};
use concrete_commons::numeric::UnsignedInteger;
fn square(v: f64) -> f64 {
v * v
}
/* enable to debug */
const CHECKS: bool = false;
/* disable to debug */
@@ -27,7 +23,7 @@ const CUTS: bool = true; // 80ms
const PARETO_CUTS: bool = true; // 75ms
const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true; // 70ms
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Solution {
pub input_lwe_dimension: u64, //n_big
pub internal_ks_output_lwe_dimension: u64, //n_small
@@ -38,27 +34,28 @@ pub struct Solution {
pub br_decomposition_level_count: u64, //l(BR)
pub br_decomposition_base_log: u64, //b(BR)
pub complexity: f64,
pub lut_complexity: f64,
pub noise_max: f64,
pub p_error: f64, // error probability
}
// Constants during optimization of decompositions
struct OptimizationDecompositionsConsts {
kappa: f64,
sum_size: u64,
security_level: u64,
noise_factor: f64,
ciphertext_modulus_log: u64,
keyswitch_decompositions: Vec<KsDecompositionParameters>,
blind_rotate_decompositions: Vec<BrDecompositionParameters>,
variance_max: f64,
// Constants during optimisation of decompositions
pub(crate) struct OptimizationDecompositionsConsts {
pub kappa: f64,
pub sum_size: u64,
pub security_level: u64,
pub noise_factor: f64,
pub ciphertext_modulus_log: u64,
pub keyswitch_decompositions: Vec<KsDecompositionParameters>,
pub blind_rotate_decompositions: Vec<BrDecompositionParameters>,
pub safe_variance: f64,
}
#[derive(Clone, Copy)]
struct ComplexityNoise {
index: usize,
complexity: f64,
noise: f64,
pub(crate) struct ComplexityNoise {
pub index: usize,
pub complexity: f64,
pub noise: f64,
}
impl ComplexityNoise {
@@ -69,7 +66,7 @@ impl ComplexityNoise {
};
}
fn blind_rotate_quantities<W: UnsignedInteger>(
pub(crate) fn pareto_blind_rotate<W: UnsignedInteger>(
consts: &OptimizationDecompositionsConsts,
internal_dim: u64,
glwe_params: GlweParameters,
@@ -107,11 +104,11 @@ fn blind_rotate_quantities<W: UnsignedInteger>(
variance_bsk,
);
let noise_in = base_noise.get_variance() * square(consts.noise_factor);
if cut_noise < noise_in && CUTS {
let noise_out = base_noise.get_variance();
if cut_noise < noise_out && CUTS {
continue; // noise is decreasing
}
if decreasing_variance < noise_in && PARETO_CUTS {
if decreasing_variance < noise_out && PARETO_CUTS {
// the current case is dominated
continue;
}
@@ -124,14 +121,14 @@ fn blind_rotate_quantities<W: UnsignedInteger>(
quantities[size] = ComplexityNoise {
index: i_br,
complexity: complexity_pbs,
noise: noise_in,
noise: noise_out,
};
assert!(
0.0 <= delta_complexity,
"blind_rotate_decompositions should be by increasing complexity"
);
increasing_complexity = complexity_pbs;
decreasing_variance = noise_in;
decreasing_variance = noise_out;
size += 1;
}
assert!(!PARETO_CUTS || size < 64);
@@ -139,7 +136,7 @@ fn blind_rotate_quantities<W: UnsignedInteger>(
quantities
}
fn keyswitch_quantities<W: UnsignedInteger>(
pub(crate) fn pareto_keyswitch<W: UnsignedInteger>(
consts: &OptimizationDecompositionsConsts,
in_dim: u64,
internal_dim: u64,
@@ -226,8 +223,8 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
glwe_poly_size,
)
.get_variance();
let variance_max = consts.variance_max;
if CUTS && noise_modulus_switching > variance_max {
let safe_variance = consts.safe_variance;
if CUTS && noise_modulus_switching > safe_variance {
return;
}
@@ -236,9 +233,9 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
let complexity_multisum = (consts.sum_size * input_lwe_dimension) as f64;
let mut cut_complexity = best_complexity - complexity_multisum;
let mut cut_noise = variance_max - noise_modulus_switching;
let mut cut_noise = safe_variance - noise_modulus_switching;
let br_quantities =
blind_rotate_quantities::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
pareto_blind_rotate::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
if br_quantities.is_empty() {
return;
}
@@ -246,7 +243,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
cut_noise -= br_quantities[br_quantities.len() - 1].noise;
cut_complexity -= br_quantities[0].complexity;
}
let ks_quantities = keyswitch_quantities::<W>(
let ks_quantities = pareto_keyswitch::<W>(
consts,
input_lwe_dimension,
internal_dim,
@@ -259,11 +256,12 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
let i_max_ks = ks_quantities.len() - 1;
let mut i_current_max_ks = i_max_ks;
let square_noise_factor = square(consts.noise_factor);
for br_quantity in br_quantities {
// increasing complexity, decreasing variance
let noise_in = br_quantity.noise;
let noise_in = br_quantity.noise * square_noise_factor;
let noise_max = noise_in + noise_modulus_switching;
if noise_max > variance_max && CUTS {
if noise_max > safe_variance && CUTS {
continue;
}
let complexity_pbs = br_quantity.complexity;
@@ -298,7 +296,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
);
}
if noise_max > variance_max {
if noise_max > safe_variance {
if CROSS_PARETO_CUTS {
// the pareto of 2 added pareto is scanned linearly
// but with all cuts, pre-computing => no gain
@@ -316,9 +314,9 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
// feasible and at least as good complexity
if complexity < best_complexity || noise_max < best_variance {
let sigma = Variance(variance_max).get_standard_dev() * consts.kappa;
let sigma = Variance(safe_variance).get_standard_dev() * consts.kappa;
let sigma_scale = sigma / Variance(noise_max).get_standard_dev();
let p_error = error_probability_of_sigma_scale(sigma_scale);
let p_error = error::error_probability_of_sigma_scale(sigma_scale);
let i_br = br_quantity.index;
let i_ks = ks_quantity.index;
@@ -344,6 +342,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
br_decomposition_base_log: br_b,
noise_max,
complexity,
lut_complexity: complexity_keyswitch + complexity_pbs,
p_error,
});
}
@@ -351,7 +350,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
} // br ks
}
// This function provides reference values with unoptimized code, until we have non regeression tests
// This function provides reference values with unoptimised code, until we have non regeression tests
#[allow(clippy::float_cmp)]
#[allow(clippy::too_many_lines)]
fn assert_checks<W: UnsignedInteger>(
@@ -367,7 +366,7 @@ fn assert_checks<W: UnsignedInteger>(
) {
let i_ks = ks_c_n.index;
let i_br = br_c_n.index;
let noise_in = br_c_n.noise;
let noise_out = br_c_n.noise;
let noise_keyswitch = ks_c_n.noise;
let complexity_keyswitch = ks_c_n.complexity;
let complexity_pbs = br_c_n.complexity;
@@ -398,7 +397,7 @@ fn assert_checks<W: UnsignedInteger>(
.complexity(pbs_parameters, ciphertext_modulus_log);
assert_eq!(complexity_pbs, complexity_pbs_);
assert_eq!(noise_in, noise_in_);
assert_eq!(noise_out * square(consts.noise_factor), noise_in_);
let variance_ksk =
noise_atomic_pattern::variance_ksk(internal_dim, ciphertext_modulus_log, security_level);
@@ -429,7 +428,7 @@ fn assert_checks<W: UnsignedInteger>(
};
let check_max_noise = noise_atomic_pattern::maximal_noise::<Variance, W>(
Variance(noise_in),
Variance(noise_in_),
atomic_pattern_parameters,
ciphertext_modulus_log,
security_level,
@@ -452,8 +451,6 @@ fn assert_checks<W: UnsignedInteger>(
assert!(diff_complexity < 0.0001);
}
const BITS_CARRY: u64 = 1;
const BITS_PADDING_WITHOUT_NOISE: u64 = 1;
const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8;
#[allow(clippy::too_many_lines)]
@@ -479,16 +476,12 @@ pub fn optimize_one<W: UnsignedInteger>(
// the blind rotate decomposition
let ciphertext_modulus_log = W::BITS as u64;
let no_noise_bits = BITS_CARRY + precision + BITS_PADDING_WITHOUT_NOISE;
let noise_bits = ciphertext_modulus_log - no_noise_bits;
let fatal_noise_limit = (1_u64 << noise_bits) as f64;
// Now we search for P(x not in [-+fatal_noise_limit] | σ = safe_sigma) = p_error
// P(x not in [-+kappa] | σ = 1) = p_error
let kappa: f64 = sigma_scale_of_error_probability(maximum_acceptable_error_probability);
let safe_sigma = fatal_noise_limit / kappa;
let variance_max = Variance::from_modular_variance::<W>(square(safe_sigma));
let safe_variance = error::variance_max(
precision,
ciphertext_modulus_log,
maximum_acceptable_error_probability,
);
let kappa = error::sigma_scale_of_error_probability(maximum_acceptable_error_probability);
let consts = OptimizationDecompositionsConsts {
kappa,
@@ -502,7 +495,7 @@ pub fn optimize_one<W: UnsignedInteger>(
blind_rotate_decompositions: pareto::BR_BL
.map(|(log2_base, level)| BrDecompositionParameters { level, log2_base })
.to_vec(),
variance_max: variance_max.get_variance(),
safe_variance,
};
let mut state = OptimizationState {
@@ -524,7 +517,7 @@ pub fn optimize_one<W: UnsignedInteger>(
glwe_poly_size,
)
.get_variance()
> consts.variance_max
> consts.safe_variance
};
let skip = |glwe_dim, glwe_poly_size| match restart_at {

View File

@@ -0,0 +1 @@
pub mod solo_key;

View File

@@ -0,0 +1,587 @@
use super::symbolic_variance::{SymbolicVariance, VarianceOrigin};
use crate::dag::operator::{
dot_kind, DotKind, LevelledComplexity, OperatorIndex, Precision, Shape,
};
use crate::dag::unparametrized;
use crate::utils::square;
// private short convention
use DotKind as DK;
use VarianceOrigin as VO;
type Op = unparametrized::UnparameterizedOperator;
fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property {
&properties[inputs[0].i]
}
fn assert_all_same<Property: PartialEq + std::fmt::Debug>(
inputs: &[OperatorIndex],
properties: &[Property],
) {
let first = first(inputs, properties);
for input in inputs.iter().skip(1) {
assert_eq!(first, &properties[input.i]);
}
}
fn assert_inputs_uniform_precisions(
op: &unparametrized::UnparameterizedOperator,
out_precisions: &[Precision],
) {
if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op {
assert_all_same(inputs, out_precisions);
}
}
fn assert_dot_uniform_inputs_shape(
op: &unparametrized::UnparameterizedOperator,
out_shapes: &[Shape],
) {
if let Op::Dot { inputs, .. } = op {
assert_all_same(inputs, out_shapes);
}
}
fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) {
if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op {
assert!(!inputs.is_empty());
}
}
fn assert_dag_correctness(dag: &unparametrized::OperationDag) {
for op in &dag.operators {
assert_non_empty_inputs(op);
}
}
fn assert_valid_variances(dag: &OperationDag) {
for &out_variance in &dag.out_variances {
assert!(
SymbolicVariance::ZERO == out_variance // Special case of multiply by 0
|| 1.0 <= out_variance.input_vf
|| 1.0 <= out_variance.lut_vf
);
}
}
fn assert_properties_correctness(dag: &OperationDag) {
for op in &dag.operators {
assert_inputs_uniform_precisions(op, &dag.out_precisions);
assert_dot_uniform_inputs_shape(op, &dag.out_shapes);
}
assert_valid_variances(dag);
}
fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance]) -> VarianceOrigin {
let first_origin = first(inputs, out_variances).origin();
for input in inputs.iter().skip(1) {
let item = &out_variances[input.i];
if first_origin != item.origin() {
return VO::Mixed;
}
}
first_origin
}
#[derive(Clone, Debug)]
pub struct OperationDag {
pub operators: Vec<Op>,
// Collect all operators ouput shape
pub out_shapes: Vec<Shape>,
// Collect all operators ouput precision
pub out_precisions: Vec<Precision>,
// Collect all operators ouput variances
pub out_variances: Vec<SymbolicVariance>,
pub nb_luts: u64,
// The full dag levelled complexity
pub levelled_complexity: LevelledComplexity,
// Global summaries of worst noise cases
pub noise_summary: NoiseSummary,
}
#[derive(Clone, Debug)]
pub struct NoiseSummary {
// All final variance factor not entering a lut (usually final levelledOp)
pub pareto_vfs_final: Vec<SymbolicVariance>,
// All variance factor entering a lut
pub pareto_vfs_in_lut: Vec<SymbolicVariance>,
}
impl OperationDag {
pub fn peek_variance(
&self,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
) -> f64 {
peek_variance(
self,
input_noise_out,
blind_rotate_noise_out,
noise_keyswitch,
noise_modulus_switching,
)
}
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
let luts_cost = one_lut_cost * (self.nb_luts as f64);
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
luts_cost + levelled_cost
}
}
fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape {
match op {
Op::Input { out_shape, .. } | Op::LevelledOp { out_shape, .. } => out_shape.clone(),
Op::Lut { input, .. } => out_shapes[input.i].clone(),
Op::Dot {
inputs, weights, ..
} => {
if inputs.is_empty() {
return Shape::number();
}
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor => Shape::number(),
DK::CompatibleTensor => weights.shape.clone(),
DK::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
DK::Unsupported { .. } => panic!("Unsupported"),
}
}
}
}
fn out_shapes(dag: &unparametrized::OperationDag) -> Vec<Shape> {
let nb_ops = dag.operators.len();
let mut out_shapes = Vec::<Shape>::with_capacity(nb_ops);
for op in &dag.operators {
let shape = out_shape(op, &mut out_shapes);
out_shapes.push(shape);
}
out_shapes
}
fn out_precision(
op: &unparametrized::UnparameterizedOperator,
out_precisions: &mut [Precision],
) -> Precision {
match op {
Op::Input { out_precision, .. } => *out_precision,
Op::Lut { input, .. } => out_precisions[input.i],
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i],
}
}
fn out_precisions(dag: &unparametrized::OperationDag) -> Vec<Precision> {
let nb_ops = dag.operators.len();
let mut out_precisions = Vec::<Precision>::with_capacity(nb_ops);
for op in &dag.operators {
let precision = out_precision(op, &mut out_precisions);
out_precisions.push(precision);
}
out_precisions
}
fn out_variance(
op: &unparametrized::UnparameterizedOperator,
out_shapes: &[Shape],
out_variances: &mut [SymbolicVariance],
) -> SymbolicVariance {
// Maintain a linear combination of input_variance and lut_out_variance
// TODO: track each elements instead of container
match op {
Op::Input { .. } => SymbolicVariance::INPUT,
Op::Lut { .. } => SymbolicVariance::LUT,
Op::LevelledOp { inputs, manp, .. } => {
let variance_factor = SymbolicVariance::manp_to_variance_factor(*manp);
let origin = match variance_origin(inputs, out_variances) {
VO::Input => SymbolicVariance::INPUT,
VO::Lut | VO::Mixed /* Mixed: assume the worst */
=> SymbolicVariance::LUT
};
origin * variance_factor
}
Op::Dot {
inputs, weights, ..
} => {
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor => {
let first_input = inputs[0];
let mut out_variance = SymbolicVariance::ZERO;
for (j, &weight) in weights.values.iter().enumerate() {
let k = if kind == DK::Simple {
inputs[j].i
} else {
first_input.i
};
out_variance += out_variances[k] * square(weight);
}
out_variance
}
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
DK::Unsupported { .. } => panic!("Unsupported"),
}
}
}
}
fn out_variances(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
) -> Vec<SymbolicVariance> {
let nb_ops = dag.operators.len();
let mut out_variances = Vec::with_capacity(nb_ops);
for op in &dag.operators {
let vf = out_variance(op, out_shapes, &mut out_variances);
out_variances.push(vf);
}
out_variances
}
fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool> {
let nb_ops = dag.operators.len();
let mut extra_values_to_check = vec![true; nb_ops];
for op in &dag.operators {
match op {
Op::Input { .. } => (),
Op::Lut { input, .. } => {
extra_values_to_check[input.i] = false;
}
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => {
for input in inputs {
extra_values_to_check[input.i] = false;
}
}
}
}
extra_values_to_check
}
fn extra_final_variances(
dag: &unparametrized::OperationDag,
out_variances: &[SymbolicVariance],
) -> Vec<SymbolicVariance> {
extra_final_values_to_check(dag)
.iter()
.enumerate()
.filter_map(|(i, &is_final)| {
if is_final {
Some(out_variances[i])
} else {
None
}
})
.collect()
}
fn in_luts_variance(
dag: &unparametrized::OperationDag,
out_variances: &[SymbolicVariance],
) -> Vec<SymbolicVariance> {
let only_luts = |op| {
if let &Op::Lut { input, .. } = op {
Some(out_variances[input.i])
} else {
None
}
};
dag.operators.iter().filter_map(only_luts).collect()
}
fn op_levelled_complexity(
op: &unparametrized::UnparameterizedOperator,
out_shapes: &[Shape],
) -> LevelledComplexity {
match op {
Op::Dot {
inputs, weights, ..
} => {
let input_shape = first(inputs, out_shapes);
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
match kind {
DK::Simple | DK::Tensor => LevelledComplexity::ADDITION * weights.flat_size(),
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
DK::Unsupported { .. } => panic!("Unsupported"),
}
}
Op::LevelledOp { complexity, .. } => *complexity,
Op::Input { .. } | Op::Lut { .. } => LevelledComplexity::ZERO,
}
}
fn levelled_complexity(
dag: &unparametrized::OperationDag,
out_shapes: &[Shape],
) -> LevelledComplexity {
let mut levelled_complexity = LevelledComplexity::ZERO;
for op in &dag.operators {
levelled_complexity += op_levelled_complexity(op, out_shapes);
}
levelled_complexity
}
fn max_update(current: &mut f64, candidate: f64) {
if candidate > *current {
*current = candidate;
}
}
fn noise_summary(
final_variances: Vec<SymbolicVariance>,
in_luts_variance: Vec<SymbolicVariance>,
) -> NoiseSummary {
let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(final_variances);
let pareto_vfs_in_lut = SymbolicVariance::reduce_to_pareto_front(in_luts_variance);
NoiseSummary {
pareto_vfs_final,
pareto_vfs_in_lut,
}
}
pub fn analyze(dag: &unparametrized::OperationDag) -> OperationDag {
assert_dag_correctness(dag);
let out_shapes = out_shapes(dag);
let out_precisions = out_precisions(dag);
let out_variances = out_variances(dag, &out_shapes);
let in_luts_variance = in_luts_variance(dag, &out_variances);
let nb_luts = in_luts_variance.len() as u64;
let extra_final_variances = extra_final_variances(dag, &out_variances);
let levelled_complexity = levelled_complexity(dag, &out_shapes);
let noise_summary = noise_summary(extra_final_variances, in_luts_variance);
let result = OperationDag {
operators: dag.operators.clone(),
out_shapes,
out_precisions,
out_variances,
nb_luts,
levelled_complexity,
noise_summary,
};
assert_properties_correctness(&result);
result
}
// Compute the maximum attained variance for the full dag
// TODO take a noise summary => peek_error or global error
fn peek_variance(
dag: &OperationDag,
input_noise_out: f64,
blind_rotate_noise_out: f64,
noise_keyswitch: f64,
noise_modulus_switching: f64,
) -> f64 {
assert!(input_noise_out < blind_rotate_noise_out);
let mut variance_peek_final = 0.0; // updated by the loop
for vf in &dag.noise_summary.pareto_vfs_final {
max_update(
&mut variance_peek_final,
vf.eval(input_noise_out, blind_rotate_noise_out),
);
}
let mut variance_peek_in_lut = 0.0; // updated by the loop
for vf in &dag.noise_summary.pareto_vfs_in_lut {
max_update(
&mut variance_peek_in_lut,
vf.eval(input_noise_out, blind_rotate_noise_out),
);
}
let peek_in_lut = variance_peek_in_lut + noise_keyswitch + noise_modulus_switching;
peek_in_lut.max(variance_peek_final)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
use crate::dag::unparametrized;
use crate::utils::square;
fn assert_f64_eq(v: f64, expected: f64) {
approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON);
}
#[test]
fn test_1_input() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT);
assert_eq!(analysis.out_shapes[input1.i], Shape::number());
assert_eq!(analysis.levelled_complexity, LevelledComplexity::ZERO);
assert_eq!(analysis.out_precisions[input1.i], 1);
assert_f64_eq(complexity_cost, 0.0);
assert!(analysis.nb_luts == 0);
let summary = analysis.noise_summary;
assert!(summary.pareto_vfs_final.len() == 1);
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 1.0);
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0);
assert!(summary.pareto_vfs_in_lut.is_empty());
}
#[test]
fn test_1_lut() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(8, Shape::number());
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN);
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT);
assert!(analysis.out_shapes[lut1.i] == Shape::number());
assert!(analysis.levelled_complexity == LevelledComplexity::ZERO);
assert_eq!(analysis.out_precisions[lut1.i], 8);
assert_f64_eq(one_lut_cost, complexity_cost);
let summary = analysis.noise_summary;
assert!(summary.pareto_vfs_final.len() == 1);
assert!(summary.pareto_vfs_in_lut.len() == 1);
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 0.0);
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0);
assert_f64_eq(summary.pareto_vfs_in_lut[0].input_vf, 1.0);
assert_f64_eq(summary.pareto_vfs_in_lut[0].lut_vf, 0.0);
}
#[test]
fn test_1_dot() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let weights = Weights::vector([1, 2]);
let norm2: f64 = 1.0 * 1.0 + 2.0 * 2.0;
let dot = graph.add_dot([input1, input1], weights);
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let expected_var = SymbolicVariance {
input_vf: norm2,
lut_vf: 0.0,
};
assert!(analysis.out_variances[dot.i] == expected_var);
assert!(analysis.out_shapes[dot.i] == Shape::number());
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 2);
assert_eq!(analysis.out_precisions[dot.i], 1);
let expected_dot_cost = (2 * lwe_dim) as f64;
assert_f64_eq(expected_dot_cost, complexity_cost);
let summary = analysis.noise_summary;
assert!(summary.pareto_vfs_in_lut.is_empty());
assert!(summary.pareto_vfs_final.len() == 1);
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0);
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0);
}
#[test]
fn test_1_dot_levelled() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(3, Shape::number());
let cpx_dot = LevelledComplexity::ADDITION;
let weights = Weights::vector([1, 2]);
let manp = 1.0 * 1.0 + 2.0 * 2_f64;
let dot = graph.add_levelled_op([input1, input1], cpx_dot, manp, Shape::number(), "dot");
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
assert!(analysis.out_variances[dot.i].origin() == VO::Input);
assert_eq!(analysis.out_precisions[dot.i], 3);
let expected_square_norm2 = weights.square_norm2() as f64;
let actual_square_norm2 = analysis.out_variances[dot.i].input_vf;
// Due to call on log2() to compute manp the result is not exact
assert_f64_eq(actual_square_norm2, expected_square_norm2);
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION);
assert_f64_eq(lwe_dim as f64, complexity_cost);
let summary = analysis.noise_summary;
assert!(summary.pareto_vfs_in_lut.is_empty());
assert!(summary.pareto_vfs_final.len() == 1);
assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Input);
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0);
}
#[test]
fn test_dot_tensorized_lut_dot_lut() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::vector(2));
let weights = &Weights::vector([1, 2]);
let dot1 = graph.add_dot([input1], weights);
let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN);
let dot2 = graph.add_dot([lut1, lut1], weights);
let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN);
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let expected_var_dot1 = SymbolicVariance {
input_vf: weights.square_norm2() as f64,
lut_vf: 0.0,
};
let expected_var_lut1 = SymbolicVariance {
input_vf: 0.0,
lut_vf: 1.0,
};
let expected_var_dot2 = SymbolicVariance {
input_vf: 0.0,
lut_vf: weights.square_norm2() as f64,
};
let expected_var_lut2 = SymbolicVariance {
input_vf: 0.0,
lut_vf: 1.0,
};
assert!(analysis.out_variances[dot1.i] == expected_var_dot1);
assert!(analysis.out_variances[lut1.i] == expected_var_lut1);
assert!(analysis.out_variances[dot2.i] == expected_var_dot2);
assert!(analysis.out_variances[lut2.i] == expected_var_lut2);
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 4);
let expected_cost = (lwe_dim * 4) as f64 + 2.0 * one_lut_cost;
assert_f64_eq(expected_cost, complexity_cost);
let summary = analysis.noise_summary;
assert_eq!(summary.pareto_vfs_final.len(), 1);
assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Lut);
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0);
assert_eq!(summary.pareto_vfs_in_lut.len(), 1);
assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Lut);
assert_f64_eq(
summary.pareto_vfs_in_lut[0].lut_vf,
weights.square_norm2() as f64,
);
}
#[test]
fn test_lut_dot_mixed_lut() {
let mut graph = unparametrized::OperationDag::new();
let input1 = graph.add_input(1, Shape::number());
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN);
let weights = &Weights::vector([2, 3]);
let dot1 = graph.add_dot([input1, lut1], weights);
let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN);
let analysis = analyze(&graph);
let one_lut_cost = 100.0;
let lwe_dim = 1024;
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
let expected_cost = (2 * lwe_dim) as f64 + 2.0 * one_lut_cost;
assert_f64_eq(expected_cost, complexity_cost);
let expected_mixed = SymbolicVariance {
input_vf: square(weights.values[0] as f64),
lut_vf: square(weights.values[1] as f64),
};
let summary = analysis.noise_summary;
assert_eq!(summary.pareto_vfs_final.len(), 1);
assert_eq!(summary.pareto_vfs_final[0], SymbolicVariance::LUT);
assert_eq!(summary.pareto_vfs_in_lut.len(), 1);
assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Mixed);
assert_eq!(summary.pareto_vfs_in_lut[0], expected_mixed);
}
}

View File

@@ -0,0 +1,3 @@
pub(crate) mod analyze;
pub mod optimize;
pub(crate) mod symbolic_variance;

View File

@@ -0,0 +1,590 @@
use concrete_commons::dispersion::{DispersionParameter, Variance};
use concrete_commons::numeric::UnsignedInteger;
use crate::dag::operator::LevelledComplexity;
use crate::dag::unparametrized;
use crate::noise_estimator::error;
use crate::noise_estimator::error::error_probability_of_sigma_scale;
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
use crate::optimization::atomic_pattern::{
pareto_blind_rotate, pareto_keyswitch, OptimizationDecompositionsConsts, OptimizationState,
Solution,
};
use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositionParameters};
use crate::pareto;
use crate::security::glwe::minimal_variance;
use crate::utils::square;
use super::analyze;
const CUTS: bool = true;
const PARETO_CUTS: bool = true;
const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true;
#[allow(clippy::too_many_lines)]
fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
state: &mut OptimizationState,
consts: &OptimizationDecompositionsConsts,
dag: &analyze::OperationDag,
internal_dim: u64,
glwe_params: GlweParameters,
noise_modulus_switching: f64,
) {
let safe_variance = consts.safe_variance;
let glwe_poly_size = glwe_params.polynomial_size();
let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size;
let mut best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity);
let mut best_lut_complexity = state
.best_solution
.map_or(f64::INFINITY, |s| s.lut_complexity);
let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max);
let mut cut_complexity =
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) / (dag.nb_luts as f64);
let mut cut_noise = safe_variance - noise_modulus_switching;
if dag.nb_luts == 0 {
cut_noise = f64::INFINITY;
cut_complexity = f64::INFINITY;
}
let br_pareto =
pareto_blind_rotate::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
if br_pareto.is_empty() {
return;
}
if PARETO_CUTS {
cut_noise -= br_pareto[br_pareto.len() - 1].noise;
cut_complexity -= br_pareto[0].complexity;
}
let ks_pareto = pareto_keyswitch::<W>(
consts,
input_lwe_dimension,
internal_dim,
cut_complexity,
cut_noise,
);
if ks_pareto.is_empty() {
return;
}
let i_max_ks = ks_pareto.len() - 1;
let mut i_current_max_ks = i_max_ks;
let input_noise_out = minimal_variance(
glwe_params,
consts.ciphertext_modulus_log,
consts.security_level,
)
.get_variance();
let mut best_br_i = 0;
let mut best_ks_i = 0;
let mut update_best_solution = false;
for br_quantity in br_pareto {
// increasing complexity, decreasing variance
let peek_variance = dag.peek_variance(
input_noise_out,
br_quantity.noise,
0.0,
noise_modulus_switching,
);
if peek_variance > safe_variance && CUTS {
continue;
}
let one_pbs_cost = br_quantity.complexity;
let complexity = dag.complexity_cost(input_lwe_dimension, one_pbs_cost);
if complexity > best_complexity {
// As best can evolves it is complementary to blind_rotate_quantities cuts.
if PARETO_CUTS {
break;
} else if CUTS {
continue;
}
}
for i_ks_pareto in (0..=i_current_max_ks).rev() {
// increasing variance, decreasing complexity
let ks_quantity = ks_pareto[i_ks_pareto];
let peek_variance = dag.peek_variance(
input_noise_out,
br_quantity.noise,
ks_quantity.noise,
noise_modulus_switching,
);
// let noise_max = br_quantity.noise * dag.lut_base_noise_worst_lut + ks_quantity.noise + noise_modulus_switching;
if peek_variance > safe_variance {
if CROSS_PARETO_CUTS {
// the pareto of 2 added pareto is scanned linearly
// but with all cuts, pre-computing => no gain
i_current_max_ks = usize::min(i_ks_pareto + 1, i_max_ks);
break;
// it's compatible with next i_br but with the worst complexity
} else if PARETO_CUTS {
// increasing variance => we can skip all remaining
break;
}
continue;
}
let one_lut_cost = ks_quantity.complexity + br_quantity.complexity;
let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost);
let better_complexity = complexity < best_complexity;
#[allow(clippy::float_cmp)]
let same_complexity_with_less_errors =
complexity == best_complexity && peek_variance < best_variance;
if better_complexity || same_complexity_with_less_errors {
best_lut_complexity = one_lut_cost;
best_complexity = complexity;
best_variance = peek_variance;
best_br_i = br_quantity.index;
best_ks_i = ks_quantity.index;
update_best_solution = true;
}
}
} // br ks
if update_best_solution {
let sigma = Variance(safe_variance).get_standard_dev() * consts.kappa;
let sigma_scale = sigma / Variance(best_variance).get_standard_dev();
let p_error = error_probability_of_sigma_scale(sigma_scale);
let BrDecompositionParameters {
level: br_l,
log2_base: br_b,
} = consts.blind_rotate_decompositions[best_br_i];
let KsDecompositionParameters {
level: ks_l,
log2_base: ks_b,
} = consts.keyswitch_decompositions[best_ks_i];
state.best_solution = Some(Solution {
input_lwe_dimension,
internal_ks_output_lwe_dimension: internal_dim,
ks_decomposition_level_count: ks_l,
ks_decomposition_base_log: ks_b,
glwe_polynomial_size: glwe_params.polynomial_size(),
glwe_dimension: glwe_params.glwe_dimension,
br_decomposition_level_count: br_l,
br_decomposition_base_log: br_b,
noise_max: best_variance,
complexity: best_complexity,
lut_complexity: best_lut_complexity,
p_error,
});
}
}
const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8;
#[allow(clippy::too_many_lines)]
pub fn optimize<W: UnsignedInteger>(
dag: &unparametrized::OperationDag,
security_level: u64,
maximum_acceptable_error_probability: f64,
glwe_log_polynomial_sizes: &[u64],
glwe_dimensions: &[u64],
internal_lwe_dimensions: &[u64],
) -> OptimizationState {
let ciphertext_modulus_log = W::BITS as u64;
let dag = analyze::analyze(dag);
let &max_precision = dag.out_precisions.iter().max().unwrap();
let safe_variance = error::variance_max(
max_precision as u64,
ciphertext_modulus_log,
maximum_acceptable_error_probability,
);
let kappa = error::sigma_scale_of_error_probability(maximum_acceptable_error_probability);
let consts = OptimizationDecompositionsConsts {
kappa,
sum_size: 0, // superseeded by dag.complexity_cost
security_level,
noise_factor: f64::NAN, // superseeded by dag.lut_variance_max
ciphertext_modulus_log,
keyswitch_decompositions: pareto::KS_BL
.map(|(log2_base, level)| KsDecompositionParameters { level, log2_base })
.to_vec(),
blind_rotate_decompositions: pareto::BR_BL
.map(|(log2_base, level)| BrDecompositionParameters { level, log2_base })
.to_vec(),
safe_variance,
};
let mut state = OptimizationState {
best_solution: None,
count_domain: glwe_dimensions.len()
* glwe_log_polynomial_sizes.len()
* internal_lwe_dimensions.len()
* consts.keyswitch_decompositions.len()
* consts.blind_rotate_decompositions.len(),
};
let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| {
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key::<W>(
internal_lwe_dimensions,
glwe_poly_size,
)
.get_variance()
};
for &glwe_dim in glwe_dimensions {
for &glwe_log_poly_size in glwe_log_polynomial_sizes {
let glwe_poly_size = 1 << glwe_log_poly_size;
let glwe_params = GlweParameters {
log2_polynomial_size: glwe_log_poly_size,
glwe_dimension: glwe_dim,
};
for &internal_dim in internal_lwe_dimensions {
let noise_modulus_switching = noise_modulus_switching(glwe_poly_size, internal_dim);
if CUTS && noise_modulus_switching > consts.safe_variance {
// assume this noise is increasing with internal_dim
break;
}
update_best_solution_with_best_decompositions::<W>(
&mut state,
&consts,
&dag,
internal_dim,
glwe_params,
noise_modulus_switching,
);
if dag.nb_luts == 0 && state.best_solution.is_some() {
return state;
}
}
}
}
if let Some(sol) = state.best_solution {
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
assert!(sol.p_error <= maximum_acceptable_error_probability * REL_EPSILON_PROBA);
}
state
}
pub fn optimize_v0<W: UnsignedInteger>(
sum_size: u64,
precision: u64,
security_level: u64,
noise_factor: f64,
maximum_acceptable_error_probability: f64,
glwe_log_polynomial_sizes: &[u64],
glwe_dimensions: &[u64],
internal_lwe_dimensions: &[u64],
) -> OptimizationState {
use crate::dag::operator::{FunctionTable, Shape};
let same_scale_manp = 0.0;
let manp = square(noise_factor);
let out_shape = &Shape::number();
let complexity = LevelledComplexity::ADDITION * sum_size;
let comment = "dot";
let mut dag = unparametrized::OperationDag::new();
let input1 = dag.add_input(precision as u8, out_shape);
let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN);
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN);
let mut state = optimize::<u64>(
&dag,
security_level,
maximum_acceptable_error_probability,
glwe_log_polynomial_sizes,
glwe_dimensions,
internal_lwe_dimensions,
);
if let Some(sol) = &mut state.best_solution {
sol.complexity /= 2.0;
}
state
}
#[cfg(test)]
mod tests {
use std::time::Instant;
use crate::dag::operator::{FunctionTable, Shape, Weights};
use crate::global_parameters::DEFAUT_DOMAINS;
use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin;
use crate::utils::square;
use super::*;
use crate::optimization::atomic_pattern;
fn small_relative_diff(v1: f64, v2: f64) -> bool {
f64::abs(v1 - v2) / f64::max(v1, v2) <= f64::EPSILON
}
impl Solution {
fn same(&self, other: Self) -> bool {
let mut other = other;
if small_relative_diff(self.noise_max, other.noise_max)
&& small_relative_diff(self.p_error, other.p_error)
{
other.noise_max = self.noise_max;
other.p_error = self.p_error;
}
self == &other
}
}
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
struct Times {
worst_time: u128,
dag_time: u128,
}
fn assert_f64_eq(v: f64, expected: f64) {
approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON);
}
#[test]
fn test_v0_parameter_ref() {
let mut times = Times {
worst_time: 0,
dag_time: 0,
};
for log_weight in 0..=16 {
let weight = 1 << log_weight;
for precision in 1..=9 {
v0_parameter_ref(precision, weight, &mut times);
}
}
assert!(times.worst_time * 2 > times.dag_time);
}
fn v0_parameter_ref(precision: u64, weight: u64, times: &mut Times) {
let security_level = 128;
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
.glwe_pbs_constrained
.log2_polynomial_size
.as_vec();
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
let sum_size = 1;
let maximum_acceptable_error_probability = _4_SIGMA;
let chrono = Instant::now();
let state = optimize_v0::<u64>(
sum_size,
precision,
security_level,
weight as f64,
maximum_acceptable_error_probability,
&glwe_log_polynomial_sizes,
&glwe_dimensions,
&internal_lwe_dimensions,
);
times.dag_time += chrono.elapsed().as_nanos();
let chrono = Instant::now();
let state_ref = atomic_pattern::optimize_one::<u64>(
sum_size,
precision,
security_level,
weight as f64,
maximum_acceptable_error_probability,
&glwe_log_polynomial_sizes,
&glwe_dimensions,
&internal_lwe_dimensions,
None,
);
times.worst_time += chrono.elapsed().as_nanos();
assert_eq!(
state.best_solution.is_some(),
state_ref.best_solution.is_some()
);
if state.best_solution.is_none() {
return;
}
let sol = state.best_solution.unwrap();
let sol_ref = state_ref.best_solution.unwrap();
assert!(sol.same(sol_ref));
}
#[test]
fn test_v0_parameter_ref_with_dot() {
for log_weight in 0..=16 {
let weight = 1 << log_weight;
for precision in 1..=9 {
v0_parameter_ref_with_dot(precision, weight);
}
}
}
fn v0_parameter_ref_with_dot(precision: u64, weight: u64) {
let mut dag = unparametrized::OperationDag::new();
{
let input1 = dag.add_input(precision as u8, Shape::number());
let dot1 = dag.add_dot([input1], [1]);
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN);
let dot2 = dag.add_dot([lut1], [weight]);
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN);
}
{
let dag2 = analyze::analyze(&dag);
let summary = dag2.noise_summary;
assert_eq!(summary.pareto_vfs_final.len(), 1);
assert_eq!(summary.pareto_vfs_in_lut.len(), 1);
assert_eq!(summary.pareto_vfs_final[0].origin(), VarianceOrigin::Lut);
assert_f64_eq(1.0, summary.pareto_vfs_final[0].lut_vf);
assert!(summary.pareto_vfs_in_lut.len() == 1);
assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VarianceOrigin::Lut);
assert_f64_eq(square(weight) as f64, summary.pareto_vfs_in_lut[0].lut_vf);
}
let security_level = 128;
let maximum_acceptable_error_probability = _4_SIGMA;
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
.glwe_pbs_constrained
.log2_polynomial_size
.as_vec();
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
let state = optimize::<u64>(
&dag,
security_level,
maximum_acceptable_error_probability,
&glwe_log_polynomial_sizes,
&glwe_dimensions,
&internal_lwe_dimensions,
);
let state_ref = atomic_pattern::optimize_one::<u64>(
1,
precision,
security_level,
weight as f64,
maximum_acceptable_error_probability,
&glwe_log_polynomial_sizes,
&glwe_dimensions,
&internal_lwe_dimensions,
None,
);
assert_eq!(
state.best_solution.is_some(),
state_ref.best_solution.is_some()
);
if state.best_solution.is_none() {
return;
}
let sol = state.best_solution.unwrap();
let mut sol_ref = state_ref.best_solution.unwrap();
sol_ref.complexity *= 2.0 /* number of luts */;
assert!(sol.same(sol_ref));
}
fn no_lut_vs_lut(precision: u64) {
let mut dag_lut = unparametrized::OperationDag::new();
let input1 = dag_lut.add_input(precision as u8, Shape::number());
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN);
let mut dag_no_lut = unparametrized::OperationDag::new();
let _input2 = dag_no_lut.add_input(precision as u8, Shape::number());
let security_level = 128;
let maximum_acceptable_error_probability = _4_SIGMA;
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
.glwe_pbs_constrained
.log2_polynomial_size
.as_vec();
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
let opt = |dag: &unparametrized::OperationDag| {
optimize::<u64>(
dag,
security_level,
maximum_acceptable_error_probability,
&glwe_log_polynomial_sizes,
&glwe_dimensions,
&internal_lwe_dimensions,
)
};
let state_no_lut = opt(&dag_no_lut);
let state_lut = opt(&dag_lut);
assert_eq!(
state_no_lut.best_solution.is_some(),
state_lut.best_solution.is_some()
);
if state_lut.best_solution.is_none() {
return;
}
let sol_no_lut = state_no_lut.best_solution.unwrap();
let sol_lut = state_lut.best_solution.unwrap();
assert!(sol_no_lut.complexity < sol_lut.complexity);
}
#[test]
fn test_lut_vs_no_lut() {
for precision in 1..=8 {
no_lut_vs_lut(precision);
}
}
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision: u64, weight: u64) {
let weight = &Weights::number(weight);
let mut dag_1 = unparametrized::OperationDag::new();
{
let input1 = dag_1.add_input(precision as u8, Shape::number());
let scaled_input1 = dag_1.add_dot([input1], weight);
let lut1 = dag_1.add_lut(scaled_input1, FunctionTable::UNKWOWN);
let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN);
}
let mut dag_2 = unparametrized::OperationDag::new();
{
let input1 = dag_2.add_input(precision as u8, Shape::number());
let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN);
let scaled_lut1 = dag_2.add_dot([lut1], weight);
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN);
}
let security_level = 128;
let maximum_acceptable_error_probability = _4_SIGMA;
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
.glwe_pbs_constrained
.log2_polynomial_size
.as_vec();
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
let opt = |dag: &unparametrized::OperationDag| {
optimize::<u64>(
dag,
security_level,
maximum_acceptable_error_probability,
&glwe_log_polynomial_sizes,
&glwe_dimensions,
&internal_lwe_dimensions,
)
};
let state_1 = opt(&dag_1);
let state_2 = opt(&dag_2);
if state_1.best_solution.is_none() {
assert!(state_2.best_solution.is_none());
return;
}
let sol_1 = state_1.best_solution.unwrap();
let sol_2 = state_2.best_solution.unwrap();
assert!(sol_1.complexity < sol_2.complexity || sol_1.p_error < sol_2.p_error);
}
#[test]
fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() {
for log_weight in 1..=16 {
let weight = 1 << log_weight;
for precision in 5..=9 {
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight);
}
}
}
}

View File

@@ -0,0 +1,109 @@
use derive_more::{Add, AddAssign, Sum};
/**
* A variance that is represented as a linear combination of base variances.
* Only the linear coefficient are known.
* The base variances are unknown.
* Each linear coefficients is a variance factor.
*
* Only 2 base variances are possible in the solo key setup:
* - from input,
* - or from lut output.
*
* We only kown that the first one is lower or equal to the second one.
* Each linear coefficient is a variance factor.
* There are homogenious to squared weight (or summed square weights or squared norm2).
*/
#[derive(Clone, Copy, Add, AddAssign, Sum, Debug, PartialEq, PartialOrd)]
pub struct SymbolicVariance {
pub lut_vf: f64,
pub input_vf: f64,
// variance = vf.lut_vf * lut_out_noise
// + vf.input_vf * input_out_noise
// E.g. variance(dot([lut, input], [3, 4])) = VariancesFactors {lut_vf:9, input_vf: 16}
// NOTE: lut_base_noise is the first field since it has higher impact,
// see pareto sorting and dominate_or_equal
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum VarianceOrigin {
Input,
Lut,
Mixed,
}
impl std::ops::Mul<f64> for SymbolicVariance {
type Output = Self;
fn mul(self, sq_weight: f64) -> Self {
Self {
input_vf: self.input_vf * sq_weight,
lut_vf: self.lut_vf * sq_weight,
}
}
}
impl std::ops::Mul<u64> for SymbolicVariance {
type Output = Self;
fn mul(self, sq_weight: u64) -> Self {
self * sq_weight as f64
}
}
impl SymbolicVariance {
pub const ZERO: Self = Self {
input_vf: 0.0,
lut_vf: 0.0,
};
pub const INPUT: Self = Self {
input_vf: 1.0,
lut_vf: 0.0,
};
pub const LUT: Self = Self {
input_vf: 0.0,
lut_vf: 1.0,
};
pub fn origin(&self) -> VarianceOrigin {
if self.lut_vf == 0.0 {
VarianceOrigin::Input
} else if self.input_vf == 0.0 {
VarianceOrigin::Lut
} else {
VarianceOrigin::Mixed
}
}
pub fn manp_to_variance_factor(manp: f64) -> f64 {
manp
}
pub fn dominate_or_equal(&self, other: &Self) -> bool {
let extra_other_minimal_base_noise = 0.0_f64.max(other.input_vf - self.input_vf);
other.lut_vf + extra_other_minimal_base_noise <= self.lut_vf
}
pub fn eval(&self, minimal_base_noise: f64, lut_base_noise: f64) -> f64 {
minimal_base_noise * self.input_vf + lut_base_noise * self.lut_vf
}
pub fn reduce_to_pareto_front(mut vfs: Vec<Self>) -> Vec<Self> {
if vfs.is_empty() {
return vec![];
}
vfs.sort_by(
// bigger first
|a, b| b.partial_cmp(a).unwrap(),
);
// Due to the special domination nature, this can be done in one pass
let mut dominator = vfs[0];
let mut pareto = vec![dominator];
for &vf in vfs.iter().skip(1) {
if dominator.dominate_or_equal(&vf) {
continue;
}
dominator = vf;
pareto.push(vf);
}
pareto
}
}

View File

@@ -1 +1,2 @@
pub mod atomic_pattern;
pub mod dag;

View File

@@ -0,0 +1,15 @@
use std::ops::Mul;
pub fn square<V>(v: V) -> V
where
V: Mul<V> + Mul<Output = V> + Copy,
{
v * v
}
pub fn square_ref<V>(v: &V) -> V
where
V: Mul<V> + Mul<Output = V> + Copy,
{
square(*v)
}