mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: dag + solo key optimization
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
use concrete_optimizer::graph::operator::{
|
||||
self, FunctionTable, LevelledComplexity, OperatorIndex, Shape,
|
||||
use concrete_optimizer::dag::operator::{
|
||||
self, FunctionTable, LevelledComplexity, OperatorIndex, Precision, Shape,
|
||||
};
|
||||
use concrete_optimizer::graph::unparametrized;
|
||||
use concrete_optimizer::dag::unparametrized;
|
||||
|
||||
fn no_solution() -> ffi::Solution {
|
||||
ffi::Solution {
|
||||
@@ -63,7 +63,7 @@ fn empty() -> Box<OperationDag> {
|
||||
}
|
||||
|
||||
impl OperationDag {
|
||||
fn add_input(&mut self, out_precision: u8, out_shape: &[u64]) -> ffi::OperatorIndex {
|
||||
fn add_input(&mut self, out_precision: Precision, out_shape: &[u64]) -> ffi::OperatorIndex {
|
||||
let out_shape = Shape {
|
||||
dimensions_size: out_shape.to_owned(),
|
||||
};
|
||||
@@ -87,7 +87,7 @@ impl OperationDag {
|
||||
) -> ffi::OperatorIndex {
|
||||
let inputs: Vec<OperatorIndex> = inputs.iter().copied().map(Into::into).collect();
|
||||
|
||||
self.0.add_dot(&inputs, &weights.0).into()
|
||||
self.0.add_dot(inputs, weights.0).into()
|
||||
}
|
||||
|
||||
fn add_levelled_op(
|
||||
@@ -111,7 +111,7 @@ impl OperationDag {
|
||||
};
|
||||
|
||||
self.0
|
||||
.add_levelled_op(&inputs, complexity, manp, out_shape, comment)
|
||||
.add_levelled_op(inputs, complexity, manp, out_shape, comment)
|
||||
.into()
|
||||
}
|
||||
}
|
||||
|
||||
84
concrete-optimizer/src/dag/operator/dot_kind.rs
Normal file
84
concrete-optimizer/src/dag/operator/dot_kind.rs
Normal file
@@ -0,0 +1,84 @@
|
||||
use super::{ClearTensor, Shape};
|
||||
|
||||
#[derive(PartialEq, Eq, Debug)]
|
||||
pub enum DotKind {
|
||||
// inputs = [x,y,z], weights = [a,b,c], = x*a + y*b + z*c
|
||||
Simple,
|
||||
// inputs = [[x, y, z]], weights = [a,b,c], = same
|
||||
Tensor,
|
||||
// inputs = [[x], [y], [z]], weights = [[a],[b],[c]], = same
|
||||
CompatibleTensor,
|
||||
// inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same]
|
||||
Broadcast,
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor) -> DotKind {
|
||||
let inputs_shape = Shape::duplicated(nb_inputs, input_shape);
|
||||
if input_shape.is_number() && inputs_shape == weights.shape {
|
||||
DotKind::Simple
|
||||
} else if nb_inputs == 1 && *input_shape == weights.shape {
|
||||
DotKind::Tensor
|
||||
} else if inputs_shape == weights.shape {
|
||||
DotKind::CompatibleTensor
|
||||
} else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape {
|
||||
DotKind::Broadcast
|
||||
} else {
|
||||
DotKind::Unsupported
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::dag::operator::{Shape, Weights};
|
||||
|
||||
#[test]
|
||||
fn test_simple() {
|
||||
assert_eq!(
|
||||
dot_kind(2, &Shape::number(), &Weights::vector([1, 2])),
|
||||
DotKind::Simple
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor() {
|
||||
assert_eq!(
|
||||
dot_kind(1, &Shape::vector(2), &Weights::vector([1, 2])),
|
||||
DotKind::Tensor
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_broadcast() {
|
||||
let s2x2 = Shape {
|
||||
dimensions_size: vec![2, 2],
|
||||
};
|
||||
assert_eq!(
|
||||
dot_kind(1, &s2x2, &Weights::vector([1, 2])),
|
||||
DotKind::Broadcast
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compatible() {
|
||||
let weights = ClearTensor {
|
||||
shape: Shape {
|
||||
dimensions_size: vec![2, 1],
|
||||
},
|
||||
values: vec![1, 2],
|
||||
};
|
||||
assert_eq!(
|
||||
dot_kind(2, &Shape::vector(1), &weights),
|
||||
DotKind::CompatibleTensor
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unsupported() {
|
||||
assert_eq!(
|
||||
dot_kind(3, &Shape::number(), &Weights::vector([1, 2])),
|
||||
DotKind::Unsupported
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#![allow(clippy::module_inception)]
|
||||
pub mod dot_kind;
|
||||
pub mod operator;
|
||||
pub mod tensor;
|
||||
|
||||
pub use self::dot_kind::*;
|
||||
pub use self::operator::*;
|
||||
pub use self::tensor::*;
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::graph::operator::tensor::{ClearTensor, Shape};
|
||||
use crate::dag::operator::tensor::{ClearTensor, Shape};
|
||||
use derive_more::{Add, AddAssign};
|
||||
|
||||
pub type Weights = ClearTensor;
|
||||
@@ -50,11 +50,13 @@ impl std::ops::Mul<u64> for LevelledComplexity {
|
||||
}
|
||||
}
|
||||
}
|
||||
pub type Precision = u8;
|
||||
pub const MIN_PRECISION: Precision = 1;
|
||||
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub enum Operator<InputExtraData, LutExtraData, DotExtraData, LevelledOpExtraData> {
|
||||
Input {
|
||||
out_precision: u8,
|
||||
out_precision: Precision,
|
||||
out_shape: Shape,
|
||||
extra_data: InputExtraData,
|
||||
},
|
||||
@@ -1,10 +1,17 @@
|
||||
use delegate::delegate;
|
||||
|
||||
use crate::utils::square_ref;
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
pub struct Shape {
|
||||
pub dimensions_size: Vec<u64>,
|
||||
}
|
||||
|
||||
impl Shape {
|
||||
pub fn first_dim_size(&self) -> u64 {
|
||||
self.dimensions_size[0]
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
self.dimensions_size.len()
|
||||
}
|
||||
@@ -43,6 +50,12 @@ impl Shape {
|
||||
dimensions_size.extend_from_slice(&other.dimensions_size);
|
||||
Self { dimensions_size }
|
||||
}
|
||||
|
||||
pub fn erase_first_dim(&self) -> Self {
|
||||
Self {
|
||||
dimensions_size: self.dimensions_size[1..].to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq, Debug)]
|
||||
@@ -51,11 +64,6 @@ pub struct ClearTensor {
|
||||
pub values: Vec<u64>,
|
||||
}
|
||||
|
||||
#[allow(clippy::trivially_copy_pass_by_ref)]
|
||||
fn square(v: &u64) -> u64 {
|
||||
v * v
|
||||
}
|
||||
|
||||
impl ClearTensor {
|
||||
pub fn number(value: u64) -> Self {
|
||||
Self {
|
||||
@@ -64,10 +72,11 @@ impl ClearTensor {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vector(values: &[u64]) -> Self {
|
||||
pub fn vector(values: impl Into<Vec<u64>>) -> Self {
|
||||
let values = values.into();
|
||||
Self {
|
||||
shape: Shape::vector(values.len() as u64),
|
||||
values: values.to_vec(),
|
||||
values,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,6 +90,27 @@ impl ClearTensor {
|
||||
}
|
||||
|
||||
pub fn square_norm2(&self) -> u64 {
|
||||
self.values.iter().map(square).sum()
|
||||
self.values.iter().map(square_ref).sum()
|
||||
}
|
||||
}
|
||||
|
||||
// helps using shared shapes
|
||||
impl From<&Self> for Shape {
|
||||
fn from(item: &Self) -> Self {
|
||||
item.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// helps using shared weights
|
||||
impl From<&Self> for ClearTensor {
|
||||
fn from(item: &Self) -> Self {
|
||||
item.clone()
|
||||
}
|
||||
}
|
||||
|
||||
// helps using array as weights
|
||||
impl<const N: usize> From<[u64; N]> for ClearTensor {
|
||||
fn from(item: [u64; N]) -> Self {
|
||||
Self::vector(item)
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::dag::parameter_indexed::OperatorParameterIndexed;
|
||||
use crate::global_parameters::{ParameterRanges, ParameterToOperation};
|
||||
use crate::graph::parameter_indexed::OperatorParameterIndexed;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct OperationDag {
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::graph::operator::{
|
||||
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Shape, Weights,
|
||||
use crate::dag::operator::{
|
||||
FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights,
|
||||
};
|
||||
|
||||
pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>;
|
||||
@@ -21,7 +21,12 @@ impl OperationDag {
|
||||
OperatorIndex { i }
|
||||
}
|
||||
|
||||
pub fn add_input(&mut self, out_precision: u8, out_shape: Shape) -> OperatorIndex {
|
||||
pub fn add_input(
|
||||
&mut self,
|
||||
out_precision: Precision,
|
||||
out_shape: impl Into<Shape>,
|
||||
) -> OperatorIndex {
|
||||
let out_shape = out_shape.into();
|
||||
self.add_operator(Operator::Input {
|
||||
out_precision,
|
||||
out_shape,
|
||||
@@ -37,23 +42,30 @@ impl OperationDag {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_dot(&mut self, inputs: &[OperatorIndex], weights: &Weights) -> OperatorIndex {
|
||||
pub fn add_dot(
|
||||
&mut self,
|
||||
inputs: impl Into<Vec<OperatorIndex>>,
|
||||
weights: impl Into<Weights>,
|
||||
) -> OperatorIndex {
|
||||
let inputs = inputs.into();
|
||||
let weights = weights.into();
|
||||
self.add_operator(Operator::Dot {
|
||||
inputs: inputs.to_vec(),
|
||||
weights: weights.clone(),
|
||||
inputs,
|
||||
weights,
|
||||
extra_data: (),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_levelled_op(
|
||||
&mut self,
|
||||
inputs: &[OperatorIndex],
|
||||
inputs: impl Into<Vec<OperatorIndex>>,
|
||||
complexity: LevelledComplexity,
|
||||
manp: f64,
|
||||
out_shape: Shape,
|
||||
out_shape: impl Into<Shape>,
|
||||
comment: &str,
|
||||
) -> OperatorIndex {
|
||||
let inputs = inputs.to_vec();
|
||||
let inputs = inputs.into();
|
||||
let out_shape = out_shape.into();
|
||||
let comment = comment.to_string();
|
||||
let op = Operator::LevelledOp {
|
||||
inputs,
|
||||
@@ -75,7 +87,7 @@ impl OperationDag {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::graph::operator::Shape;
|
||||
use crate::dag::operator::Shape;
|
||||
|
||||
#[test]
|
||||
fn graph_creation() {
|
||||
@@ -86,14 +98,14 @@ mod tests {
|
||||
let input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
|
||||
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
|
||||
|
||||
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
|
||||
|
||||
let concat =
|
||||
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat");
|
||||
graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat");
|
||||
|
||||
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
|
||||
let dot = graph.add_dot([concat], [1, 2]);
|
||||
|
||||
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::dag::parameter_indexed::OperatorParameterIndexed;
|
||||
use crate::global_parameters::{ParameterToOperation, ParameterValues};
|
||||
use crate::graph::parameter_indexed::OperatorParameterIndexed;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct OperationDag {
|
||||
@@ -1,11 +1,11 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use crate::graph::operator::{Operator, OperatorIndex};
|
||||
use crate::graph::parameter_indexed::{
|
||||
use crate::dag::operator::{Operator, OperatorIndex};
|
||||
use crate::dag::parameter_indexed::{
|
||||
InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed,
|
||||
};
|
||||
use crate::graph::unparametrized::UnparameterizedOperator;
|
||||
use crate::graph::{parameter_indexed, range_parametrized, unparametrized};
|
||||
use crate::dag::unparametrized::UnparameterizedOperator;
|
||||
use crate::dag::{parameter_indexed, range_parametrized, unparametrized};
|
||||
use crate::parameters::{
|
||||
BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters,
|
||||
KsDecompositionParameterRanges, KsDecompositionParameters,
|
||||
@@ -244,7 +244,7 @@ pub fn domains_to_ranges(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::graph::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
|
||||
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape};
|
||||
|
||||
#[test]
|
||||
fn test_maximal_unify() {
|
||||
@@ -255,14 +255,13 @@ mod tests {
|
||||
let input2 = graph.add_input(2, Shape::number());
|
||||
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum");
|
||||
let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum");
|
||||
|
||||
let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN);
|
||||
|
||||
let concat =
|
||||
graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
|
||||
let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat");
|
||||
|
||||
let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2]));
|
||||
let dot = graph.add_dot([concat], [1, 2]);
|
||||
|
||||
let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN);
|
||||
|
||||
@@ -336,7 +335,7 @@ mod tests {
|
||||
|
||||
let cpx_add = LevelledComplexity::ADDITION;
|
||||
let concat =
|
||||
graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
|
||||
graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat");
|
||||
|
||||
let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN);
|
||||
|
||||
|
||||
@@ -1,27 +1,29 @@
|
||||
#![warn(clippy::nursery)]
|
||||
#![warn(clippy::pedantic)]
|
||||
#![warn(clippy::style)]
|
||||
#![allow(clippy::cast_lossless)]
|
||||
#![allow(clippy::cast_precision_loss)] // u64 to f64
|
||||
#![allow(clippy::cast_possible_truncation)] // u64 to usize
|
||||
#![allow(clippy::inline_always)] // needed by delegate
|
||||
#![allow(clippy::match_wildcard_for_single_variants)]
|
||||
#![allow(clippy::missing_panics_doc)]
|
||||
#![allow(clippy::missing_const_for_fn)]
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
#![allow(clippy::must_use_candidate)]
|
||||
#![allow(clippy::return_self_not_must_use)]
|
||||
#![allow(clippy::similar_names)]
|
||||
#![allow(clippy::suboptimal_flops)]
|
||||
#![allow(clippy::too_many_arguments)]
|
||||
#![allow(clippy::match_wildcard_for_single_variants)]
|
||||
#![allow(clippy::cast_lossless)]
|
||||
#![warn(unused_results)]
|
||||
|
||||
pub mod computing_cost;
|
||||
|
||||
pub mod dag;
|
||||
pub mod global_parameters;
|
||||
pub mod graph;
|
||||
pub mod noise_estimator;
|
||||
pub mod optimization;
|
||||
pub mod parameters;
|
||||
pub mod pareto;
|
||||
pub mod security;
|
||||
pub mod utils;
|
||||
pub mod weight;
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
use concrete_commons::dispersion::DispersionParameter;
|
||||
|
||||
use super::utils;
|
||||
use crate::utils::square;
|
||||
|
||||
pub fn sigma_scale_of_error_probability(p_error: f64) -> f64 {
|
||||
// https://en.wikipedia.org/wiki/Error_function#Applications
|
||||
let p_in = 1.0 - p_error;
|
||||
@@ -9,6 +14,31 @@ pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 {
|
||||
1.0 - statrs::function::erf::erf(sigma_scale / 2_f64.sqrt())
|
||||
}
|
||||
|
||||
const LEFT_PADDING_BITS: u64 = 1;
|
||||
const RIGHT_PADDING_BITS: u64 = 1;
|
||||
|
||||
pub fn fatal_noise_limit(precision: u64, ciphertext_modulus_log: u64) -> f64 {
|
||||
let no_noise_bits = LEFT_PADDING_BITS + precision + RIGHT_PADDING_BITS;
|
||||
let noise_bits = ciphertext_modulus_log - no_noise_bits;
|
||||
2_f64.powi(noise_bits as i32)
|
||||
}
|
||||
|
||||
pub fn variance_max(
|
||||
precision: u64,
|
||||
ciphertext_modulus_log: u64,
|
||||
maximum_acceptable_error_probability: f64,
|
||||
) -> f64 {
|
||||
let fatal_noise_limit = fatal_noise_limit(precision, ciphertext_modulus_log);
|
||||
// We want safe_sigma such that:
|
||||
// P(x not in [-+fatal_noise_limit] | σ = safe_sigma) = p_error
|
||||
// <=> P(x not in [-+fatal_noise_limit/safe_sigma] | σ = 1) = p_error
|
||||
// <=> P(x not in [-+kappa] | σ = 1) = p_error, with safe_sigma = fatal_noise_limit / kappa
|
||||
let kappa = sigma_scale_of_error_probability(maximum_acceptable_error_probability);
|
||||
let safe_sigma = fatal_noise_limit / kappa;
|
||||
let modular_variance = square(safe_sigma);
|
||||
utils::from_modular_variance(modular_variance, ciphertext_modulus_log).get_variance()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
use crate::computing_cost::operators::atomic_pattern as complexity_atomic_pattern;
|
||||
use crate::computing_cost::operators::keyswitch_lwe::KeySwitchLWEComplexity;
|
||||
use crate::computing_cost::operators::pbs::PbsComplexity;
|
||||
use crate::noise_estimator::error::{
|
||||
error_probability_of_sigma_scale, sigma_scale_of_error_probability,
|
||||
};
|
||||
use crate::noise_estimator::error;
|
||||
|
||||
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
use crate::parameters::{
|
||||
AtomicPatternParameters, BrDecompositionParameters, GlweParameters, KeyswitchParameters,
|
||||
@@ -11,14 +10,11 @@ use crate::parameters::{
|
||||
};
|
||||
use crate::pareto;
|
||||
use crate::security;
|
||||
use crate::utils::square;
|
||||
use complexity_atomic_pattern::{AtomicPatternComplexity, DEFAULT as DEFAULT_COMPLEXITY};
|
||||
use concrete_commons::dispersion::{DispersionParameter, Variance};
|
||||
use concrete_commons::numeric::UnsignedInteger;
|
||||
|
||||
fn square(v: f64) -> f64 {
|
||||
v * v
|
||||
}
|
||||
|
||||
/* enable to debug */
|
||||
const CHECKS: bool = false;
|
||||
/* disable to debug */
|
||||
@@ -27,7 +23,7 @@ const CUTS: bool = true; // 80ms
|
||||
const PARETO_CUTS: bool = true; // 75ms
|
||||
const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true; // 70ms
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct Solution {
|
||||
pub input_lwe_dimension: u64, //n_big
|
||||
pub internal_ks_output_lwe_dimension: u64, //n_small
|
||||
@@ -38,27 +34,28 @@ pub struct Solution {
|
||||
pub br_decomposition_level_count: u64, //l(BR)
|
||||
pub br_decomposition_base_log: u64, //b(BR)
|
||||
pub complexity: f64,
|
||||
pub lut_complexity: f64,
|
||||
pub noise_max: f64,
|
||||
pub p_error: f64, // error probability
|
||||
}
|
||||
|
||||
// Constants during optimization of decompositions
|
||||
struct OptimizationDecompositionsConsts {
|
||||
kappa: f64,
|
||||
sum_size: u64,
|
||||
security_level: u64,
|
||||
noise_factor: f64,
|
||||
ciphertext_modulus_log: u64,
|
||||
keyswitch_decompositions: Vec<KsDecompositionParameters>,
|
||||
blind_rotate_decompositions: Vec<BrDecompositionParameters>,
|
||||
variance_max: f64,
|
||||
// Constants during optimisation of decompositions
|
||||
pub(crate) struct OptimizationDecompositionsConsts {
|
||||
pub kappa: f64,
|
||||
pub sum_size: u64,
|
||||
pub security_level: u64,
|
||||
pub noise_factor: f64,
|
||||
pub ciphertext_modulus_log: u64,
|
||||
pub keyswitch_decompositions: Vec<KsDecompositionParameters>,
|
||||
pub blind_rotate_decompositions: Vec<BrDecompositionParameters>,
|
||||
pub safe_variance: f64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
struct ComplexityNoise {
|
||||
index: usize,
|
||||
complexity: f64,
|
||||
noise: f64,
|
||||
pub(crate) struct ComplexityNoise {
|
||||
pub index: usize,
|
||||
pub complexity: f64,
|
||||
pub noise: f64,
|
||||
}
|
||||
|
||||
impl ComplexityNoise {
|
||||
@@ -69,7 +66,7 @@ impl ComplexityNoise {
|
||||
};
|
||||
}
|
||||
|
||||
fn blind_rotate_quantities<W: UnsignedInteger>(
|
||||
pub(crate) fn pareto_blind_rotate<W: UnsignedInteger>(
|
||||
consts: &OptimizationDecompositionsConsts,
|
||||
internal_dim: u64,
|
||||
glwe_params: GlweParameters,
|
||||
@@ -107,11 +104,11 @@ fn blind_rotate_quantities<W: UnsignedInteger>(
|
||||
variance_bsk,
|
||||
);
|
||||
|
||||
let noise_in = base_noise.get_variance() * square(consts.noise_factor);
|
||||
if cut_noise < noise_in && CUTS {
|
||||
let noise_out = base_noise.get_variance();
|
||||
if cut_noise < noise_out && CUTS {
|
||||
continue; // noise is decreasing
|
||||
}
|
||||
if decreasing_variance < noise_in && PARETO_CUTS {
|
||||
if decreasing_variance < noise_out && PARETO_CUTS {
|
||||
// the current case is dominated
|
||||
continue;
|
||||
}
|
||||
@@ -124,14 +121,14 @@ fn blind_rotate_quantities<W: UnsignedInteger>(
|
||||
quantities[size] = ComplexityNoise {
|
||||
index: i_br,
|
||||
complexity: complexity_pbs,
|
||||
noise: noise_in,
|
||||
noise: noise_out,
|
||||
};
|
||||
assert!(
|
||||
0.0 <= delta_complexity,
|
||||
"blind_rotate_decompositions should be by increasing complexity"
|
||||
);
|
||||
increasing_complexity = complexity_pbs;
|
||||
decreasing_variance = noise_in;
|
||||
decreasing_variance = noise_out;
|
||||
size += 1;
|
||||
}
|
||||
assert!(!PARETO_CUTS || size < 64);
|
||||
@@ -139,7 +136,7 @@ fn blind_rotate_quantities<W: UnsignedInteger>(
|
||||
quantities
|
||||
}
|
||||
|
||||
fn keyswitch_quantities<W: UnsignedInteger>(
|
||||
pub(crate) fn pareto_keyswitch<W: UnsignedInteger>(
|
||||
consts: &OptimizationDecompositionsConsts,
|
||||
in_dim: u64,
|
||||
internal_dim: u64,
|
||||
@@ -226,8 +223,8 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
glwe_poly_size,
|
||||
)
|
||||
.get_variance();
|
||||
let variance_max = consts.variance_max;
|
||||
if CUTS && noise_modulus_switching > variance_max {
|
||||
let safe_variance = consts.safe_variance;
|
||||
if CUTS && noise_modulus_switching > safe_variance {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -236,9 +233,9 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
|
||||
let complexity_multisum = (consts.sum_size * input_lwe_dimension) as f64;
|
||||
let mut cut_complexity = best_complexity - complexity_multisum;
|
||||
let mut cut_noise = variance_max - noise_modulus_switching;
|
||||
let mut cut_noise = safe_variance - noise_modulus_switching;
|
||||
let br_quantities =
|
||||
blind_rotate_quantities::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
|
||||
pareto_blind_rotate::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
|
||||
if br_quantities.is_empty() {
|
||||
return;
|
||||
}
|
||||
@@ -246,7 +243,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
cut_noise -= br_quantities[br_quantities.len() - 1].noise;
|
||||
cut_complexity -= br_quantities[0].complexity;
|
||||
}
|
||||
let ks_quantities = keyswitch_quantities::<W>(
|
||||
let ks_quantities = pareto_keyswitch::<W>(
|
||||
consts,
|
||||
input_lwe_dimension,
|
||||
internal_dim,
|
||||
@@ -259,11 +256,12 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
|
||||
let i_max_ks = ks_quantities.len() - 1;
|
||||
let mut i_current_max_ks = i_max_ks;
|
||||
let square_noise_factor = square(consts.noise_factor);
|
||||
for br_quantity in br_quantities {
|
||||
// increasing complexity, decreasing variance
|
||||
let noise_in = br_quantity.noise;
|
||||
let noise_in = br_quantity.noise * square_noise_factor;
|
||||
let noise_max = noise_in + noise_modulus_switching;
|
||||
if noise_max > variance_max && CUTS {
|
||||
if noise_max > safe_variance && CUTS {
|
||||
continue;
|
||||
}
|
||||
let complexity_pbs = br_quantity.complexity;
|
||||
@@ -298,7 +296,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
);
|
||||
}
|
||||
|
||||
if noise_max > variance_max {
|
||||
if noise_max > safe_variance {
|
||||
if CROSS_PARETO_CUTS {
|
||||
// the pareto of 2 added pareto is scanned linearly
|
||||
// but with all cuts, pre-computing => no gain
|
||||
@@ -316,9 +314,9 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
|
||||
// feasible and at least as good complexity
|
||||
if complexity < best_complexity || noise_max < best_variance {
|
||||
let sigma = Variance(variance_max).get_standard_dev() * consts.kappa;
|
||||
let sigma = Variance(safe_variance).get_standard_dev() * consts.kappa;
|
||||
let sigma_scale = sigma / Variance(noise_max).get_standard_dev();
|
||||
let p_error = error_probability_of_sigma_scale(sigma_scale);
|
||||
let p_error = error::error_probability_of_sigma_scale(sigma_scale);
|
||||
|
||||
let i_br = br_quantity.index;
|
||||
let i_ks = ks_quantity.index;
|
||||
@@ -344,6 +342,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
br_decomposition_base_log: br_b,
|
||||
noise_max,
|
||||
complexity,
|
||||
lut_complexity: complexity_keyswitch + complexity_pbs,
|
||||
p_error,
|
||||
});
|
||||
}
|
||||
@@ -351,7 +350,7 @@ fn update_state_with_best_decompositions<W: UnsignedInteger>(
|
||||
} // br ks
|
||||
}
|
||||
|
||||
// This function provides reference values with unoptimized code, until we have non regeression tests
|
||||
// This function provides reference values with unoptimised code, until we have non regeression tests
|
||||
#[allow(clippy::float_cmp)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn assert_checks<W: UnsignedInteger>(
|
||||
@@ -367,7 +366,7 @@ fn assert_checks<W: UnsignedInteger>(
|
||||
) {
|
||||
let i_ks = ks_c_n.index;
|
||||
let i_br = br_c_n.index;
|
||||
let noise_in = br_c_n.noise;
|
||||
let noise_out = br_c_n.noise;
|
||||
let noise_keyswitch = ks_c_n.noise;
|
||||
let complexity_keyswitch = ks_c_n.complexity;
|
||||
let complexity_pbs = br_c_n.complexity;
|
||||
@@ -398,7 +397,7 @@ fn assert_checks<W: UnsignedInteger>(
|
||||
.complexity(pbs_parameters, ciphertext_modulus_log);
|
||||
|
||||
assert_eq!(complexity_pbs, complexity_pbs_);
|
||||
assert_eq!(noise_in, noise_in_);
|
||||
assert_eq!(noise_out * square(consts.noise_factor), noise_in_);
|
||||
let variance_ksk =
|
||||
noise_atomic_pattern::variance_ksk(internal_dim, ciphertext_modulus_log, security_level);
|
||||
|
||||
@@ -429,7 +428,7 @@ fn assert_checks<W: UnsignedInteger>(
|
||||
};
|
||||
|
||||
let check_max_noise = noise_atomic_pattern::maximal_noise::<Variance, W>(
|
||||
Variance(noise_in),
|
||||
Variance(noise_in_),
|
||||
atomic_pattern_parameters,
|
||||
ciphertext_modulus_log,
|
||||
security_level,
|
||||
@@ -452,8 +451,6 @@ fn assert_checks<W: UnsignedInteger>(
|
||||
assert!(diff_complexity < 0.0001);
|
||||
}
|
||||
|
||||
const BITS_CARRY: u64 = 1;
|
||||
const BITS_PADDING_WITHOUT_NOISE: u64 = 1;
|
||||
const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
@@ -479,16 +476,12 @@ pub fn optimize_one<W: UnsignedInteger>(
|
||||
// the blind rotate decomposition
|
||||
|
||||
let ciphertext_modulus_log = W::BITS as u64;
|
||||
|
||||
let no_noise_bits = BITS_CARRY + precision + BITS_PADDING_WITHOUT_NOISE;
|
||||
let noise_bits = ciphertext_modulus_log - no_noise_bits;
|
||||
let fatal_noise_limit = (1_u64 << noise_bits) as f64;
|
||||
|
||||
// Now we search for P(x not in [-+fatal_noise_limit] | σ = safe_sigma) = p_error
|
||||
// P(x not in [-+kappa] | σ = 1) = p_error
|
||||
let kappa: f64 = sigma_scale_of_error_probability(maximum_acceptable_error_probability);
|
||||
let safe_sigma = fatal_noise_limit / kappa;
|
||||
let variance_max = Variance::from_modular_variance::<W>(square(safe_sigma));
|
||||
let safe_variance = error::variance_max(
|
||||
precision,
|
||||
ciphertext_modulus_log,
|
||||
maximum_acceptable_error_probability,
|
||||
);
|
||||
let kappa = error::sigma_scale_of_error_probability(maximum_acceptable_error_probability);
|
||||
|
||||
let consts = OptimizationDecompositionsConsts {
|
||||
kappa,
|
||||
@@ -502,7 +495,7 @@ pub fn optimize_one<W: UnsignedInteger>(
|
||||
blind_rotate_decompositions: pareto::BR_BL
|
||||
.map(|(log2_base, level)| BrDecompositionParameters { level, log2_base })
|
||||
.to_vec(),
|
||||
variance_max: variance_max.get_variance(),
|
||||
safe_variance,
|
||||
};
|
||||
|
||||
let mut state = OptimizationState {
|
||||
@@ -524,7 +517,7 @@ pub fn optimize_one<W: UnsignedInteger>(
|
||||
glwe_poly_size,
|
||||
)
|
||||
.get_variance()
|
||||
> consts.variance_max
|
||||
> consts.safe_variance
|
||||
};
|
||||
|
||||
let skip = |glwe_dim, glwe_poly_size| match restart_at {
|
||||
|
||||
1
concrete-optimizer/src/optimization/dag/mod.rs
Normal file
1
concrete-optimizer/src/optimization/dag/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod solo_key;
|
||||
587
concrete-optimizer/src/optimization/dag/solo_key/analyze.rs
Normal file
587
concrete-optimizer/src/optimization/dag/solo_key/analyze.rs
Normal file
@@ -0,0 +1,587 @@
|
||||
use super::symbolic_variance::{SymbolicVariance, VarianceOrigin};
|
||||
use crate::dag::operator::{
|
||||
dot_kind, DotKind, LevelledComplexity, OperatorIndex, Precision, Shape,
|
||||
};
|
||||
use crate::dag::unparametrized;
|
||||
use crate::utils::square;
|
||||
|
||||
// private short convention
|
||||
use DotKind as DK;
|
||||
use VarianceOrigin as VO;
|
||||
type Op = unparametrized::UnparameterizedOperator;
|
||||
|
||||
fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property {
|
||||
&properties[inputs[0].i]
|
||||
}
|
||||
|
||||
fn assert_all_same<Property: PartialEq + std::fmt::Debug>(
|
||||
inputs: &[OperatorIndex],
|
||||
properties: &[Property],
|
||||
) {
|
||||
let first = first(inputs, properties);
|
||||
for input in inputs.iter().skip(1) {
|
||||
assert_eq!(first, &properties[input.i]);
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_inputs_uniform_precisions(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_precisions: &[Precision],
|
||||
) {
|
||||
if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op {
|
||||
assert_all_same(inputs, out_precisions);
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_dot_uniform_inputs_shape(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_shapes: &[Shape],
|
||||
) {
|
||||
if let Op::Dot { inputs, .. } = op {
|
||||
assert_all_same(inputs, out_shapes);
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) {
|
||||
if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op {
|
||||
assert!(!inputs.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_dag_correctness(dag: &unparametrized::OperationDag) {
|
||||
for op in &dag.operators {
|
||||
assert_non_empty_inputs(op);
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_valid_variances(dag: &OperationDag) {
|
||||
for &out_variance in &dag.out_variances {
|
||||
assert!(
|
||||
SymbolicVariance::ZERO == out_variance // Special case of multiply by 0
|
||||
|| 1.0 <= out_variance.input_vf
|
||||
|| 1.0 <= out_variance.lut_vf
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn assert_properties_correctness(dag: &OperationDag) {
|
||||
for op in &dag.operators {
|
||||
assert_inputs_uniform_precisions(op, &dag.out_precisions);
|
||||
assert_dot_uniform_inputs_shape(op, &dag.out_shapes);
|
||||
}
|
||||
assert_valid_variances(dag);
|
||||
}
|
||||
|
||||
fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance]) -> VarianceOrigin {
|
||||
let first_origin = first(inputs, out_variances).origin();
|
||||
for input in inputs.iter().skip(1) {
|
||||
let item = &out_variances[input.i];
|
||||
if first_origin != item.origin() {
|
||||
return VO::Mixed;
|
||||
}
|
||||
}
|
||||
first_origin
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OperationDag {
|
||||
pub operators: Vec<Op>,
|
||||
// Collect all operators ouput shape
|
||||
pub out_shapes: Vec<Shape>,
|
||||
// Collect all operators ouput precision
|
||||
pub out_precisions: Vec<Precision>,
|
||||
// Collect all operators ouput variances
|
||||
pub out_variances: Vec<SymbolicVariance>,
|
||||
pub nb_luts: u64,
|
||||
// The full dag levelled complexity
|
||||
pub levelled_complexity: LevelledComplexity,
|
||||
// Global summaries of worst noise cases
|
||||
pub noise_summary: NoiseSummary,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NoiseSummary {
|
||||
// All final variance factor not entering a lut (usually final levelledOp)
|
||||
pub pareto_vfs_final: Vec<SymbolicVariance>,
|
||||
// All variance factor entering a lut
|
||||
pub pareto_vfs_in_lut: Vec<SymbolicVariance>,
|
||||
}
|
||||
|
||||
impl OperationDag {
|
||||
pub fn peek_variance(
|
||||
&self,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
) -> f64 {
|
||||
peek_variance(
|
||||
self,
|
||||
input_noise_out,
|
||||
blind_rotate_noise_out,
|
||||
noise_keyswitch,
|
||||
noise_modulus_switching,
|
||||
)
|
||||
}
|
||||
|
||||
pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 {
|
||||
let luts_cost = one_lut_cost * (self.nb_luts as f64);
|
||||
let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension);
|
||||
luts_cost + levelled_cost
|
||||
}
|
||||
}
|
||||
|
||||
fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape {
|
||||
match op {
|
||||
Op::Input { out_shape, .. } | Op::LevelledOp { out_shape, .. } => out_shape.clone(),
|
||||
Op::Lut { input, .. } => out_shapes[input.i].clone(),
|
||||
Op::Dot {
|
||||
inputs, weights, ..
|
||||
} => {
|
||||
if inputs.is_empty() {
|
||||
return Shape::number();
|
||||
}
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor => Shape::number(),
|
||||
DK::CompatibleTensor => weights.shape.clone(),
|
||||
DK::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()),
|
||||
DK::Unsupported { .. } => panic!("Unsupported"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn out_shapes(dag: &unparametrized::OperationDag) -> Vec<Shape> {
|
||||
let nb_ops = dag.operators.len();
|
||||
let mut out_shapes = Vec::<Shape>::with_capacity(nb_ops);
|
||||
for op in &dag.operators {
|
||||
let shape = out_shape(op, &mut out_shapes);
|
||||
out_shapes.push(shape);
|
||||
}
|
||||
out_shapes
|
||||
}
|
||||
|
||||
fn out_precision(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_precisions: &mut [Precision],
|
||||
) -> Precision {
|
||||
match op {
|
||||
Op::Input { out_precision, .. } => *out_precision,
|
||||
Op::Lut { input, .. } => out_precisions[input.i],
|
||||
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i],
|
||||
}
|
||||
}
|
||||
|
||||
fn out_precisions(dag: &unparametrized::OperationDag) -> Vec<Precision> {
|
||||
let nb_ops = dag.operators.len();
|
||||
let mut out_precisions = Vec::<Precision>::with_capacity(nb_ops);
|
||||
for op in &dag.operators {
|
||||
let precision = out_precision(op, &mut out_precisions);
|
||||
out_precisions.push(precision);
|
||||
}
|
||||
out_precisions
|
||||
}
|
||||
|
||||
fn out_variance(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_shapes: &[Shape],
|
||||
out_variances: &mut [SymbolicVariance],
|
||||
) -> SymbolicVariance {
|
||||
// Maintain a linear combination of input_variance and lut_out_variance
|
||||
// TODO: track each elements instead of container
|
||||
match op {
|
||||
Op::Input { .. } => SymbolicVariance::INPUT,
|
||||
Op::Lut { .. } => SymbolicVariance::LUT,
|
||||
Op::LevelledOp { inputs, manp, .. } => {
|
||||
let variance_factor = SymbolicVariance::manp_to_variance_factor(*manp);
|
||||
let origin = match variance_origin(inputs, out_variances) {
|
||||
VO::Input => SymbolicVariance::INPUT,
|
||||
VO::Lut | VO::Mixed /* Mixed: assume the worst */
|
||||
=> SymbolicVariance::LUT
|
||||
};
|
||||
origin * variance_factor
|
||||
}
|
||||
Op::Dot {
|
||||
inputs, weights, ..
|
||||
} => {
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor => {
|
||||
let first_input = inputs[0];
|
||||
let mut out_variance = SymbolicVariance::ZERO;
|
||||
for (j, &weight) in weights.values.iter().enumerate() {
|
||||
let k = if kind == DK::Simple {
|
||||
inputs[j].i
|
||||
} else {
|
||||
first_input.i
|
||||
};
|
||||
out_variance += out_variances[k] * square(weight);
|
||||
}
|
||||
out_variance
|
||||
}
|
||||
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
|
||||
DK::Unsupported { .. } => panic!("Unsupported"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn out_variances(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
) -> Vec<SymbolicVariance> {
|
||||
let nb_ops = dag.operators.len();
|
||||
let mut out_variances = Vec::with_capacity(nb_ops);
|
||||
for op in &dag.operators {
|
||||
let vf = out_variance(op, out_shapes, &mut out_variances);
|
||||
out_variances.push(vf);
|
||||
}
|
||||
out_variances
|
||||
}
|
||||
|
||||
fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec<bool> {
|
||||
let nb_ops = dag.operators.len();
|
||||
let mut extra_values_to_check = vec![true; nb_ops];
|
||||
for op in &dag.operators {
|
||||
match op {
|
||||
Op::Input { .. } => (),
|
||||
Op::Lut { input, .. } => {
|
||||
extra_values_to_check[input.i] = false;
|
||||
}
|
||||
Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => {
|
||||
for input in inputs {
|
||||
extra_values_to_check[input.i] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
extra_values_to_check
|
||||
}
|
||||
|
||||
fn extra_final_variances(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_variances: &[SymbolicVariance],
|
||||
) -> Vec<SymbolicVariance> {
|
||||
extra_final_values_to_check(dag)
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, &is_final)| {
|
||||
if is_final {
|
||||
Some(out_variances[i])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn in_luts_variance(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_variances: &[SymbolicVariance],
|
||||
) -> Vec<SymbolicVariance> {
|
||||
let only_luts = |op| {
|
||||
if let &Op::Lut { input, .. } = op {
|
||||
Some(out_variances[input.i])
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
dag.operators.iter().filter_map(only_luts).collect()
|
||||
}
|
||||
|
||||
fn op_levelled_complexity(
|
||||
op: &unparametrized::UnparameterizedOperator,
|
||||
out_shapes: &[Shape],
|
||||
) -> LevelledComplexity {
|
||||
match op {
|
||||
Op::Dot {
|
||||
inputs, weights, ..
|
||||
} => {
|
||||
let input_shape = first(inputs, out_shapes);
|
||||
let kind = dot_kind(inputs.len() as u64, input_shape, weights);
|
||||
match kind {
|
||||
DK::Simple | DK::Tensor => LevelledComplexity::ADDITION * weights.flat_size(),
|
||||
DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"),
|
||||
DK::Unsupported { .. } => panic!("Unsupported"),
|
||||
}
|
||||
}
|
||||
Op::LevelledOp { complexity, .. } => *complexity,
|
||||
Op::Input { .. } | Op::Lut { .. } => LevelledComplexity::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
fn levelled_complexity(
|
||||
dag: &unparametrized::OperationDag,
|
||||
out_shapes: &[Shape],
|
||||
) -> LevelledComplexity {
|
||||
let mut levelled_complexity = LevelledComplexity::ZERO;
|
||||
for op in &dag.operators {
|
||||
levelled_complexity += op_levelled_complexity(op, out_shapes);
|
||||
}
|
||||
levelled_complexity
|
||||
}
|
||||
|
||||
fn max_update(current: &mut f64, candidate: f64) {
|
||||
if candidate > *current {
|
||||
*current = candidate;
|
||||
}
|
||||
}
|
||||
|
||||
fn noise_summary(
|
||||
final_variances: Vec<SymbolicVariance>,
|
||||
in_luts_variance: Vec<SymbolicVariance>,
|
||||
) -> NoiseSummary {
|
||||
let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(final_variances);
|
||||
let pareto_vfs_in_lut = SymbolicVariance::reduce_to_pareto_front(in_luts_variance);
|
||||
NoiseSummary {
|
||||
pareto_vfs_final,
|
||||
pareto_vfs_in_lut,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn analyze(dag: &unparametrized::OperationDag) -> OperationDag {
|
||||
assert_dag_correctness(dag);
|
||||
let out_shapes = out_shapes(dag);
|
||||
let out_precisions = out_precisions(dag);
|
||||
let out_variances = out_variances(dag, &out_shapes);
|
||||
let in_luts_variance = in_luts_variance(dag, &out_variances);
|
||||
let nb_luts = in_luts_variance.len() as u64;
|
||||
let extra_final_variances = extra_final_variances(dag, &out_variances);
|
||||
let levelled_complexity = levelled_complexity(dag, &out_shapes);
|
||||
let noise_summary = noise_summary(extra_final_variances, in_luts_variance);
|
||||
let result = OperationDag {
|
||||
operators: dag.operators.clone(),
|
||||
out_shapes,
|
||||
out_precisions,
|
||||
out_variances,
|
||||
nb_luts,
|
||||
levelled_complexity,
|
||||
noise_summary,
|
||||
};
|
||||
assert_properties_correctness(&result);
|
||||
result
|
||||
}
|
||||
|
||||
// Compute the maximum attained variance for the full dag
|
||||
// TODO take a noise summary => peek_error or global error
|
||||
fn peek_variance(
|
||||
dag: &OperationDag,
|
||||
input_noise_out: f64,
|
||||
blind_rotate_noise_out: f64,
|
||||
noise_keyswitch: f64,
|
||||
noise_modulus_switching: f64,
|
||||
) -> f64 {
|
||||
assert!(input_noise_out < blind_rotate_noise_out);
|
||||
let mut variance_peek_final = 0.0; // updated by the loop
|
||||
for vf in &dag.noise_summary.pareto_vfs_final {
|
||||
max_update(
|
||||
&mut variance_peek_final,
|
||||
vf.eval(input_noise_out, blind_rotate_noise_out),
|
||||
);
|
||||
}
|
||||
|
||||
let mut variance_peek_in_lut = 0.0; // updated by the loop
|
||||
for vf in &dag.noise_summary.pareto_vfs_in_lut {
|
||||
max_update(
|
||||
&mut variance_peek_in_lut,
|
||||
vf.eval(input_noise_out, blind_rotate_noise_out),
|
||||
);
|
||||
}
|
||||
let peek_in_lut = variance_peek_in_lut + noise_keyswitch + noise_modulus_switching;
|
||||
peek_in_lut.max(variance_peek_final)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights};
|
||||
use crate::dag::unparametrized;
|
||||
use crate::utils::square;
|
||||
|
||||
fn assert_f64_eq(v: f64, expected: f64) {
|
||||
approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1_input() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
|
||||
assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT);
|
||||
assert_eq!(analysis.out_shapes[input1.i], Shape::number());
|
||||
assert_eq!(analysis.levelled_complexity, LevelledComplexity::ZERO);
|
||||
assert_eq!(analysis.out_precisions[input1.i], 1);
|
||||
assert_f64_eq(complexity_cost, 0.0);
|
||||
assert!(analysis.nb_luts == 0);
|
||||
let summary = analysis.noise_summary;
|
||||
assert!(summary.pareto_vfs_final.len() == 1);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 1.0);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0);
|
||||
assert!(summary.pareto_vfs_in_lut.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1_lut() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(8, Shape::number());
|
||||
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN);
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
|
||||
assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT);
|
||||
assert!(analysis.out_shapes[lut1.i] == Shape::number());
|
||||
assert!(analysis.levelled_complexity == LevelledComplexity::ZERO);
|
||||
assert_eq!(analysis.out_precisions[lut1.i], 8);
|
||||
assert_f64_eq(one_lut_cost, complexity_cost);
|
||||
let summary = analysis.noise_summary;
|
||||
assert!(summary.pareto_vfs_final.len() == 1);
|
||||
assert!(summary.pareto_vfs_in_lut.len() == 1);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 0.0);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0);
|
||||
assert_f64_eq(summary.pareto_vfs_in_lut[0].input_vf, 1.0);
|
||||
assert_f64_eq(summary.pareto_vfs_in_lut[0].lut_vf, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1_dot() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let weights = Weights::vector([1, 2]);
|
||||
let norm2: f64 = 1.0 * 1.0 + 2.0 * 2.0;
|
||||
let dot = graph.add_dot([input1, input1], weights);
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
|
||||
let expected_var = SymbolicVariance {
|
||||
input_vf: norm2,
|
||||
lut_vf: 0.0,
|
||||
};
|
||||
assert!(analysis.out_variances[dot.i] == expected_var);
|
||||
assert!(analysis.out_shapes[dot.i] == Shape::number());
|
||||
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 2);
|
||||
assert_eq!(analysis.out_precisions[dot.i], 1);
|
||||
let expected_dot_cost = (2 * lwe_dim) as f64;
|
||||
assert_f64_eq(expected_dot_cost, complexity_cost);
|
||||
let summary = analysis.noise_summary;
|
||||
assert!(summary.pareto_vfs_in_lut.is_empty());
|
||||
assert!(summary.pareto_vfs_final.len() == 1);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_1_dot_levelled() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(3, Shape::number());
|
||||
let cpx_dot = LevelledComplexity::ADDITION;
|
||||
let weights = Weights::vector([1, 2]);
|
||||
let manp = 1.0 * 1.0 + 2.0 * 2_f64;
|
||||
let dot = graph.add_levelled_op([input1, input1], cpx_dot, manp, Shape::number(), "dot");
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
|
||||
assert!(analysis.out_variances[dot.i].origin() == VO::Input);
|
||||
assert_eq!(analysis.out_precisions[dot.i], 3);
|
||||
let expected_square_norm2 = weights.square_norm2() as f64;
|
||||
let actual_square_norm2 = analysis.out_variances[dot.i].input_vf;
|
||||
// Due to call on log2() to compute manp the result is not exact
|
||||
assert_f64_eq(actual_square_norm2, expected_square_norm2);
|
||||
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION);
|
||||
assert_f64_eq(lwe_dim as f64, complexity_cost);
|
||||
let summary = analysis.noise_summary;
|
||||
assert!(summary.pareto_vfs_in_lut.is_empty());
|
||||
assert!(summary.pareto_vfs_final.len() == 1);
|
||||
assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Input);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dot_tensorized_lut_dot_lut() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::vector(2));
|
||||
let weights = &Weights::vector([1, 2]);
|
||||
let dot1 = graph.add_dot([input1], weights);
|
||||
let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN);
|
||||
let dot2 = graph.add_dot([lut1, lut1], weights);
|
||||
let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN);
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
|
||||
let expected_var_dot1 = SymbolicVariance {
|
||||
input_vf: weights.square_norm2() as f64,
|
||||
lut_vf: 0.0,
|
||||
};
|
||||
let expected_var_lut1 = SymbolicVariance {
|
||||
input_vf: 0.0,
|
||||
lut_vf: 1.0,
|
||||
};
|
||||
let expected_var_dot2 = SymbolicVariance {
|
||||
input_vf: 0.0,
|
||||
lut_vf: weights.square_norm2() as f64,
|
||||
};
|
||||
let expected_var_lut2 = SymbolicVariance {
|
||||
input_vf: 0.0,
|
||||
lut_vf: 1.0,
|
||||
};
|
||||
assert!(analysis.out_variances[dot1.i] == expected_var_dot1);
|
||||
assert!(analysis.out_variances[lut1.i] == expected_var_lut1);
|
||||
assert!(analysis.out_variances[dot2.i] == expected_var_dot2);
|
||||
assert!(analysis.out_variances[lut2.i] == expected_var_lut2);
|
||||
assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 4);
|
||||
let expected_cost = (lwe_dim * 4) as f64 + 2.0 * one_lut_cost;
|
||||
assert_f64_eq(expected_cost, complexity_cost);
|
||||
let summary = analysis.noise_summary;
|
||||
assert_eq!(summary.pareto_vfs_final.len(), 1);
|
||||
assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Lut);
|
||||
assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0);
|
||||
assert_eq!(summary.pareto_vfs_in_lut.len(), 1);
|
||||
assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Lut);
|
||||
assert_f64_eq(
|
||||
summary.pareto_vfs_in_lut[0].lut_vf,
|
||||
weights.square_norm2() as f64,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lut_dot_mixed_lut() {
|
||||
let mut graph = unparametrized::OperationDag::new();
|
||||
let input1 = graph.add_input(1, Shape::number());
|
||||
let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN);
|
||||
let weights = &Weights::vector([2, 3]);
|
||||
let dot1 = graph.add_dot([input1, lut1], weights);
|
||||
let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN);
|
||||
let analysis = analyze(&graph);
|
||||
let one_lut_cost = 100.0;
|
||||
let lwe_dim = 1024;
|
||||
let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost);
|
||||
|
||||
let expected_cost = (2 * lwe_dim) as f64 + 2.0 * one_lut_cost;
|
||||
assert_f64_eq(expected_cost, complexity_cost);
|
||||
let expected_mixed = SymbolicVariance {
|
||||
input_vf: square(weights.values[0] as f64),
|
||||
lut_vf: square(weights.values[1] as f64),
|
||||
};
|
||||
let summary = analysis.noise_summary;
|
||||
assert_eq!(summary.pareto_vfs_final.len(), 1);
|
||||
assert_eq!(summary.pareto_vfs_final[0], SymbolicVariance::LUT);
|
||||
assert_eq!(summary.pareto_vfs_in_lut.len(), 1);
|
||||
assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Mixed);
|
||||
assert_eq!(summary.pareto_vfs_in_lut[0], expected_mixed);
|
||||
}
|
||||
}
|
||||
3
concrete-optimizer/src/optimization/dag/solo_key/mod.rs
Normal file
3
concrete-optimizer/src/optimization/dag/solo_key/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub(crate) mod analyze;
|
||||
pub mod optimize;
|
||||
pub(crate) mod symbolic_variance;
|
||||
590
concrete-optimizer/src/optimization/dag/solo_key/optimize.rs
Normal file
590
concrete-optimizer/src/optimization/dag/solo_key/optimize.rs
Normal file
@@ -0,0 +1,590 @@
|
||||
use concrete_commons::dispersion::{DispersionParameter, Variance};
|
||||
use concrete_commons::numeric::UnsignedInteger;
|
||||
|
||||
use crate::dag::operator::LevelledComplexity;
|
||||
use crate::dag::unparametrized;
|
||||
use crate::noise_estimator::error;
|
||||
use crate::noise_estimator::error::error_probability_of_sigma_scale;
|
||||
use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern;
|
||||
|
||||
use crate::optimization::atomic_pattern::{
|
||||
pareto_blind_rotate, pareto_keyswitch, OptimizationDecompositionsConsts, OptimizationState,
|
||||
Solution,
|
||||
};
|
||||
|
||||
use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositionParameters};
|
||||
use crate::pareto;
|
||||
use crate::security::glwe::minimal_variance;
|
||||
use crate::utils::square;
|
||||
|
||||
use super::analyze;
|
||||
|
||||
const CUTS: bool = true;
|
||||
const PARETO_CUTS: bool = true;
|
||||
const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn update_best_solution_with_best_decompositions<W: UnsignedInteger>(
|
||||
state: &mut OptimizationState,
|
||||
consts: &OptimizationDecompositionsConsts,
|
||||
dag: &analyze::OperationDag,
|
||||
internal_dim: u64,
|
||||
glwe_params: GlweParameters,
|
||||
noise_modulus_switching: f64,
|
||||
) {
|
||||
let safe_variance = consts.safe_variance;
|
||||
let glwe_poly_size = glwe_params.polynomial_size();
|
||||
let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size;
|
||||
|
||||
let mut best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity);
|
||||
let mut best_lut_complexity = state
|
||||
.best_solution
|
||||
.map_or(f64::INFINITY, |s| s.lut_complexity);
|
||||
let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max);
|
||||
|
||||
let mut cut_complexity =
|
||||
(best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) / (dag.nb_luts as f64);
|
||||
let mut cut_noise = safe_variance - noise_modulus_switching;
|
||||
|
||||
if dag.nb_luts == 0 {
|
||||
cut_noise = f64::INFINITY;
|
||||
cut_complexity = f64::INFINITY;
|
||||
}
|
||||
|
||||
let br_pareto =
|
||||
pareto_blind_rotate::<W>(consts, internal_dim, glwe_params, cut_complexity, cut_noise);
|
||||
if br_pareto.is_empty() {
|
||||
return;
|
||||
}
|
||||
if PARETO_CUTS {
|
||||
cut_noise -= br_pareto[br_pareto.len() - 1].noise;
|
||||
cut_complexity -= br_pareto[0].complexity;
|
||||
}
|
||||
|
||||
let ks_pareto = pareto_keyswitch::<W>(
|
||||
consts,
|
||||
input_lwe_dimension,
|
||||
internal_dim,
|
||||
cut_complexity,
|
||||
cut_noise,
|
||||
);
|
||||
if ks_pareto.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let i_max_ks = ks_pareto.len() - 1;
|
||||
let mut i_current_max_ks = i_max_ks;
|
||||
let input_noise_out = minimal_variance(
|
||||
glwe_params,
|
||||
consts.ciphertext_modulus_log,
|
||||
consts.security_level,
|
||||
)
|
||||
.get_variance();
|
||||
|
||||
let mut best_br_i = 0;
|
||||
let mut best_ks_i = 0;
|
||||
let mut update_best_solution = false;
|
||||
|
||||
for br_quantity in br_pareto {
|
||||
// increasing complexity, decreasing variance
|
||||
let peek_variance = dag.peek_variance(
|
||||
input_noise_out,
|
||||
br_quantity.noise,
|
||||
0.0,
|
||||
noise_modulus_switching,
|
||||
);
|
||||
if peek_variance > safe_variance && CUTS {
|
||||
continue;
|
||||
}
|
||||
let one_pbs_cost = br_quantity.complexity;
|
||||
let complexity = dag.complexity_cost(input_lwe_dimension, one_pbs_cost);
|
||||
if complexity > best_complexity {
|
||||
// As best can evolves it is complementary to blind_rotate_quantities cuts.
|
||||
if PARETO_CUTS {
|
||||
break;
|
||||
} else if CUTS {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
for i_ks_pareto in (0..=i_current_max_ks).rev() {
|
||||
// increasing variance, decreasing complexity
|
||||
let ks_quantity = ks_pareto[i_ks_pareto];
|
||||
let peek_variance = dag.peek_variance(
|
||||
input_noise_out,
|
||||
br_quantity.noise,
|
||||
ks_quantity.noise,
|
||||
noise_modulus_switching,
|
||||
);
|
||||
// let noise_max = br_quantity.noise * dag.lut_base_noise_worst_lut + ks_quantity.noise + noise_modulus_switching;
|
||||
if peek_variance > safe_variance {
|
||||
if CROSS_PARETO_CUTS {
|
||||
// the pareto of 2 added pareto is scanned linearly
|
||||
// but with all cuts, pre-computing => no gain
|
||||
i_current_max_ks = usize::min(i_ks_pareto + 1, i_max_ks);
|
||||
break;
|
||||
// it's compatible with next i_br but with the worst complexity
|
||||
} else if PARETO_CUTS {
|
||||
// increasing variance => we can skip all remaining
|
||||
break;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let one_lut_cost = ks_quantity.complexity + br_quantity.complexity;
|
||||
let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost);
|
||||
|
||||
let better_complexity = complexity < best_complexity;
|
||||
#[allow(clippy::float_cmp)]
|
||||
let same_complexity_with_less_errors =
|
||||
complexity == best_complexity && peek_variance < best_variance;
|
||||
if better_complexity || same_complexity_with_less_errors {
|
||||
best_lut_complexity = one_lut_cost;
|
||||
best_complexity = complexity;
|
||||
best_variance = peek_variance;
|
||||
best_br_i = br_quantity.index;
|
||||
best_ks_i = ks_quantity.index;
|
||||
update_best_solution = true;
|
||||
}
|
||||
}
|
||||
} // br ks
|
||||
|
||||
if update_best_solution {
|
||||
let sigma = Variance(safe_variance).get_standard_dev() * consts.kappa;
|
||||
let sigma_scale = sigma / Variance(best_variance).get_standard_dev();
|
||||
let p_error = error_probability_of_sigma_scale(sigma_scale);
|
||||
let BrDecompositionParameters {
|
||||
level: br_l,
|
||||
log2_base: br_b,
|
||||
} = consts.blind_rotate_decompositions[best_br_i];
|
||||
let KsDecompositionParameters {
|
||||
level: ks_l,
|
||||
log2_base: ks_b,
|
||||
} = consts.keyswitch_decompositions[best_ks_i];
|
||||
state.best_solution = Some(Solution {
|
||||
input_lwe_dimension,
|
||||
internal_ks_output_lwe_dimension: internal_dim,
|
||||
ks_decomposition_level_count: ks_l,
|
||||
ks_decomposition_base_log: ks_b,
|
||||
glwe_polynomial_size: glwe_params.polynomial_size(),
|
||||
glwe_dimension: glwe_params.glwe_dimension,
|
||||
br_decomposition_level_count: br_l,
|
||||
br_decomposition_base_log: br_b,
|
||||
noise_max: best_variance,
|
||||
complexity: best_complexity,
|
||||
lut_complexity: best_lut_complexity,
|
||||
p_error,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn optimize<W: UnsignedInteger>(
|
||||
dag: &unparametrized::OperationDag,
|
||||
security_level: u64,
|
||||
maximum_acceptable_error_probability: f64,
|
||||
glwe_log_polynomial_sizes: &[u64],
|
||||
glwe_dimensions: &[u64],
|
||||
internal_lwe_dimensions: &[u64],
|
||||
) -> OptimizationState {
|
||||
let ciphertext_modulus_log = W::BITS as u64;
|
||||
let dag = analyze::analyze(dag);
|
||||
|
||||
let &max_precision = dag.out_precisions.iter().max().unwrap();
|
||||
|
||||
let safe_variance = error::variance_max(
|
||||
max_precision as u64,
|
||||
ciphertext_modulus_log,
|
||||
maximum_acceptable_error_probability,
|
||||
);
|
||||
let kappa = error::sigma_scale_of_error_probability(maximum_acceptable_error_probability);
|
||||
|
||||
let consts = OptimizationDecompositionsConsts {
|
||||
kappa,
|
||||
sum_size: 0, // superseeded by dag.complexity_cost
|
||||
security_level,
|
||||
noise_factor: f64::NAN, // superseeded by dag.lut_variance_max
|
||||
ciphertext_modulus_log,
|
||||
keyswitch_decompositions: pareto::KS_BL
|
||||
.map(|(log2_base, level)| KsDecompositionParameters { level, log2_base })
|
||||
.to_vec(),
|
||||
blind_rotate_decompositions: pareto::BR_BL
|
||||
.map(|(log2_base, level)| BrDecompositionParameters { level, log2_base })
|
||||
.to_vec(),
|
||||
safe_variance,
|
||||
};
|
||||
|
||||
let mut state = OptimizationState {
|
||||
best_solution: None,
|
||||
count_domain: glwe_dimensions.len()
|
||||
* glwe_log_polynomial_sizes.len()
|
||||
* internal_lwe_dimensions.len()
|
||||
* consts.keyswitch_decompositions.len()
|
||||
* consts.blind_rotate_decompositions.len(),
|
||||
};
|
||||
|
||||
let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| {
|
||||
noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key::<W>(
|
||||
internal_lwe_dimensions,
|
||||
glwe_poly_size,
|
||||
)
|
||||
.get_variance()
|
||||
};
|
||||
|
||||
for &glwe_dim in glwe_dimensions {
|
||||
for &glwe_log_poly_size in glwe_log_polynomial_sizes {
|
||||
let glwe_poly_size = 1 << glwe_log_poly_size;
|
||||
let glwe_params = GlweParameters {
|
||||
log2_polynomial_size: glwe_log_poly_size,
|
||||
glwe_dimension: glwe_dim,
|
||||
};
|
||||
for &internal_dim in internal_lwe_dimensions {
|
||||
let noise_modulus_switching = noise_modulus_switching(glwe_poly_size, internal_dim);
|
||||
if CUTS && noise_modulus_switching > consts.safe_variance {
|
||||
// assume this noise is increasing with internal_dim
|
||||
break;
|
||||
}
|
||||
update_best_solution_with_best_decompositions::<W>(
|
||||
&mut state,
|
||||
&consts,
|
||||
&dag,
|
||||
internal_dim,
|
||||
glwe_params,
|
||||
noise_modulus_switching,
|
||||
);
|
||||
if dag.nb_luts == 0 && state.best_solution.is_some() {
|
||||
return state;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(sol) = state.best_solution {
|
||||
assert!(0.0 <= sol.p_error && sol.p_error <= 1.0);
|
||||
assert!(sol.p_error <= maximum_acceptable_error_probability * REL_EPSILON_PROBA);
|
||||
}
|
||||
|
||||
state
|
||||
}
|
||||
|
||||
pub fn optimize_v0<W: UnsignedInteger>(
|
||||
sum_size: u64,
|
||||
precision: u64,
|
||||
security_level: u64,
|
||||
noise_factor: f64,
|
||||
maximum_acceptable_error_probability: f64,
|
||||
glwe_log_polynomial_sizes: &[u64],
|
||||
glwe_dimensions: &[u64],
|
||||
internal_lwe_dimensions: &[u64],
|
||||
) -> OptimizationState {
|
||||
use crate::dag::operator::{FunctionTable, Shape};
|
||||
let same_scale_manp = 0.0;
|
||||
let manp = square(noise_factor);
|
||||
let out_shape = &Shape::number();
|
||||
let complexity = LevelledComplexity::ADDITION * sum_size;
|
||||
let comment = "dot";
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
let input1 = dag.add_input(precision as u8, out_shape);
|
||||
let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment);
|
||||
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN);
|
||||
let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment);
|
||||
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN);
|
||||
let mut state = optimize::<u64>(
|
||||
&dag,
|
||||
security_level,
|
||||
maximum_acceptable_error_probability,
|
||||
glwe_log_polynomial_sizes,
|
||||
glwe_dimensions,
|
||||
internal_lwe_dimensions,
|
||||
);
|
||||
if let Some(sol) = &mut state.best_solution {
|
||||
sol.complexity /= 2.0;
|
||||
}
|
||||
state
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::dag::operator::{FunctionTable, Shape, Weights};
|
||||
use crate::global_parameters::DEFAUT_DOMAINS;
|
||||
use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin;
|
||||
use crate::utils::square;
|
||||
|
||||
use super::*;
|
||||
use crate::optimization::atomic_pattern;
|
||||
|
||||
fn small_relative_diff(v1: f64, v2: f64) -> bool {
|
||||
f64::abs(v1 - v2) / f64::max(v1, v2) <= f64::EPSILON
|
||||
}
|
||||
|
||||
impl Solution {
|
||||
fn same(&self, other: Self) -> bool {
|
||||
let mut other = other;
|
||||
if small_relative_diff(self.noise_max, other.noise_max)
|
||||
&& small_relative_diff(self.p_error, other.p_error)
|
||||
{
|
||||
other.noise_max = self.noise_max;
|
||||
other.p_error = self.p_error;
|
||||
}
|
||||
self == &other
|
||||
}
|
||||
}
|
||||
|
||||
const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516;
|
||||
|
||||
struct Times {
|
||||
worst_time: u128,
|
||||
dag_time: u128,
|
||||
}
|
||||
|
||||
fn assert_f64_eq(v: f64, expected: f64) {
|
||||
approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v0_parameter_ref() {
|
||||
let mut times = Times {
|
||||
worst_time: 0,
|
||||
dag_time: 0,
|
||||
};
|
||||
for log_weight in 0..=16 {
|
||||
let weight = 1 << log_weight;
|
||||
for precision in 1..=9 {
|
||||
v0_parameter_ref(precision, weight, &mut times);
|
||||
}
|
||||
}
|
||||
assert!(times.worst_time * 2 > times.dag_time);
|
||||
}
|
||||
|
||||
fn v0_parameter_ref(precision: u64, weight: u64, times: &mut Times) {
|
||||
let security_level = 128;
|
||||
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
|
||||
.glwe_pbs_constrained
|
||||
.log2_polynomial_size
|
||||
.as_vec();
|
||||
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
|
||||
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
|
||||
let sum_size = 1;
|
||||
let maximum_acceptable_error_probability = _4_SIGMA;
|
||||
|
||||
let chrono = Instant::now();
|
||||
let state = optimize_v0::<u64>(
|
||||
sum_size,
|
||||
precision,
|
||||
security_level,
|
||||
weight as f64,
|
||||
maximum_acceptable_error_probability,
|
||||
&glwe_log_polynomial_sizes,
|
||||
&glwe_dimensions,
|
||||
&internal_lwe_dimensions,
|
||||
);
|
||||
times.dag_time += chrono.elapsed().as_nanos();
|
||||
let chrono = Instant::now();
|
||||
let state_ref = atomic_pattern::optimize_one::<u64>(
|
||||
sum_size,
|
||||
precision,
|
||||
security_level,
|
||||
weight as f64,
|
||||
maximum_acceptable_error_probability,
|
||||
&glwe_log_polynomial_sizes,
|
||||
&glwe_dimensions,
|
||||
&internal_lwe_dimensions,
|
||||
None,
|
||||
);
|
||||
times.worst_time += chrono.elapsed().as_nanos();
|
||||
assert_eq!(
|
||||
state.best_solution.is_some(),
|
||||
state_ref.best_solution.is_some()
|
||||
);
|
||||
if state.best_solution.is_none() {
|
||||
return;
|
||||
}
|
||||
let sol = state.best_solution.unwrap();
|
||||
let sol_ref = state_ref.best_solution.unwrap();
|
||||
assert!(sol.same(sol_ref));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_v0_parameter_ref_with_dot() {
|
||||
for log_weight in 0..=16 {
|
||||
let weight = 1 << log_weight;
|
||||
for precision in 1..=9 {
|
||||
v0_parameter_ref_with_dot(precision, weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn v0_parameter_ref_with_dot(precision: u64, weight: u64) {
|
||||
let mut dag = unparametrized::OperationDag::new();
|
||||
{
|
||||
let input1 = dag.add_input(precision as u8, Shape::number());
|
||||
let dot1 = dag.add_dot([input1], [1]);
|
||||
let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN);
|
||||
let dot2 = dag.add_dot([lut1], [weight]);
|
||||
let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN);
|
||||
}
|
||||
{
|
||||
let dag2 = analyze::analyze(&dag);
|
||||
let summary = dag2.noise_summary;
|
||||
assert_eq!(summary.pareto_vfs_final.len(), 1);
|
||||
assert_eq!(summary.pareto_vfs_in_lut.len(), 1);
|
||||
assert_eq!(summary.pareto_vfs_final[0].origin(), VarianceOrigin::Lut);
|
||||
assert_f64_eq(1.0, summary.pareto_vfs_final[0].lut_vf);
|
||||
assert!(summary.pareto_vfs_in_lut.len() == 1);
|
||||
assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VarianceOrigin::Lut);
|
||||
assert_f64_eq(square(weight) as f64, summary.pareto_vfs_in_lut[0].lut_vf);
|
||||
}
|
||||
|
||||
let security_level = 128;
|
||||
let maximum_acceptable_error_probability = _4_SIGMA;
|
||||
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
|
||||
.glwe_pbs_constrained
|
||||
.log2_polynomial_size
|
||||
.as_vec();
|
||||
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
|
||||
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
|
||||
let state = optimize::<u64>(
|
||||
&dag,
|
||||
security_level,
|
||||
maximum_acceptable_error_probability,
|
||||
&glwe_log_polynomial_sizes,
|
||||
&glwe_dimensions,
|
||||
&internal_lwe_dimensions,
|
||||
);
|
||||
let state_ref = atomic_pattern::optimize_one::<u64>(
|
||||
1,
|
||||
precision,
|
||||
security_level,
|
||||
weight as f64,
|
||||
maximum_acceptable_error_probability,
|
||||
&glwe_log_polynomial_sizes,
|
||||
&glwe_dimensions,
|
||||
&internal_lwe_dimensions,
|
||||
None,
|
||||
);
|
||||
assert_eq!(
|
||||
state.best_solution.is_some(),
|
||||
state_ref.best_solution.is_some()
|
||||
);
|
||||
if state.best_solution.is_none() {
|
||||
return;
|
||||
}
|
||||
let sol = state.best_solution.unwrap();
|
||||
let mut sol_ref = state_ref.best_solution.unwrap();
|
||||
sol_ref.complexity *= 2.0 /* number of luts */;
|
||||
assert!(sol.same(sol_ref));
|
||||
}
|
||||
|
||||
fn no_lut_vs_lut(precision: u64) {
|
||||
let mut dag_lut = unparametrized::OperationDag::new();
|
||||
let input1 = dag_lut.add_input(precision as u8, Shape::number());
|
||||
let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN);
|
||||
|
||||
let mut dag_no_lut = unparametrized::OperationDag::new();
|
||||
let _input2 = dag_no_lut.add_input(precision as u8, Shape::number());
|
||||
|
||||
let security_level = 128;
|
||||
let maximum_acceptable_error_probability = _4_SIGMA;
|
||||
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
|
||||
.glwe_pbs_constrained
|
||||
.log2_polynomial_size
|
||||
.as_vec();
|
||||
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
|
||||
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
|
||||
|
||||
let opt = |dag: &unparametrized::OperationDag| {
|
||||
optimize::<u64>(
|
||||
dag,
|
||||
security_level,
|
||||
maximum_acceptable_error_probability,
|
||||
&glwe_log_polynomial_sizes,
|
||||
&glwe_dimensions,
|
||||
&internal_lwe_dimensions,
|
||||
)
|
||||
};
|
||||
|
||||
let state_no_lut = opt(&dag_no_lut);
|
||||
let state_lut = opt(&dag_lut);
|
||||
assert_eq!(
|
||||
state_no_lut.best_solution.is_some(),
|
||||
state_lut.best_solution.is_some()
|
||||
);
|
||||
|
||||
if state_lut.best_solution.is_none() {
|
||||
return;
|
||||
}
|
||||
|
||||
let sol_no_lut = state_no_lut.best_solution.unwrap();
|
||||
let sol_lut = state_lut.best_solution.unwrap();
|
||||
assert!(sol_no_lut.complexity < sol_lut.complexity);
|
||||
}
|
||||
#[test]
|
||||
fn test_lut_vs_no_lut() {
|
||||
for precision in 1..=8 {
|
||||
no_lut_vs_lut(precision);
|
||||
}
|
||||
}
|
||||
|
||||
fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision: u64, weight: u64) {
|
||||
let weight = &Weights::number(weight);
|
||||
|
||||
let mut dag_1 = unparametrized::OperationDag::new();
|
||||
{
|
||||
let input1 = dag_1.add_input(precision as u8, Shape::number());
|
||||
let scaled_input1 = dag_1.add_dot([input1], weight);
|
||||
let lut1 = dag_1.add_lut(scaled_input1, FunctionTable::UNKWOWN);
|
||||
let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN);
|
||||
}
|
||||
|
||||
let mut dag_2 = unparametrized::OperationDag::new();
|
||||
{
|
||||
let input1 = dag_2.add_input(precision as u8, Shape::number());
|
||||
let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN);
|
||||
let scaled_lut1 = dag_2.add_dot([lut1], weight);
|
||||
let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN);
|
||||
}
|
||||
|
||||
let security_level = 128;
|
||||
let maximum_acceptable_error_probability = _4_SIGMA;
|
||||
let glwe_log_polynomial_sizes: Vec<u64> = DEFAUT_DOMAINS
|
||||
.glwe_pbs_constrained
|
||||
.log2_polynomial_size
|
||||
.as_vec();
|
||||
let glwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec();
|
||||
let internal_lwe_dimensions: Vec<u64> = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec();
|
||||
|
||||
let opt = |dag: &unparametrized::OperationDag| {
|
||||
optimize::<u64>(
|
||||
dag,
|
||||
security_level,
|
||||
maximum_acceptable_error_probability,
|
||||
&glwe_log_polynomial_sizes,
|
||||
&glwe_dimensions,
|
||||
&internal_lwe_dimensions,
|
||||
)
|
||||
};
|
||||
|
||||
let state_1 = opt(&dag_1);
|
||||
let state_2 = opt(&dag_2);
|
||||
|
||||
if state_1.best_solution.is_none() {
|
||||
assert!(state_2.best_solution.is_none());
|
||||
return;
|
||||
}
|
||||
let sol_1 = state_1.best_solution.unwrap();
|
||||
let sol_2 = state_2.best_solution.unwrap();
|
||||
assert!(sol_1.complexity < sol_2.complexity || sol_1.p_error < sol_2.p_error);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() {
|
||||
for log_weight in 1..=16 {
|
||||
let weight = 1 << log_weight;
|
||||
for precision in 5..=9 {
|
||||
lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
use derive_more::{Add, AddAssign, Sum};
|
||||
/**
|
||||
* A variance that is represented as a linear combination of base variances.
|
||||
* Only the linear coefficient are known.
|
||||
* The base variances are unknown.
|
||||
* Each linear coefficients is a variance factor.
|
||||
*
|
||||
* Only 2 base variances are possible in the solo key setup:
|
||||
* - from input,
|
||||
* - or from lut output.
|
||||
*
|
||||
* We only kown that the first one is lower or equal to the second one.
|
||||
* Each linear coefficient is a variance factor.
|
||||
* There are homogenious to squared weight (or summed square weights or squared norm2).
|
||||
*/
|
||||
#[derive(Clone, Copy, Add, AddAssign, Sum, Debug, PartialEq, PartialOrd)]
|
||||
pub struct SymbolicVariance {
|
||||
pub lut_vf: f64,
|
||||
pub input_vf: f64,
|
||||
// variance = vf.lut_vf * lut_out_noise
|
||||
// + vf.input_vf * input_out_noise
|
||||
// E.g. variance(dot([lut, input], [3, 4])) = VariancesFactors {lut_vf:9, input_vf: 16}
|
||||
|
||||
// NOTE: lut_base_noise is the first field since it has higher impact,
|
||||
// see pareto sorting and dominate_or_equal
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum VarianceOrigin {
|
||||
Input,
|
||||
Lut,
|
||||
Mixed,
|
||||
}
|
||||
|
||||
impl std::ops::Mul<f64> for SymbolicVariance {
|
||||
type Output = Self;
|
||||
fn mul(self, sq_weight: f64) -> Self {
|
||||
Self {
|
||||
input_vf: self.input_vf * sq_weight,
|
||||
lut_vf: self.lut_vf * sq_weight,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::ops::Mul<u64> for SymbolicVariance {
|
||||
type Output = Self;
|
||||
fn mul(self, sq_weight: u64) -> Self {
|
||||
self * sq_weight as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl SymbolicVariance {
|
||||
pub const ZERO: Self = Self {
|
||||
input_vf: 0.0,
|
||||
lut_vf: 0.0,
|
||||
};
|
||||
pub const INPUT: Self = Self {
|
||||
input_vf: 1.0,
|
||||
lut_vf: 0.0,
|
||||
};
|
||||
pub const LUT: Self = Self {
|
||||
input_vf: 0.0,
|
||||
lut_vf: 1.0,
|
||||
};
|
||||
|
||||
pub fn origin(&self) -> VarianceOrigin {
|
||||
if self.lut_vf == 0.0 {
|
||||
VarianceOrigin::Input
|
||||
} else if self.input_vf == 0.0 {
|
||||
VarianceOrigin::Lut
|
||||
} else {
|
||||
VarianceOrigin::Mixed
|
||||
}
|
||||
}
|
||||
|
||||
pub fn manp_to_variance_factor(manp: f64) -> f64 {
|
||||
manp
|
||||
}
|
||||
|
||||
pub fn dominate_or_equal(&self, other: &Self) -> bool {
|
||||
let extra_other_minimal_base_noise = 0.0_f64.max(other.input_vf - self.input_vf);
|
||||
other.lut_vf + extra_other_minimal_base_noise <= self.lut_vf
|
||||
}
|
||||
|
||||
pub fn eval(&self, minimal_base_noise: f64, lut_base_noise: f64) -> f64 {
|
||||
minimal_base_noise * self.input_vf + lut_base_noise * self.lut_vf
|
||||
}
|
||||
|
||||
pub fn reduce_to_pareto_front(mut vfs: Vec<Self>) -> Vec<Self> {
|
||||
if vfs.is_empty() {
|
||||
return vec![];
|
||||
}
|
||||
vfs.sort_by(
|
||||
// bigger first
|
||||
|a, b| b.partial_cmp(a).unwrap(),
|
||||
);
|
||||
// Due to the special domination nature, this can be done in one pass
|
||||
let mut dominator = vfs[0];
|
||||
let mut pareto = vec![dominator];
|
||||
for &vf in vfs.iter().skip(1) {
|
||||
if dominator.dominate_or_equal(&vf) {
|
||||
continue;
|
||||
}
|
||||
dominator = vf;
|
||||
pareto.push(vf);
|
||||
}
|
||||
pareto
|
||||
}
|
||||
}
|
||||
@@ -1 +1,2 @@
|
||||
pub mod atomic_pattern;
|
||||
pub mod dag;
|
||||
|
||||
15
concrete-optimizer/src/utils/mod.rs
Normal file
15
concrete-optimizer/src/utils/mod.rs
Normal file
@@ -0,0 +1,15 @@
|
||||
use std::ops::Mul;
|
||||
|
||||
pub fn square<V>(v: V) -> V
|
||||
where
|
||||
V: Mul<V> + Mul<Output = V> + Copy,
|
||||
{
|
||||
v * v
|
||||
}
|
||||
|
||||
pub fn square_ref<V>(v: &V) -> V
|
||||
where
|
||||
V: Mul<V> + Mul<Output = V> + Copy,
|
||||
{
|
||||
square(*v)
|
||||
}
|
||||
Reference in New Issue
Block a user