diff --git a/concrete-optimizer-cpp/src/concrete-optimizer.rs b/concrete-optimizer-cpp/src/concrete-optimizer.rs index 517a18b1a..9c50fb220 100644 --- a/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -1,7 +1,7 @@ -use concrete_optimizer::graph::operator::{ - self, FunctionTable, LevelledComplexity, OperatorIndex, Shape, +use concrete_optimizer::dag::operator::{ + self, FunctionTable, LevelledComplexity, OperatorIndex, Precision, Shape, }; -use concrete_optimizer::graph::unparametrized; +use concrete_optimizer::dag::unparametrized; fn no_solution() -> ffi::Solution { ffi::Solution { @@ -63,7 +63,7 @@ fn empty() -> Box { } impl OperationDag { - fn add_input(&mut self, out_precision: u8, out_shape: &[u64]) -> ffi::OperatorIndex { + fn add_input(&mut self, out_precision: Precision, out_shape: &[u64]) -> ffi::OperatorIndex { let out_shape = Shape { dimensions_size: out_shape.to_owned(), }; @@ -87,7 +87,7 @@ impl OperationDag { ) -> ffi::OperatorIndex { let inputs: Vec = inputs.iter().copied().map(Into::into).collect(); - self.0.add_dot(&inputs, &weights.0).into() + self.0.add_dot(inputs, weights.0).into() } fn add_levelled_op( @@ -111,7 +111,7 @@ impl OperationDag { }; self.0 - .add_levelled_op(&inputs, complexity, manp, out_shape, comment) + .add_levelled_op(inputs, complexity, manp, out_shape, comment) .into() } } diff --git a/concrete-optimizer/src/graph/mod.rs b/concrete-optimizer/src/dag/mod.rs similarity index 100% rename from concrete-optimizer/src/graph/mod.rs rename to concrete-optimizer/src/dag/mod.rs diff --git a/concrete-optimizer/src/dag/operator/dot_kind.rs b/concrete-optimizer/src/dag/operator/dot_kind.rs new file mode 100644 index 000000000..1dc47ac6a --- /dev/null +++ b/concrete-optimizer/src/dag/operator/dot_kind.rs @@ -0,0 +1,84 @@ +use super::{ClearTensor, Shape}; + +#[derive(PartialEq, Eq, Debug)] +pub enum DotKind { + // inputs = [x,y,z], weights = [a,b,c], = x*a + y*b + z*c + Simple, + // inputs = [[x, y, z]], weights = [a,b,c], = same + Tensor, + // inputs = [[x], [y], [z]], weights = [[a],[b],[c]], = same + CompatibleTensor, + // inputs = [[x, y, z], [x, y, z]], weights = [[a,b,c]], = [same, same] + Broadcast, + Unsupported, +} + +pub fn dot_kind(nb_inputs: u64, input_shape: &Shape, weights: &ClearTensor) -> DotKind { + let inputs_shape = Shape::duplicated(nb_inputs, input_shape); + if input_shape.is_number() && inputs_shape == weights.shape { + DotKind::Simple + } else if nb_inputs == 1 && *input_shape == weights.shape { + DotKind::Tensor + } else if inputs_shape == weights.shape { + DotKind::CompatibleTensor + } else if nb_inputs == 1 && input_shape.erase_first_dim() == weights.shape { + DotKind::Broadcast + } else { + DotKind::Unsupported + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dag::operator::{Shape, Weights}; + + #[test] + fn test_simple() { + assert_eq!( + dot_kind(2, &Shape::number(), &Weights::vector([1, 2])), + DotKind::Simple + ); + } + + #[test] + fn test_tensor() { + assert_eq!( + dot_kind(1, &Shape::vector(2), &Weights::vector([1, 2])), + DotKind::Tensor + ); + } + + #[test] + fn test_broadcast() { + let s2x2 = Shape { + dimensions_size: vec![2, 2], + }; + assert_eq!( + dot_kind(1, &s2x2, &Weights::vector([1, 2])), + DotKind::Broadcast + ); + } + + #[test] + fn test_compatible() { + let weights = ClearTensor { + shape: Shape { + dimensions_size: vec![2, 1], + }, + values: vec![1, 2], + }; + assert_eq!( + dot_kind(2, &Shape::vector(1), &weights), + DotKind::CompatibleTensor + ); + } + + #[test] + fn test_unsupported() { + assert_eq!( + dot_kind(3, &Shape::number(), &Weights::vector([1, 2])), + DotKind::Unsupported + ); + } +} diff --git a/concrete-optimizer/src/graph/operator/mod.rs b/concrete-optimizer/src/dag/operator/mod.rs similarity index 73% rename from concrete-optimizer/src/graph/operator/mod.rs rename to concrete-optimizer/src/dag/operator/mod.rs index 27fea75f4..5837495f4 100644 --- a/concrete-optimizer/src/graph/operator/mod.rs +++ b/concrete-optimizer/src/dag/operator/mod.rs @@ -1,6 +1,8 @@ #![allow(clippy::module_inception)] +pub mod dot_kind; pub mod operator; pub mod tensor; +pub use self::dot_kind::*; pub use self::operator::*; pub use self::tensor::*; diff --git a/concrete-optimizer/src/graph/operator/operator.rs b/concrete-optimizer/src/dag/operator/operator.rs similarity index 92% rename from concrete-optimizer/src/graph/operator/operator.rs rename to concrete-optimizer/src/dag/operator/operator.rs index e0c23694b..4ca2514ea 100644 --- a/concrete-optimizer/src/graph/operator/operator.rs +++ b/concrete-optimizer/src/dag/operator/operator.rs @@ -1,4 +1,4 @@ -use crate::graph::operator::tensor::{ClearTensor, Shape}; +use crate::dag::operator::tensor::{ClearTensor, Shape}; use derive_more::{Add, AddAssign}; pub type Weights = ClearTensor; @@ -50,11 +50,13 @@ impl std::ops::Mul for LevelledComplexity { } } } +pub type Precision = u8; +pub const MIN_PRECISION: Precision = 1; #[derive(Clone, PartialEq, Debug)] pub enum Operator { Input { - out_precision: u8, + out_precision: Precision, out_shape: Shape, extra_data: InputExtraData, }, diff --git a/concrete-optimizer/src/graph/operator/tensor.rs b/concrete-optimizer/src/dag/operator/tensor.rs similarity index 67% rename from concrete-optimizer/src/graph/operator/tensor.rs rename to concrete-optimizer/src/dag/operator/tensor.rs index 82a6c9628..477b9e13e 100644 --- a/concrete-optimizer/src/graph/operator/tensor.rs +++ b/concrete-optimizer/src/dag/operator/tensor.rs @@ -1,10 +1,17 @@ use delegate::delegate; + +use crate::utils::square_ref; + #[derive(Clone, PartialEq, Eq, Debug)] pub struct Shape { pub dimensions_size: Vec, } impl Shape { + pub fn first_dim_size(&self) -> u64 { + self.dimensions_size[0] + } + pub fn rank(&self) -> usize { self.dimensions_size.len() } @@ -43,6 +50,12 @@ impl Shape { dimensions_size.extend_from_slice(&other.dimensions_size); Self { dimensions_size } } + + pub fn erase_first_dim(&self) -> Self { + Self { + dimensions_size: self.dimensions_size[1..].to_vec(), + } + } } #[derive(Clone, PartialEq, Eq, Debug)] @@ -51,11 +64,6 @@ pub struct ClearTensor { pub values: Vec, } -#[allow(clippy::trivially_copy_pass_by_ref)] -fn square(v: &u64) -> u64 { - v * v -} - impl ClearTensor { pub fn number(value: u64) -> Self { Self { @@ -64,10 +72,11 @@ impl ClearTensor { } } - pub fn vector(values: &[u64]) -> Self { + pub fn vector(values: impl Into>) -> Self { + let values = values.into(); Self { shape: Shape::vector(values.len() as u64), - values: values.to_vec(), + values, } } @@ -81,6 +90,27 @@ impl ClearTensor { } pub fn square_norm2(&self) -> u64 { - self.values.iter().map(square).sum() + self.values.iter().map(square_ref).sum() + } +} + +// helps using shared shapes +impl From<&Self> for Shape { + fn from(item: &Self) -> Self { + item.clone() + } +} + +// helps using shared weights +impl From<&Self> for ClearTensor { + fn from(item: &Self) -> Self { + item.clone() + } +} + +// helps using array as weights +impl From<[u64; N]> for ClearTensor { + fn from(item: [u64; N]) -> Self { + Self::vector(item) } } diff --git a/concrete-optimizer/src/graph/parameter_indexed.rs b/concrete-optimizer/src/dag/parameter_indexed.rs similarity index 100% rename from concrete-optimizer/src/graph/parameter_indexed.rs rename to concrete-optimizer/src/dag/parameter_indexed.rs diff --git a/concrete-optimizer/src/graph/range_parametrized.rs b/concrete-optimizer/src/dag/range_parametrized.rs similarity index 81% rename from concrete-optimizer/src/graph/range_parametrized.rs rename to concrete-optimizer/src/dag/range_parametrized.rs index e673844f6..7ebb24b8b 100644 --- a/concrete-optimizer/src/graph/range_parametrized.rs +++ b/concrete-optimizer/src/dag/range_parametrized.rs @@ -1,5 +1,5 @@ +use crate::dag::parameter_indexed::OperatorParameterIndexed; use crate::global_parameters::{ParameterRanges, ParameterToOperation}; -use crate::graph::parameter_indexed::OperatorParameterIndexed; #[allow(dead_code)] pub struct OperationDag { diff --git a/concrete-optimizer/src/graph/unparametrized.rs b/concrete-optimizer/src/dag/unparametrized.rs similarity index 80% rename from concrete-optimizer/src/graph/unparametrized.rs rename to concrete-optimizer/src/dag/unparametrized.rs index 8ac5a6c54..49f7ef297 100644 --- a/concrete-optimizer/src/graph/unparametrized.rs +++ b/concrete-optimizer/src/dag/unparametrized.rs @@ -1,5 +1,5 @@ -use crate::graph::operator::{ - FunctionTable, LevelledComplexity, Operator, OperatorIndex, Shape, Weights, +use crate::dag::operator::{ + FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights, }; pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>; @@ -21,7 +21,12 @@ impl OperationDag { OperatorIndex { i } } - pub fn add_input(&mut self, out_precision: u8, out_shape: Shape) -> OperatorIndex { + pub fn add_input( + &mut self, + out_precision: Precision, + out_shape: impl Into, + ) -> OperatorIndex { + let out_shape = out_shape.into(); self.add_operator(Operator::Input { out_precision, out_shape, @@ -37,23 +42,30 @@ impl OperationDag { }) } - pub fn add_dot(&mut self, inputs: &[OperatorIndex], weights: &Weights) -> OperatorIndex { + pub fn add_dot( + &mut self, + inputs: impl Into>, + weights: impl Into, + ) -> OperatorIndex { + let inputs = inputs.into(); + let weights = weights.into(); self.add_operator(Operator::Dot { - inputs: inputs.to_vec(), - weights: weights.clone(), + inputs, + weights, extra_data: (), }) } pub fn add_levelled_op( &mut self, - inputs: &[OperatorIndex], + inputs: impl Into>, complexity: LevelledComplexity, manp: f64, - out_shape: Shape, + out_shape: impl Into, comment: &str, ) -> OperatorIndex { - let inputs = inputs.to_vec(); + let inputs = inputs.into(); + let out_shape = out_shape.into(); let comment = comment.to_string(); let op = Operator::LevelledOp { inputs, @@ -75,7 +87,7 @@ impl OperationDag { #[cfg(test)] mod tests { use super::*; - use crate::graph::operator::Shape; + use crate::dag::operator::Shape; #[test] fn graph_creation() { @@ -86,14 +98,14 @@ mod tests { let input2 = graph.add_input(2, Shape::number()); let cpx_add = LevelledComplexity::ADDITION; - let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum"); + let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN); let concat = - graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat"); + graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat"); - let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2])); + let dot = graph.add_dot([concat], [1, 2]); let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN); diff --git a/concrete-optimizer/src/graph/value_parametrized.rs b/concrete-optimizer/src/dag/value_parametrized.rs similarity index 81% rename from concrete-optimizer/src/graph/value_parametrized.rs rename to concrete-optimizer/src/dag/value_parametrized.rs index a6cf66075..4658b6a62 100644 --- a/concrete-optimizer/src/graph/value_parametrized.rs +++ b/concrete-optimizer/src/dag/value_parametrized.rs @@ -1,5 +1,5 @@ +use crate::dag::parameter_indexed::OperatorParameterIndexed; use crate::global_parameters::{ParameterToOperation, ParameterValues}; -use crate::graph::parameter_indexed::OperatorParameterIndexed; #[allow(dead_code)] pub struct OperationDag { diff --git a/concrete-optimizer/src/global_parameters.rs b/concrete-optimizer/src/global_parameters.rs index 6aae92cba..76f3cb842 100644 --- a/concrete-optimizer/src/global_parameters.rs +++ b/concrete-optimizer/src/global_parameters.rs @@ -1,11 +1,11 @@ use std::collections::HashSet; -use crate::graph::operator::{Operator, OperatorIndex}; -use crate::graph::parameter_indexed::{ +use crate::dag::operator::{Operator, OperatorIndex}; +use crate::dag::parameter_indexed::{ InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed, }; -use crate::graph::unparametrized::UnparameterizedOperator; -use crate::graph::{parameter_indexed, range_parametrized, unparametrized}; +use crate::dag::unparametrized::UnparameterizedOperator; +use crate::dag::{parameter_indexed, range_parametrized, unparametrized}; use crate::parameters::{ BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters, KsDecompositionParameterRanges, KsDecompositionParameters, @@ -244,7 +244,7 @@ pub fn domains_to_ranges( #[cfg(test)] mod tests { use super::*; - use crate::graph::operator::{FunctionTable, LevelledComplexity, Shape, Weights}; + use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; #[test] fn test_maximal_unify() { @@ -255,14 +255,13 @@ mod tests { let input2 = graph.add_input(2, Shape::number()); let cpx_add = LevelledComplexity::ADDITION; - let sum1 = graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::number(), "sum"); + let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN); - let concat = - graph.add_levelled_op(&[input1, lut1], cpx_add, 1.0, Shape::number(), "concat"); + let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat"); - let dot = graph.add_dot(&[concat], &Weights::vector(&[1, 2])); + let dot = graph.add_dot([concat], [1, 2]); let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN); @@ -336,7 +335,7 @@ mod tests { let cpx_add = LevelledComplexity::ADDITION; let concat = - graph.add_levelled_op(&[input1, input2], cpx_add, 1.0, Shape::vector(2), "concat"); + graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat"); let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN); diff --git a/concrete-optimizer/src/lib.rs b/concrete-optimizer/src/lib.rs index d5e82847c..c8ab26cf0 100644 --- a/concrete-optimizer/src/lib.rs +++ b/concrete-optimizer/src/lib.rs @@ -1,27 +1,29 @@ #![warn(clippy::nursery)] #![warn(clippy::pedantic)] #![warn(clippy::style)] +#![allow(clippy::cast_lossless)] #![allow(clippy::cast_precision_loss)] // u64 to f64 #![allow(clippy::cast_possible_truncation)] // u64 to usize #![allow(clippy::inline_always)] // needed by delegate +#![allow(clippy::match_wildcard_for_single_variants)] #![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_const_for_fn)] #![allow(clippy::module_name_repetitions)] #![allow(clippy::must_use_candidate)] +#![allow(clippy::return_self_not_must_use)] #![allow(clippy::similar_names)] #![allow(clippy::suboptimal_flops)] #![allow(clippy::too_many_arguments)] -#![allow(clippy::match_wildcard_for_single_variants)] -#![allow(clippy::cast_lossless)] #![warn(unused_results)] pub mod computing_cost; +pub mod dag; pub mod global_parameters; -pub mod graph; pub mod noise_estimator; pub mod optimization; pub mod parameters; pub mod pareto; pub mod security; +pub mod utils; pub mod weight; diff --git a/concrete-optimizer/src/noise_estimator/error.rs b/concrete-optimizer/src/noise_estimator/error.rs index 59276c7f6..55f2fa48e 100644 --- a/concrete-optimizer/src/noise_estimator/error.rs +++ b/concrete-optimizer/src/noise_estimator/error.rs @@ -1,3 +1,8 @@ +use concrete_commons::dispersion::DispersionParameter; + +use super::utils; +use crate::utils::square; + pub fn sigma_scale_of_error_probability(p_error: f64) -> f64 { // https://en.wikipedia.org/wiki/Error_function#Applications let p_in = 1.0 - p_error; @@ -9,6 +14,31 @@ pub fn error_probability_of_sigma_scale(sigma_scale: f64) -> f64 { 1.0 - statrs::function::erf::erf(sigma_scale / 2_f64.sqrt()) } +const LEFT_PADDING_BITS: u64 = 1; +const RIGHT_PADDING_BITS: u64 = 1; + +pub fn fatal_noise_limit(precision: u64, ciphertext_modulus_log: u64) -> f64 { + let no_noise_bits = LEFT_PADDING_BITS + precision + RIGHT_PADDING_BITS; + let noise_bits = ciphertext_modulus_log - no_noise_bits; + 2_f64.powi(noise_bits as i32) +} + +pub fn variance_max( + precision: u64, + ciphertext_modulus_log: u64, + maximum_acceptable_error_probability: f64, +) -> f64 { + let fatal_noise_limit = fatal_noise_limit(precision, ciphertext_modulus_log); + // We want safe_sigma such that: + // P(x not in [-+fatal_noise_limit] | σ = safe_sigma) = p_error + // <=> P(x not in [-+fatal_noise_limit/safe_sigma] | σ = 1) = p_error + // <=> P(x not in [-+kappa] | σ = 1) = p_error, with safe_sigma = fatal_noise_limit / kappa + let kappa = sigma_scale_of_error_probability(maximum_acceptable_error_probability); + let safe_sigma = fatal_noise_limit / kappa; + let modular_variance = square(safe_sigma); + utils::from_modular_variance(modular_variance, ciphertext_modulus_log).get_variance() +} + #[cfg(test)] mod tests { use super::*; diff --git a/concrete-optimizer/src/optimization/atomic_pattern.rs b/concrete-optimizer/src/optimization/atomic_pattern.rs index ab054c67d..bd0e7d22b 100644 --- a/concrete-optimizer/src/optimization/atomic_pattern.rs +++ b/concrete-optimizer/src/optimization/atomic_pattern.rs @@ -1,9 +1,8 @@ use crate::computing_cost::operators::atomic_pattern as complexity_atomic_pattern; use crate::computing_cost::operators::keyswitch_lwe::KeySwitchLWEComplexity; use crate::computing_cost::operators::pbs::PbsComplexity; -use crate::noise_estimator::error::{ - error_probability_of_sigma_scale, sigma_scale_of_error_probability, -}; +use crate::noise_estimator::error; + use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; use crate::parameters::{ AtomicPatternParameters, BrDecompositionParameters, GlweParameters, KeyswitchParameters, @@ -11,14 +10,11 @@ use crate::parameters::{ }; use crate::pareto; use crate::security; +use crate::utils::square; use complexity_atomic_pattern::{AtomicPatternComplexity, DEFAULT as DEFAULT_COMPLEXITY}; use concrete_commons::dispersion::{DispersionParameter, Variance}; use concrete_commons::numeric::UnsignedInteger; -fn square(v: f64) -> f64 { - v * v -} - /* enable to debug */ const CHECKS: bool = false; /* disable to debug */ @@ -27,7 +23,7 @@ const CUTS: bool = true; // 80ms const PARETO_CUTS: bool = true; // 75ms const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true; // 70ms -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq)] pub struct Solution { pub input_lwe_dimension: u64, //n_big pub internal_ks_output_lwe_dimension: u64, //n_small @@ -38,27 +34,28 @@ pub struct Solution { pub br_decomposition_level_count: u64, //l(BR) pub br_decomposition_base_log: u64, //b(BR) pub complexity: f64, + pub lut_complexity: f64, pub noise_max: f64, pub p_error: f64, // error probability } -// Constants during optimization of decompositions -struct OptimizationDecompositionsConsts { - kappa: f64, - sum_size: u64, - security_level: u64, - noise_factor: f64, - ciphertext_modulus_log: u64, - keyswitch_decompositions: Vec, - blind_rotate_decompositions: Vec, - variance_max: f64, +// Constants during optimisation of decompositions +pub(crate) struct OptimizationDecompositionsConsts { + pub kappa: f64, + pub sum_size: u64, + pub security_level: u64, + pub noise_factor: f64, + pub ciphertext_modulus_log: u64, + pub keyswitch_decompositions: Vec, + pub blind_rotate_decompositions: Vec, + pub safe_variance: f64, } #[derive(Clone, Copy)] -struct ComplexityNoise { - index: usize, - complexity: f64, - noise: f64, +pub(crate) struct ComplexityNoise { + pub index: usize, + pub complexity: f64, + pub noise: f64, } impl ComplexityNoise { @@ -69,7 +66,7 @@ impl ComplexityNoise { }; } -fn blind_rotate_quantities( +pub(crate) fn pareto_blind_rotate( consts: &OptimizationDecompositionsConsts, internal_dim: u64, glwe_params: GlweParameters, @@ -107,11 +104,11 @@ fn blind_rotate_quantities( variance_bsk, ); - let noise_in = base_noise.get_variance() * square(consts.noise_factor); - if cut_noise < noise_in && CUTS { + let noise_out = base_noise.get_variance(); + if cut_noise < noise_out && CUTS { continue; // noise is decreasing } - if decreasing_variance < noise_in && PARETO_CUTS { + if decreasing_variance < noise_out && PARETO_CUTS { // the current case is dominated continue; } @@ -124,14 +121,14 @@ fn blind_rotate_quantities( quantities[size] = ComplexityNoise { index: i_br, complexity: complexity_pbs, - noise: noise_in, + noise: noise_out, }; assert!( 0.0 <= delta_complexity, "blind_rotate_decompositions should be by increasing complexity" ); increasing_complexity = complexity_pbs; - decreasing_variance = noise_in; + decreasing_variance = noise_out; size += 1; } assert!(!PARETO_CUTS || size < 64); @@ -139,7 +136,7 @@ fn blind_rotate_quantities( quantities } -fn keyswitch_quantities( +pub(crate) fn pareto_keyswitch( consts: &OptimizationDecompositionsConsts, in_dim: u64, internal_dim: u64, @@ -226,8 +223,8 @@ fn update_state_with_best_decompositions( glwe_poly_size, ) .get_variance(); - let variance_max = consts.variance_max; - if CUTS && noise_modulus_switching > variance_max { + let safe_variance = consts.safe_variance; + if CUTS && noise_modulus_switching > safe_variance { return; } @@ -236,9 +233,9 @@ fn update_state_with_best_decompositions( let complexity_multisum = (consts.sum_size * input_lwe_dimension) as f64; let mut cut_complexity = best_complexity - complexity_multisum; - let mut cut_noise = variance_max - noise_modulus_switching; + let mut cut_noise = safe_variance - noise_modulus_switching; let br_quantities = - blind_rotate_quantities::(consts, internal_dim, glwe_params, cut_complexity, cut_noise); + pareto_blind_rotate::(consts, internal_dim, glwe_params, cut_complexity, cut_noise); if br_quantities.is_empty() { return; } @@ -246,7 +243,7 @@ fn update_state_with_best_decompositions( cut_noise -= br_quantities[br_quantities.len() - 1].noise; cut_complexity -= br_quantities[0].complexity; } - let ks_quantities = keyswitch_quantities::( + let ks_quantities = pareto_keyswitch::( consts, input_lwe_dimension, internal_dim, @@ -259,11 +256,12 @@ fn update_state_with_best_decompositions( let i_max_ks = ks_quantities.len() - 1; let mut i_current_max_ks = i_max_ks; + let square_noise_factor = square(consts.noise_factor); for br_quantity in br_quantities { // increasing complexity, decreasing variance - let noise_in = br_quantity.noise; + let noise_in = br_quantity.noise * square_noise_factor; let noise_max = noise_in + noise_modulus_switching; - if noise_max > variance_max && CUTS { + if noise_max > safe_variance && CUTS { continue; } let complexity_pbs = br_quantity.complexity; @@ -298,7 +296,7 @@ fn update_state_with_best_decompositions( ); } - if noise_max > variance_max { + if noise_max > safe_variance { if CROSS_PARETO_CUTS { // the pareto of 2 added pareto is scanned linearly // but with all cuts, pre-computing => no gain @@ -316,9 +314,9 @@ fn update_state_with_best_decompositions( // feasible and at least as good complexity if complexity < best_complexity || noise_max < best_variance { - let sigma = Variance(variance_max).get_standard_dev() * consts.kappa; + let sigma = Variance(safe_variance).get_standard_dev() * consts.kappa; let sigma_scale = sigma / Variance(noise_max).get_standard_dev(); - let p_error = error_probability_of_sigma_scale(sigma_scale); + let p_error = error::error_probability_of_sigma_scale(sigma_scale); let i_br = br_quantity.index; let i_ks = ks_quantity.index; @@ -344,6 +342,7 @@ fn update_state_with_best_decompositions( br_decomposition_base_log: br_b, noise_max, complexity, + lut_complexity: complexity_keyswitch + complexity_pbs, p_error, }); } @@ -351,7 +350,7 @@ fn update_state_with_best_decompositions( } // br ks } -// This function provides reference values with unoptimized code, until we have non regeression tests +// This function provides reference values with unoptimised code, until we have non regeression tests #[allow(clippy::float_cmp)] #[allow(clippy::too_many_lines)] fn assert_checks( @@ -367,7 +366,7 @@ fn assert_checks( ) { let i_ks = ks_c_n.index; let i_br = br_c_n.index; - let noise_in = br_c_n.noise; + let noise_out = br_c_n.noise; let noise_keyswitch = ks_c_n.noise; let complexity_keyswitch = ks_c_n.complexity; let complexity_pbs = br_c_n.complexity; @@ -398,7 +397,7 @@ fn assert_checks( .complexity(pbs_parameters, ciphertext_modulus_log); assert_eq!(complexity_pbs, complexity_pbs_); - assert_eq!(noise_in, noise_in_); + assert_eq!(noise_out * square(consts.noise_factor), noise_in_); let variance_ksk = noise_atomic_pattern::variance_ksk(internal_dim, ciphertext_modulus_log, security_level); @@ -429,7 +428,7 @@ fn assert_checks( }; let check_max_noise = noise_atomic_pattern::maximal_noise::( - Variance(noise_in), + Variance(noise_in_), atomic_pattern_parameters, ciphertext_modulus_log, security_level, @@ -452,8 +451,6 @@ fn assert_checks( assert!(diff_complexity < 0.0001); } -const BITS_CARRY: u64 = 1; -const BITS_PADDING_WITHOUT_NOISE: u64 = 1; const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8; #[allow(clippy::too_many_lines)] @@ -479,16 +476,12 @@ pub fn optimize_one( // the blind rotate decomposition let ciphertext_modulus_log = W::BITS as u64; - - let no_noise_bits = BITS_CARRY + precision + BITS_PADDING_WITHOUT_NOISE; - let noise_bits = ciphertext_modulus_log - no_noise_bits; - let fatal_noise_limit = (1_u64 << noise_bits) as f64; - - // Now we search for P(x not in [-+fatal_noise_limit] | σ = safe_sigma) = p_error - // P(x not in [-+kappa] | σ = 1) = p_error - let kappa: f64 = sigma_scale_of_error_probability(maximum_acceptable_error_probability); - let safe_sigma = fatal_noise_limit / kappa; - let variance_max = Variance::from_modular_variance::(square(safe_sigma)); + let safe_variance = error::variance_max( + precision, + ciphertext_modulus_log, + maximum_acceptable_error_probability, + ); + let kappa = error::sigma_scale_of_error_probability(maximum_acceptable_error_probability); let consts = OptimizationDecompositionsConsts { kappa, @@ -502,7 +495,7 @@ pub fn optimize_one( blind_rotate_decompositions: pareto::BR_BL .map(|(log2_base, level)| BrDecompositionParameters { level, log2_base }) .to_vec(), - variance_max: variance_max.get_variance(), + safe_variance, }; let mut state = OptimizationState { @@ -524,7 +517,7 @@ pub fn optimize_one( glwe_poly_size, ) .get_variance() - > consts.variance_max + > consts.safe_variance }; let skip = |glwe_dim, glwe_poly_size| match restart_at { diff --git a/concrete-optimizer/src/optimization/dag/mod.rs b/concrete-optimizer/src/optimization/dag/mod.rs new file mode 100644 index 000000000..d3c6cdaf7 --- /dev/null +++ b/concrete-optimizer/src/optimization/dag/mod.rs @@ -0,0 +1 @@ +pub mod solo_key; diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs new file mode 100644 index 000000000..75ddef397 --- /dev/null +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -0,0 +1,587 @@ +use super::symbolic_variance::{SymbolicVariance, VarianceOrigin}; +use crate::dag::operator::{ + dot_kind, DotKind, LevelledComplexity, OperatorIndex, Precision, Shape, +}; +use crate::dag::unparametrized; +use crate::utils::square; + +// private short convention +use DotKind as DK; +use VarianceOrigin as VO; +type Op = unparametrized::UnparameterizedOperator; + +fn first<'a, Property>(inputs: &[OperatorIndex], properties: &'a [Property]) -> &'a Property { + &properties[inputs[0].i] +} + +fn assert_all_same( + inputs: &[OperatorIndex], + properties: &[Property], +) { + let first = first(inputs, properties); + for input in inputs.iter().skip(1) { + assert_eq!(first, &properties[input.i]); + } +} + +fn assert_inputs_uniform_precisions( + op: &unparametrized::UnparameterizedOperator, + out_precisions: &[Precision], +) { + if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op { + assert_all_same(inputs, out_precisions); + } +} + +fn assert_dot_uniform_inputs_shape( + op: &unparametrized::UnparameterizedOperator, + out_shapes: &[Shape], +) { + if let Op::Dot { inputs, .. } = op { + assert_all_same(inputs, out_shapes); + } +} + +fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) { + if let Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } = op { + assert!(!inputs.is_empty()); + } +} + +fn assert_dag_correctness(dag: &unparametrized::OperationDag) { + for op in &dag.operators { + assert_non_empty_inputs(op); + } +} + +fn assert_valid_variances(dag: &OperationDag) { + for &out_variance in &dag.out_variances { + assert!( + SymbolicVariance::ZERO == out_variance // Special case of multiply by 0 + || 1.0 <= out_variance.input_vf + || 1.0 <= out_variance.lut_vf + ); + } +} + +fn assert_properties_correctness(dag: &OperationDag) { + for op in &dag.operators { + assert_inputs_uniform_precisions(op, &dag.out_precisions); + assert_dot_uniform_inputs_shape(op, &dag.out_shapes); + } + assert_valid_variances(dag); +} + +fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance]) -> VarianceOrigin { + let first_origin = first(inputs, out_variances).origin(); + for input in inputs.iter().skip(1) { + let item = &out_variances[input.i]; + if first_origin != item.origin() { + return VO::Mixed; + } + } + first_origin +} + +#[derive(Clone, Debug)] +pub struct OperationDag { + pub operators: Vec, + // Collect all operators ouput shape + pub out_shapes: Vec, + // Collect all operators ouput precision + pub out_precisions: Vec, + // Collect all operators ouput variances + pub out_variances: Vec, + pub nb_luts: u64, + // The full dag levelled complexity + pub levelled_complexity: LevelledComplexity, + // Global summaries of worst noise cases + pub noise_summary: NoiseSummary, +} + +#[derive(Clone, Debug)] +pub struct NoiseSummary { + // All final variance factor not entering a lut (usually final levelledOp) + pub pareto_vfs_final: Vec, + // All variance factor entering a lut + pub pareto_vfs_in_lut: Vec, +} + +impl OperationDag { + pub fn peek_variance( + &self, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, + ) -> f64 { + peek_variance( + self, + input_noise_out, + blind_rotate_noise_out, + noise_keyswitch, + noise_modulus_switching, + ) + } + + pub fn complexity_cost(&self, input_lwe_dimension: u64, one_lut_cost: f64) -> f64 { + let luts_cost = one_lut_cost * (self.nb_luts as f64); + let levelled_cost = self.levelled_complexity.cost(input_lwe_dimension); + luts_cost + levelled_cost + } +} + +fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape { + match op { + Op::Input { out_shape, .. } | Op::LevelledOp { out_shape, .. } => out_shape.clone(), + Op::Lut { input, .. } => out_shapes[input.i].clone(), + Op::Dot { + inputs, weights, .. + } => { + if inputs.is_empty() { + return Shape::number(); + } + let input_shape = first(inputs, out_shapes); + let kind = dot_kind(inputs.len() as u64, input_shape, weights); + match kind { + DK::Simple | DK::Tensor => Shape::number(), + DK::CompatibleTensor => weights.shape.clone(), + DK::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()), + DK::Unsupported { .. } => panic!("Unsupported"), + } + } + } +} + +fn out_shapes(dag: &unparametrized::OperationDag) -> Vec { + let nb_ops = dag.operators.len(); + let mut out_shapes = Vec::::with_capacity(nb_ops); + for op in &dag.operators { + let shape = out_shape(op, &mut out_shapes); + out_shapes.push(shape); + } + out_shapes +} + +fn out_precision( + op: &unparametrized::UnparameterizedOperator, + out_precisions: &mut [Precision], +) -> Precision { + match op { + Op::Input { out_precision, .. } => *out_precision, + Op::Lut { input, .. } => out_precisions[input.i], + Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i], + } +} + +fn out_precisions(dag: &unparametrized::OperationDag) -> Vec { + let nb_ops = dag.operators.len(); + let mut out_precisions = Vec::::with_capacity(nb_ops); + for op in &dag.operators { + let precision = out_precision(op, &mut out_precisions); + out_precisions.push(precision); + } + out_precisions +} + +fn out_variance( + op: &unparametrized::UnparameterizedOperator, + out_shapes: &[Shape], + out_variances: &mut [SymbolicVariance], +) -> SymbolicVariance { + // Maintain a linear combination of input_variance and lut_out_variance + // TODO: track each elements instead of container + match op { + Op::Input { .. } => SymbolicVariance::INPUT, + Op::Lut { .. } => SymbolicVariance::LUT, + Op::LevelledOp { inputs, manp, .. } => { + let variance_factor = SymbolicVariance::manp_to_variance_factor(*manp); + let origin = match variance_origin(inputs, out_variances) { + VO::Input => SymbolicVariance::INPUT, + VO::Lut | VO::Mixed /* Mixed: assume the worst */ + => SymbolicVariance::LUT + }; + origin * variance_factor + } + Op::Dot { + inputs, weights, .. + } => { + let input_shape = first(inputs, out_shapes); + let kind = dot_kind(inputs.len() as u64, input_shape, weights); + match kind { + DK::Simple | DK::Tensor => { + let first_input = inputs[0]; + let mut out_variance = SymbolicVariance::ZERO; + for (j, &weight) in weights.values.iter().enumerate() { + let k = if kind == DK::Simple { + inputs[j].i + } else { + first_input.i + }; + out_variance += out_variances[k] * square(weight); + } + out_variance + } + DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"), + DK::Unsupported { .. } => panic!("Unsupported"), + } + } + } +} + +fn out_variances( + dag: &unparametrized::OperationDag, + out_shapes: &[Shape], +) -> Vec { + let nb_ops = dag.operators.len(); + let mut out_variances = Vec::with_capacity(nb_ops); + for op in &dag.operators { + let vf = out_variance(op, out_shapes, &mut out_variances); + out_variances.push(vf); + } + out_variances +} + +fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec { + let nb_ops = dag.operators.len(); + let mut extra_values_to_check = vec![true; nb_ops]; + for op in &dag.operators { + match op { + Op::Input { .. } => (), + Op::Lut { input, .. } => { + extra_values_to_check[input.i] = false; + } + Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => { + for input in inputs { + extra_values_to_check[input.i] = false; + } + } + } + } + extra_values_to_check +} + +fn extra_final_variances( + dag: &unparametrized::OperationDag, + out_variances: &[SymbolicVariance], +) -> Vec { + extra_final_values_to_check(dag) + .iter() + .enumerate() + .filter_map(|(i, &is_final)| { + if is_final { + Some(out_variances[i]) + } else { + None + } + }) + .collect() +} + +fn in_luts_variance( + dag: &unparametrized::OperationDag, + out_variances: &[SymbolicVariance], +) -> Vec { + let only_luts = |op| { + if let &Op::Lut { input, .. } = op { + Some(out_variances[input.i]) + } else { + None + } + }; + dag.operators.iter().filter_map(only_luts).collect() +} + +fn op_levelled_complexity( + op: &unparametrized::UnparameterizedOperator, + out_shapes: &[Shape], +) -> LevelledComplexity { + match op { + Op::Dot { + inputs, weights, .. + } => { + let input_shape = first(inputs, out_shapes); + let kind = dot_kind(inputs.len() as u64, input_shape, weights); + match kind { + DK::Simple | DK::Tensor => LevelledComplexity::ADDITION * weights.flat_size(), + DK::CompatibleTensor { .. } | DK::Broadcast { .. } => todo!("TODO"), + DK::Unsupported { .. } => panic!("Unsupported"), + } + } + Op::LevelledOp { complexity, .. } => *complexity, + Op::Input { .. } | Op::Lut { .. } => LevelledComplexity::ZERO, + } +} + +fn levelled_complexity( + dag: &unparametrized::OperationDag, + out_shapes: &[Shape], +) -> LevelledComplexity { + let mut levelled_complexity = LevelledComplexity::ZERO; + for op in &dag.operators { + levelled_complexity += op_levelled_complexity(op, out_shapes); + } + levelled_complexity +} + +fn max_update(current: &mut f64, candidate: f64) { + if candidate > *current { + *current = candidate; + } +} + +fn noise_summary( + final_variances: Vec, + in_luts_variance: Vec, +) -> NoiseSummary { + let pareto_vfs_final = SymbolicVariance::reduce_to_pareto_front(final_variances); + let pareto_vfs_in_lut = SymbolicVariance::reduce_to_pareto_front(in_luts_variance); + NoiseSummary { + pareto_vfs_final, + pareto_vfs_in_lut, + } +} + +pub fn analyze(dag: &unparametrized::OperationDag) -> OperationDag { + assert_dag_correctness(dag); + let out_shapes = out_shapes(dag); + let out_precisions = out_precisions(dag); + let out_variances = out_variances(dag, &out_shapes); + let in_luts_variance = in_luts_variance(dag, &out_variances); + let nb_luts = in_luts_variance.len() as u64; + let extra_final_variances = extra_final_variances(dag, &out_variances); + let levelled_complexity = levelled_complexity(dag, &out_shapes); + let noise_summary = noise_summary(extra_final_variances, in_luts_variance); + let result = OperationDag { + operators: dag.operators.clone(), + out_shapes, + out_precisions, + out_variances, + nb_luts, + levelled_complexity, + noise_summary, + }; + assert_properties_correctness(&result); + result +} + +// Compute the maximum attained variance for the full dag +// TODO take a noise summary => peek_error or global error +fn peek_variance( + dag: &OperationDag, + input_noise_out: f64, + blind_rotate_noise_out: f64, + noise_keyswitch: f64, + noise_modulus_switching: f64, +) -> f64 { + assert!(input_noise_out < blind_rotate_noise_out); + let mut variance_peek_final = 0.0; // updated by the loop + for vf in &dag.noise_summary.pareto_vfs_final { + max_update( + &mut variance_peek_final, + vf.eval(input_noise_out, blind_rotate_noise_out), + ); + } + + let mut variance_peek_in_lut = 0.0; // updated by the loop + for vf in &dag.noise_summary.pareto_vfs_in_lut { + max_update( + &mut variance_peek_in_lut, + vf.eval(input_noise_out, blind_rotate_noise_out), + ); + } + let peek_in_lut = variance_peek_in_lut + noise_keyswitch + noise_modulus_switching; + peek_in_lut.max(variance_peek_final) +} + +#[cfg(test)] +mod tests { + + use super::*; + use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape, Weights}; + use crate::dag::unparametrized; + use crate::utils::square; + + fn assert_f64_eq(v: f64, expected: f64) { + approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON); + } + + #[test] + fn test_1_input() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(1, Shape::number()); + let analysis = analyze(&graph); + let one_lut_cost = 100.0; + let lwe_dim = 1024; + let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + + assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT); + assert_eq!(analysis.out_shapes[input1.i], Shape::number()); + assert_eq!(analysis.levelled_complexity, LevelledComplexity::ZERO); + assert_eq!(analysis.out_precisions[input1.i], 1); + assert_f64_eq(complexity_cost, 0.0); + assert!(analysis.nb_luts == 0); + let summary = analysis.noise_summary; + assert!(summary.pareto_vfs_final.len() == 1); + assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 1.0); + assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0); + assert!(summary.pareto_vfs_in_lut.is_empty()); + } + + #[test] + fn test_1_lut() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(8, Shape::number()); + let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN); + let analysis = analyze(&graph); + let one_lut_cost = 100.0; + let lwe_dim = 1024; + let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + + assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT); + assert!(analysis.out_shapes[lut1.i] == Shape::number()); + assert!(analysis.levelled_complexity == LevelledComplexity::ZERO); + assert_eq!(analysis.out_precisions[lut1.i], 8); + assert_f64_eq(one_lut_cost, complexity_cost); + let summary = analysis.noise_summary; + assert!(summary.pareto_vfs_final.len() == 1); + assert!(summary.pareto_vfs_in_lut.len() == 1); + assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 0.0); + assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0); + assert_f64_eq(summary.pareto_vfs_in_lut[0].input_vf, 1.0); + assert_f64_eq(summary.pareto_vfs_in_lut[0].lut_vf, 0.0); + } + + #[test] + fn test_1_dot() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(1, Shape::number()); + let weights = Weights::vector([1, 2]); + let norm2: f64 = 1.0 * 1.0 + 2.0 * 2.0; + let dot = graph.add_dot([input1, input1], weights); + let analysis = analyze(&graph); + let one_lut_cost = 100.0; + let lwe_dim = 1024; + let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + + let expected_var = SymbolicVariance { + input_vf: norm2, + lut_vf: 0.0, + }; + assert!(analysis.out_variances[dot.i] == expected_var); + assert!(analysis.out_shapes[dot.i] == Shape::number()); + assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 2); + assert_eq!(analysis.out_precisions[dot.i], 1); + let expected_dot_cost = (2 * lwe_dim) as f64; + assert_f64_eq(expected_dot_cost, complexity_cost); + let summary = analysis.noise_summary; + assert!(summary.pareto_vfs_in_lut.is_empty()); + assert!(summary.pareto_vfs_final.len() == 1); + assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0); + assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 0.0); + } + + #[test] + fn test_1_dot_levelled() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(3, Shape::number()); + let cpx_dot = LevelledComplexity::ADDITION; + let weights = Weights::vector([1, 2]); + let manp = 1.0 * 1.0 + 2.0 * 2_f64; + let dot = graph.add_levelled_op([input1, input1], cpx_dot, manp, Shape::number(), "dot"); + let analysis = analyze(&graph); + let one_lut_cost = 100.0; + let lwe_dim = 1024; + let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + + assert!(analysis.out_variances[dot.i].origin() == VO::Input); + assert_eq!(analysis.out_precisions[dot.i], 3); + let expected_square_norm2 = weights.square_norm2() as f64; + let actual_square_norm2 = analysis.out_variances[dot.i].input_vf; + // Due to call on log2() to compute manp the result is not exact + assert_f64_eq(actual_square_norm2, expected_square_norm2); + assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION); + assert_f64_eq(lwe_dim as f64, complexity_cost); + let summary = analysis.noise_summary; + assert!(summary.pareto_vfs_in_lut.is_empty()); + assert!(summary.pareto_vfs_final.len() == 1); + assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Input); + assert_f64_eq(summary.pareto_vfs_final[0].input_vf, 5.0); + } + + #[test] + fn test_dot_tensorized_lut_dot_lut() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(1, Shape::vector(2)); + let weights = &Weights::vector([1, 2]); + let dot1 = graph.add_dot([input1], weights); + let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN); + let dot2 = graph.add_dot([lut1, lut1], weights); + let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN); + let analysis = analyze(&graph); + let one_lut_cost = 100.0; + let lwe_dim = 1024; + let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + + let expected_var_dot1 = SymbolicVariance { + input_vf: weights.square_norm2() as f64, + lut_vf: 0.0, + }; + let expected_var_lut1 = SymbolicVariance { + input_vf: 0.0, + lut_vf: 1.0, + }; + let expected_var_dot2 = SymbolicVariance { + input_vf: 0.0, + lut_vf: weights.square_norm2() as f64, + }; + let expected_var_lut2 = SymbolicVariance { + input_vf: 0.0, + lut_vf: 1.0, + }; + assert!(analysis.out_variances[dot1.i] == expected_var_dot1); + assert!(analysis.out_variances[lut1.i] == expected_var_lut1); + assert!(analysis.out_variances[dot2.i] == expected_var_dot2); + assert!(analysis.out_variances[lut2.i] == expected_var_lut2); + assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 4); + let expected_cost = (lwe_dim * 4) as f64 + 2.0 * one_lut_cost; + assert_f64_eq(expected_cost, complexity_cost); + let summary = analysis.noise_summary; + assert_eq!(summary.pareto_vfs_final.len(), 1); + assert_eq!(summary.pareto_vfs_final[0].origin(), VO::Lut); + assert_f64_eq(summary.pareto_vfs_final[0].lut_vf, 1.0); + assert_eq!(summary.pareto_vfs_in_lut.len(), 1); + assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Lut); + assert_f64_eq( + summary.pareto_vfs_in_lut[0].lut_vf, + weights.square_norm2() as f64, + ); + } + + #[test] + fn test_lut_dot_mixed_lut() { + let mut graph = unparametrized::OperationDag::new(); + let input1 = graph.add_input(1, Shape::number()); + let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN); + let weights = &Weights::vector([2, 3]); + let dot1 = graph.add_dot([input1, lut1], weights); + let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN); + let analysis = analyze(&graph); + let one_lut_cost = 100.0; + let lwe_dim = 1024; + let complexity_cost = analysis.complexity_cost(lwe_dim, one_lut_cost); + + let expected_cost = (2 * lwe_dim) as f64 + 2.0 * one_lut_cost; + assert_f64_eq(expected_cost, complexity_cost); + let expected_mixed = SymbolicVariance { + input_vf: square(weights.values[0] as f64), + lut_vf: square(weights.values[1] as f64), + }; + let summary = analysis.noise_summary; + assert_eq!(summary.pareto_vfs_final.len(), 1); + assert_eq!(summary.pareto_vfs_final[0], SymbolicVariance::LUT); + assert_eq!(summary.pareto_vfs_in_lut.len(), 1); + assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VO::Mixed); + assert_eq!(summary.pareto_vfs_in_lut[0], expected_mixed); + } +} diff --git a/concrete-optimizer/src/optimization/dag/solo_key/mod.rs b/concrete-optimizer/src/optimization/dag/solo_key/mod.rs new file mode 100644 index 000000000..ef4f9d882 --- /dev/null +++ b/concrete-optimizer/src/optimization/dag/solo_key/mod.rs @@ -0,0 +1,3 @@ +pub(crate) mod analyze; +pub mod optimize; +pub(crate) mod symbolic_variance; diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs new file mode 100644 index 000000000..5ca4c5630 --- /dev/null +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -0,0 +1,590 @@ +use concrete_commons::dispersion::{DispersionParameter, Variance}; +use concrete_commons::numeric::UnsignedInteger; + +use crate::dag::operator::LevelledComplexity; +use crate::dag::unparametrized; +use crate::noise_estimator::error; +use crate::noise_estimator::error::error_probability_of_sigma_scale; +use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; + +use crate::optimization::atomic_pattern::{ + pareto_blind_rotate, pareto_keyswitch, OptimizationDecompositionsConsts, OptimizationState, + Solution, +}; + +use crate::parameters::{BrDecompositionParameters, GlweParameters, KsDecompositionParameters}; +use crate::pareto; +use crate::security::glwe::minimal_variance; +use crate::utils::square; + +use super::analyze; + +const CUTS: bool = true; +const PARETO_CUTS: bool = true; +const CROSS_PARETO_CUTS: bool = PARETO_CUTS && true; + +#[allow(clippy::too_many_lines)] +fn update_best_solution_with_best_decompositions( + state: &mut OptimizationState, + consts: &OptimizationDecompositionsConsts, + dag: &analyze::OperationDag, + internal_dim: u64, + glwe_params: GlweParameters, + noise_modulus_switching: f64, +) { + let safe_variance = consts.safe_variance; + let glwe_poly_size = glwe_params.polynomial_size(); + let input_lwe_dimension = glwe_params.glwe_dimension * glwe_poly_size; + + let mut best_complexity = state.best_solution.map_or(f64::INFINITY, |s| s.complexity); + let mut best_lut_complexity = state + .best_solution + .map_or(f64::INFINITY, |s| s.lut_complexity); + let mut best_variance = state.best_solution.map_or(f64::INFINITY, |s| s.noise_max); + + let mut cut_complexity = + (best_complexity - dag.complexity_cost(input_lwe_dimension, 0.0)) / (dag.nb_luts as f64); + let mut cut_noise = safe_variance - noise_modulus_switching; + + if dag.nb_luts == 0 { + cut_noise = f64::INFINITY; + cut_complexity = f64::INFINITY; + } + + let br_pareto = + pareto_blind_rotate::(consts, internal_dim, glwe_params, cut_complexity, cut_noise); + if br_pareto.is_empty() { + return; + } + if PARETO_CUTS { + cut_noise -= br_pareto[br_pareto.len() - 1].noise; + cut_complexity -= br_pareto[0].complexity; + } + + let ks_pareto = pareto_keyswitch::( + consts, + input_lwe_dimension, + internal_dim, + cut_complexity, + cut_noise, + ); + if ks_pareto.is_empty() { + return; + } + + let i_max_ks = ks_pareto.len() - 1; + let mut i_current_max_ks = i_max_ks; + let input_noise_out = minimal_variance( + glwe_params, + consts.ciphertext_modulus_log, + consts.security_level, + ) + .get_variance(); + + let mut best_br_i = 0; + let mut best_ks_i = 0; + let mut update_best_solution = false; + + for br_quantity in br_pareto { + // increasing complexity, decreasing variance + let peek_variance = dag.peek_variance( + input_noise_out, + br_quantity.noise, + 0.0, + noise_modulus_switching, + ); + if peek_variance > safe_variance && CUTS { + continue; + } + let one_pbs_cost = br_quantity.complexity; + let complexity = dag.complexity_cost(input_lwe_dimension, one_pbs_cost); + if complexity > best_complexity { + // As best can evolves it is complementary to blind_rotate_quantities cuts. + if PARETO_CUTS { + break; + } else if CUTS { + continue; + } + } + for i_ks_pareto in (0..=i_current_max_ks).rev() { + // increasing variance, decreasing complexity + let ks_quantity = ks_pareto[i_ks_pareto]; + let peek_variance = dag.peek_variance( + input_noise_out, + br_quantity.noise, + ks_quantity.noise, + noise_modulus_switching, + ); + // let noise_max = br_quantity.noise * dag.lut_base_noise_worst_lut + ks_quantity.noise + noise_modulus_switching; + if peek_variance > safe_variance { + if CROSS_PARETO_CUTS { + // the pareto of 2 added pareto is scanned linearly + // but with all cuts, pre-computing => no gain + i_current_max_ks = usize::min(i_ks_pareto + 1, i_max_ks); + break; + // it's compatible with next i_br but with the worst complexity + } else if PARETO_CUTS { + // increasing variance => we can skip all remaining + break; + } + continue; + } + let one_lut_cost = ks_quantity.complexity + br_quantity.complexity; + let complexity = dag.complexity_cost(input_lwe_dimension, one_lut_cost); + + let better_complexity = complexity < best_complexity; + #[allow(clippy::float_cmp)] + let same_complexity_with_less_errors = + complexity == best_complexity && peek_variance < best_variance; + if better_complexity || same_complexity_with_less_errors { + best_lut_complexity = one_lut_cost; + best_complexity = complexity; + best_variance = peek_variance; + best_br_i = br_quantity.index; + best_ks_i = ks_quantity.index; + update_best_solution = true; + } + } + } // br ks + + if update_best_solution { + let sigma = Variance(safe_variance).get_standard_dev() * consts.kappa; + let sigma_scale = sigma / Variance(best_variance).get_standard_dev(); + let p_error = error_probability_of_sigma_scale(sigma_scale); + let BrDecompositionParameters { + level: br_l, + log2_base: br_b, + } = consts.blind_rotate_decompositions[best_br_i]; + let KsDecompositionParameters { + level: ks_l, + log2_base: ks_b, + } = consts.keyswitch_decompositions[best_ks_i]; + state.best_solution = Some(Solution { + input_lwe_dimension, + internal_ks_output_lwe_dimension: internal_dim, + ks_decomposition_level_count: ks_l, + ks_decomposition_base_log: ks_b, + glwe_polynomial_size: glwe_params.polynomial_size(), + glwe_dimension: glwe_params.glwe_dimension, + br_decomposition_level_count: br_l, + br_decomposition_base_log: br_b, + noise_max: best_variance, + complexity: best_complexity, + lut_complexity: best_lut_complexity, + p_error, + }); + } +} + +const REL_EPSILON_PROBA: f64 = 1.0 + 1e-8; + +#[allow(clippy::too_many_lines)] +pub fn optimize( + dag: &unparametrized::OperationDag, + security_level: u64, + maximum_acceptable_error_probability: f64, + glwe_log_polynomial_sizes: &[u64], + glwe_dimensions: &[u64], + internal_lwe_dimensions: &[u64], +) -> OptimizationState { + let ciphertext_modulus_log = W::BITS as u64; + let dag = analyze::analyze(dag); + + let &max_precision = dag.out_precisions.iter().max().unwrap(); + + let safe_variance = error::variance_max( + max_precision as u64, + ciphertext_modulus_log, + maximum_acceptable_error_probability, + ); + let kappa = error::sigma_scale_of_error_probability(maximum_acceptable_error_probability); + + let consts = OptimizationDecompositionsConsts { + kappa, + sum_size: 0, // superseeded by dag.complexity_cost + security_level, + noise_factor: f64::NAN, // superseeded by dag.lut_variance_max + ciphertext_modulus_log, + keyswitch_decompositions: pareto::KS_BL + .map(|(log2_base, level)| KsDecompositionParameters { level, log2_base }) + .to_vec(), + blind_rotate_decompositions: pareto::BR_BL + .map(|(log2_base, level)| BrDecompositionParameters { level, log2_base }) + .to_vec(), + safe_variance, + }; + + let mut state = OptimizationState { + best_solution: None, + count_domain: glwe_dimensions.len() + * glwe_log_polynomial_sizes.len() + * internal_lwe_dimensions.len() + * consts.keyswitch_decompositions.len() + * consts.blind_rotate_decompositions.len(), + }; + + let noise_modulus_switching = |glwe_poly_size, internal_lwe_dimensions| { + noise_atomic_pattern::estimate_modulus_switching_noise_with_binary_key::( + internal_lwe_dimensions, + glwe_poly_size, + ) + .get_variance() + }; + + for &glwe_dim in glwe_dimensions { + for &glwe_log_poly_size in glwe_log_polynomial_sizes { + let glwe_poly_size = 1 << glwe_log_poly_size; + let glwe_params = GlweParameters { + log2_polynomial_size: glwe_log_poly_size, + glwe_dimension: glwe_dim, + }; + for &internal_dim in internal_lwe_dimensions { + let noise_modulus_switching = noise_modulus_switching(glwe_poly_size, internal_dim); + if CUTS && noise_modulus_switching > consts.safe_variance { + // assume this noise is increasing with internal_dim + break; + } + update_best_solution_with_best_decompositions::( + &mut state, + &consts, + &dag, + internal_dim, + glwe_params, + noise_modulus_switching, + ); + if dag.nb_luts == 0 && state.best_solution.is_some() { + return state; + } + } + } + } + + if let Some(sol) = state.best_solution { + assert!(0.0 <= sol.p_error && sol.p_error <= 1.0); + assert!(sol.p_error <= maximum_acceptable_error_probability * REL_EPSILON_PROBA); + } + + state +} + +pub fn optimize_v0( + sum_size: u64, + precision: u64, + security_level: u64, + noise_factor: f64, + maximum_acceptable_error_probability: f64, + glwe_log_polynomial_sizes: &[u64], + glwe_dimensions: &[u64], + internal_lwe_dimensions: &[u64], +) -> OptimizationState { + use crate::dag::operator::{FunctionTable, Shape}; + let same_scale_manp = 0.0; + let manp = square(noise_factor); + let out_shape = &Shape::number(); + let complexity = LevelledComplexity::ADDITION * sum_size; + let comment = "dot"; + let mut dag = unparametrized::OperationDag::new(); + let input1 = dag.add_input(precision as u8, out_shape); + let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment); + let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN); + let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment); + let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN); + let mut state = optimize::( + &dag, + security_level, + maximum_acceptable_error_probability, + glwe_log_polynomial_sizes, + glwe_dimensions, + internal_lwe_dimensions, + ); + if let Some(sol) = &mut state.best_solution { + sol.complexity /= 2.0; + } + state +} + +#[cfg(test)] +mod tests { + use std::time::Instant; + + use crate::dag::operator::{FunctionTable, Shape, Weights}; + use crate::global_parameters::DEFAUT_DOMAINS; + use crate::optimization::dag::solo_key::symbolic_variance::VarianceOrigin; + use crate::utils::square; + + use super::*; + use crate::optimization::atomic_pattern; + + fn small_relative_diff(v1: f64, v2: f64) -> bool { + f64::abs(v1 - v2) / f64::max(v1, v2) <= f64::EPSILON + } + + impl Solution { + fn same(&self, other: Self) -> bool { + let mut other = other; + if small_relative_diff(self.noise_max, other.noise_max) + && small_relative_diff(self.p_error, other.p_error) + { + other.noise_max = self.noise_max; + other.p_error = self.p_error; + } + self == &other + } + } + + const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; + + struct Times { + worst_time: u128, + dag_time: u128, + } + + fn assert_f64_eq(v: f64, expected: f64) { + approx::assert_relative_eq!(v, expected, epsilon = f64::EPSILON); + } + + #[test] + fn test_v0_parameter_ref() { + let mut times = Times { + worst_time: 0, + dag_time: 0, + }; + for log_weight in 0..=16 { + let weight = 1 << log_weight; + for precision in 1..=9 { + v0_parameter_ref(precision, weight, &mut times); + } + } + assert!(times.worst_time * 2 > times.dag_time); + } + + fn v0_parameter_ref(precision: u64, weight: u64, times: &mut Times) { + let security_level = 128; + let glwe_log_polynomial_sizes: Vec = DEFAUT_DOMAINS + .glwe_pbs_constrained + .log2_polynomial_size + .as_vec(); + let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); + let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + let sum_size = 1; + let maximum_acceptable_error_probability = _4_SIGMA; + + let chrono = Instant::now(); + let state = optimize_v0::( + sum_size, + precision, + security_level, + weight as f64, + maximum_acceptable_error_probability, + &glwe_log_polynomial_sizes, + &glwe_dimensions, + &internal_lwe_dimensions, + ); + times.dag_time += chrono.elapsed().as_nanos(); + let chrono = Instant::now(); + let state_ref = atomic_pattern::optimize_one::( + sum_size, + precision, + security_level, + weight as f64, + maximum_acceptable_error_probability, + &glwe_log_polynomial_sizes, + &glwe_dimensions, + &internal_lwe_dimensions, + None, + ); + times.worst_time += chrono.elapsed().as_nanos(); + assert_eq!( + state.best_solution.is_some(), + state_ref.best_solution.is_some() + ); + if state.best_solution.is_none() { + return; + } + let sol = state.best_solution.unwrap(); + let sol_ref = state_ref.best_solution.unwrap(); + assert!(sol.same(sol_ref)); + } + + #[test] + fn test_v0_parameter_ref_with_dot() { + for log_weight in 0..=16 { + let weight = 1 << log_weight; + for precision in 1..=9 { + v0_parameter_ref_with_dot(precision, weight); + } + } + } + + fn v0_parameter_ref_with_dot(precision: u64, weight: u64) { + let mut dag = unparametrized::OperationDag::new(); + { + let input1 = dag.add_input(precision as u8, Shape::number()); + let dot1 = dag.add_dot([input1], [1]); + let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN); + let dot2 = dag.add_dot([lut1], [weight]); + let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN); + } + { + let dag2 = analyze::analyze(&dag); + let summary = dag2.noise_summary; + assert_eq!(summary.pareto_vfs_final.len(), 1); + assert_eq!(summary.pareto_vfs_in_lut.len(), 1); + assert_eq!(summary.pareto_vfs_final[0].origin(), VarianceOrigin::Lut); + assert_f64_eq(1.0, summary.pareto_vfs_final[0].lut_vf); + assert!(summary.pareto_vfs_in_lut.len() == 1); + assert_eq!(summary.pareto_vfs_in_lut[0].origin(), VarianceOrigin::Lut); + assert_f64_eq(square(weight) as f64, summary.pareto_vfs_in_lut[0].lut_vf); + } + + let security_level = 128; + let maximum_acceptable_error_probability = _4_SIGMA; + let glwe_log_polynomial_sizes: Vec = DEFAUT_DOMAINS + .glwe_pbs_constrained + .log2_polynomial_size + .as_vec(); + let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); + let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + let state = optimize::( + &dag, + security_level, + maximum_acceptable_error_probability, + &glwe_log_polynomial_sizes, + &glwe_dimensions, + &internal_lwe_dimensions, + ); + let state_ref = atomic_pattern::optimize_one::( + 1, + precision, + security_level, + weight as f64, + maximum_acceptable_error_probability, + &glwe_log_polynomial_sizes, + &glwe_dimensions, + &internal_lwe_dimensions, + None, + ); + assert_eq!( + state.best_solution.is_some(), + state_ref.best_solution.is_some() + ); + if state.best_solution.is_none() { + return; + } + let sol = state.best_solution.unwrap(); + let mut sol_ref = state_ref.best_solution.unwrap(); + sol_ref.complexity *= 2.0 /* number of luts */; + assert!(sol.same(sol_ref)); + } + + fn no_lut_vs_lut(precision: u64) { + let mut dag_lut = unparametrized::OperationDag::new(); + let input1 = dag_lut.add_input(precision as u8, Shape::number()); + let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN); + + let mut dag_no_lut = unparametrized::OperationDag::new(); + let _input2 = dag_no_lut.add_input(precision as u8, Shape::number()); + + let security_level = 128; + let maximum_acceptable_error_probability = _4_SIGMA; + let glwe_log_polynomial_sizes: Vec = DEFAUT_DOMAINS + .glwe_pbs_constrained + .log2_polynomial_size + .as_vec(); + let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); + let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + + let opt = |dag: &unparametrized::OperationDag| { + optimize::( + dag, + security_level, + maximum_acceptable_error_probability, + &glwe_log_polynomial_sizes, + &glwe_dimensions, + &internal_lwe_dimensions, + ) + }; + + let state_no_lut = opt(&dag_no_lut); + let state_lut = opt(&dag_lut); + assert_eq!( + state_no_lut.best_solution.is_some(), + state_lut.best_solution.is_some() + ); + + if state_lut.best_solution.is_none() { + return; + } + + let sol_no_lut = state_no_lut.best_solution.unwrap(); + let sol_lut = state_lut.best_solution.unwrap(); + assert!(sol_no_lut.complexity < sol_lut.complexity); + } + #[test] + fn test_lut_vs_no_lut() { + for precision in 1..=8 { + no_lut_vs_lut(precision); + } + } + + fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision: u64, weight: u64) { + let weight = &Weights::number(weight); + + let mut dag_1 = unparametrized::OperationDag::new(); + { + let input1 = dag_1.add_input(precision as u8, Shape::number()); + let scaled_input1 = dag_1.add_dot([input1], weight); + let lut1 = dag_1.add_lut(scaled_input1, FunctionTable::UNKWOWN); + let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN); + } + + let mut dag_2 = unparametrized::OperationDag::new(); + { + let input1 = dag_2.add_input(precision as u8, Shape::number()); + let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN); + let scaled_lut1 = dag_2.add_dot([lut1], weight); + let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN); + } + + let security_level = 128; + let maximum_acceptable_error_probability = _4_SIGMA; + let glwe_log_polynomial_sizes: Vec = DEFAUT_DOMAINS + .glwe_pbs_constrained + .log2_polynomial_size + .as_vec(); + let glwe_dimensions: Vec = DEFAUT_DOMAINS.glwe_pbs_constrained.glwe_dimension.as_vec(); + let internal_lwe_dimensions: Vec = DEFAUT_DOMAINS.free_glwe.glwe_dimension.as_vec(); + + let opt = |dag: &unparametrized::OperationDag| { + optimize::( + dag, + security_level, + maximum_acceptable_error_probability, + &glwe_log_polynomial_sizes, + &glwe_dimensions, + &internal_lwe_dimensions, + ) + }; + + let state_1 = opt(&dag_1); + let state_2 = opt(&dag_2); + + if state_1.best_solution.is_none() { + assert!(state_2.best_solution.is_none()); + return; + } + let sol_1 = state_1.best_solution.unwrap(); + let sol_2 = state_2.best_solution.unwrap(); + assert!(sol_1.complexity < sol_2.complexity || sol_1.p_error < sol_2.p_error); + } + + #[test] + fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() { + for log_weight in 1..=16 { + let weight = 1 << log_weight; + for precision in 5..=9 { + lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight); + } + } + } +} diff --git a/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs b/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs new file mode 100644 index 000000000..abcfc4ed2 --- /dev/null +++ b/concrete-optimizer/src/optimization/dag/solo_key/symbolic_variance.rs @@ -0,0 +1,109 @@ +use derive_more::{Add, AddAssign, Sum}; +/** + * A variance that is represented as a linear combination of base variances. + * Only the linear coefficient are known. + * The base variances are unknown. + * Each linear coefficients is a variance factor. + * + * Only 2 base variances are possible in the solo key setup: + * - from input, + * - or from lut output. + * + * We only kown that the first one is lower or equal to the second one. + * Each linear coefficient is a variance factor. + * There are homogenious to squared weight (or summed square weights or squared norm2). + */ +#[derive(Clone, Copy, Add, AddAssign, Sum, Debug, PartialEq, PartialOrd)] +pub struct SymbolicVariance { + pub lut_vf: f64, + pub input_vf: f64, + // variance = vf.lut_vf * lut_out_noise + // + vf.input_vf * input_out_noise + // E.g. variance(dot([lut, input], [3, 4])) = VariancesFactors {lut_vf:9, input_vf: 16} + + // NOTE: lut_base_noise is the first field since it has higher impact, + // see pareto sorting and dominate_or_equal +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum VarianceOrigin { + Input, + Lut, + Mixed, +} + +impl std::ops::Mul for SymbolicVariance { + type Output = Self; + fn mul(self, sq_weight: f64) -> Self { + Self { + input_vf: self.input_vf * sq_weight, + lut_vf: self.lut_vf * sq_weight, + } + } +} + +impl std::ops::Mul for SymbolicVariance { + type Output = Self; + fn mul(self, sq_weight: u64) -> Self { + self * sq_weight as f64 + } +} + +impl SymbolicVariance { + pub const ZERO: Self = Self { + input_vf: 0.0, + lut_vf: 0.0, + }; + pub const INPUT: Self = Self { + input_vf: 1.0, + lut_vf: 0.0, + }; + pub const LUT: Self = Self { + input_vf: 0.0, + lut_vf: 1.0, + }; + + pub fn origin(&self) -> VarianceOrigin { + if self.lut_vf == 0.0 { + VarianceOrigin::Input + } else if self.input_vf == 0.0 { + VarianceOrigin::Lut + } else { + VarianceOrigin::Mixed + } + } + + pub fn manp_to_variance_factor(manp: f64) -> f64 { + manp + } + + pub fn dominate_or_equal(&self, other: &Self) -> bool { + let extra_other_minimal_base_noise = 0.0_f64.max(other.input_vf - self.input_vf); + other.lut_vf + extra_other_minimal_base_noise <= self.lut_vf + } + + pub fn eval(&self, minimal_base_noise: f64, lut_base_noise: f64) -> f64 { + minimal_base_noise * self.input_vf + lut_base_noise * self.lut_vf + } + + pub fn reduce_to_pareto_front(mut vfs: Vec) -> Vec { + if vfs.is_empty() { + return vec![]; + } + vfs.sort_by( + // bigger first + |a, b| b.partial_cmp(a).unwrap(), + ); + // Due to the special domination nature, this can be done in one pass + let mut dominator = vfs[0]; + let mut pareto = vec![dominator]; + for &vf in vfs.iter().skip(1) { + if dominator.dominate_or_equal(&vf) { + continue; + } + dominator = vf; + pareto.push(vf); + } + pareto + } +} diff --git a/concrete-optimizer/src/optimization/mod.rs b/concrete-optimizer/src/optimization/mod.rs index cb66c1737..b18d26ead 100644 --- a/concrete-optimizer/src/optimization/mod.rs +++ b/concrete-optimizer/src/optimization/mod.rs @@ -1 +1,2 @@ pub mod atomic_pattern; +pub mod dag; diff --git a/concrete-optimizer/src/utils/mod.rs b/concrete-optimizer/src/utils/mod.rs new file mode 100644 index 000000000..121783041 --- /dev/null +++ b/concrete-optimizer/src/utils/mod.rs @@ -0,0 +1,15 @@ +use std::ops::Mul; + +pub fn square(v: V) -> V +where + V: Mul + Mul + Copy, +{ + v * v +} + +pub fn square_ref(v: &V) -> V +where + V: Mul + Mul + Copy, +{ + square(*v) +}