From 3db828d3d0e49ce35de16f3d964174dc2e0163aa Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 21 Sep 2022 18:24:58 +0200 Subject: [PATCH] feat: rounded lut for classical pbs step 1, provide the sub-dag construction --- concrete-optimizer/Cargo.toml | 1 + concrete-optimizer/src/dag/mod.rs | 3 - .../src/dag/operator/operator.rs | 12 +- .../src/dag/parameter_indexed.rs | 25 -- .../src/dag/range_parametrized.rs | 9 - concrete-optimizer/src/dag/unparametrized.rs | 288 +++++++++++++++- .../src/dag/value_parametrized.rs | 9 - concrete-optimizer/src/global_parameters.rs | 319 +----------------- .../src/optimization/dag/solo_key/analyze.rs | 148 ++------ .../src/optimization/dag/solo_key/optimize.rs | 197 +++++++---- .../dag/solo_key/optimize_generic.rs | 19 +- 11 files changed, 441 insertions(+), 589 deletions(-) delete mode 100644 concrete-optimizer/src/dag/parameter_indexed.rs delete mode 100644 concrete-optimizer/src/dag/range_parametrized.rs delete mode 100644 concrete-optimizer/src/dag/value_parametrized.rs diff --git a/concrete-optimizer/Cargo.toml b/concrete-optimizer/Cargo.toml index 2c37b1e79..f27c5426b 100644 --- a/concrete-optimizer/Cargo.toml +++ b/concrete-optimizer/Cargo.toml @@ -15,6 +15,7 @@ puruspe = "0.2.0" [dev-dependencies] approx = "0.5" +once_cell = "1.16.0" pretty_assertions = "1.2.1" [lib] diff --git a/concrete-optimizer/src/dag/mod.rs b/concrete-optimizer/src/dag/mod.rs index 296ce434c..a21dce048 100644 --- a/concrete-optimizer/src/dag/mod.rs +++ b/concrete-optimizer/src/dag/mod.rs @@ -1,5 +1,2 @@ pub mod operator; -pub mod parameter_indexed; -pub mod range_parametrized; pub mod unparametrized; -pub mod value_parametrized; diff --git a/concrete-optimizer/src/dag/operator/operator.rs b/concrete-optimizer/src/dag/operator/operator.rs index c0071c72e..79b4b4296 100644 --- a/concrete-optimizer/src/dag/operator/operator.rs +++ b/concrete-optimizer/src/dag/operator/operator.rs @@ -65,22 +65,19 @@ pub type Precision = u8; pub const MIN_PRECISION: Precision = 1; #[derive(Clone, PartialEq, Debug)] -pub enum Operator { +pub enum Operator { Input { out_precision: Precision, out_shape: Shape, - extra_data: InputExtraData, }, Lut { input: OperatorIndex, table: FunctionTable, out_precision: Precision, - extra_data: LutExtraData, }, Dot { inputs: Vec, weights: Weights, - extra_data: DotExtraData, }, LevelledOp { inputs: Vec, @@ -88,7 +85,12 @@ pub enum Operator; - -pub struct OperationDag { - pub(crate) operators: Vec, - pub(crate) parameters_count: ParameterCount, - pub(crate) reverse_map: ParameterToOperation, -} diff --git a/concrete-optimizer/src/dag/range_parametrized.rs b/concrete-optimizer/src/dag/range_parametrized.rs deleted file mode 100644 index 7ebb24b8b..000000000 --- a/concrete-optimizer/src/dag/range_parametrized.rs +++ /dev/null @@ -1,9 +0,0 @@ -use crate::dag::parameter_indexed::OperatorParameterIndexed; -use crate::global_parameters::{ParameterRanges, ParameterToOperation}; - -#[allow(dead_code)] -pub struct OperationDag { - pub(crate) operators: Vec, - pub(crate) parameter_ranges: ParameterRanges, - pub(crate) reverse_map: ParameterToOperation, -} diff --git a/concrete-optimizer/src/dag/unparametrized.rs b/concrete-optimizer/src/dag/unparametrized.rs index 51d3576b4..c1fe33a6f 100644 --- a/concrete-optimizer/src/dag/unparametrized.rs +++ b/concrete-optimizer/src/dag/unparametrized.rs @@ -1,24 +1,36 @@ use std::fmt::Write; use crate::dag::operator::{ - FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, Shape, Weights, + dot_kind, DotKind, FunctionTable, LevelledComplexity, Operator, OperatorIndex, Precision, + Shape, Weights, }; -pub(crate) type UnparameterizedOperator = Operator<(), (), (), ()>; +pub(crate) type UnparameterizedOperator = Operator; #[derive(Clone, PartialEq, Debug)] #[must_use] pub struct OperationDag { pub(crate) operators: Vec, + // Collect all operators ouput shape + pub(crate) out_shapes: Vec, + // Collect all operators ouput precision + pub(crate) out_precisions: Vec, } impl OperationDag { pub const fn new() -> Self { - Self { operators: vec![] } + Self { + operators: vec![], + out_shapes: vec![], + out_precisions: vec![], + } } fn add_operator(&mut self, operator: UnparameterizedOperator) -> OperatorIndex { let i = self.operators.len(); + self.out_precisions + .push(self.infer_out_precision(&operator)); + self.out_shapes.push(self.infer_out_shape(&operator)); self.operators.push(operator); OperatorIndex { i } } @@ -32,7 +44,6 @@ impl OperationDag { self.add_operator(Operator::Input { out_precision, out_shape, - extra_data: (), }) } @@ -46,7 +57,6 @@ impl OperationDag { input, table, out_precision, - extra_data: (), }) } @@ -57,11 +67,7 @@ impl OperationDag { ) -> OperatorIndex { let inputs = inputs.into(); let weights = weights.into(); - self.add_operator(Operator::Dot { - inputs, - weights, - extra_data: (), - }) + self.add_operator(Operator::Dot { inputs, weights }) } pub fn add_levelled_op( @@ -81,11 +87,25 @@ impl OperationDag { manp, out_shape, comment, - extra_data: (), }; self.add_operator(op) } + pub fn add_unsafe_cast( + &mut self, + input: OperatorIndex, + out_precision: Precision, + ) -> OperatorIndex { + let input_precision = self.out_precisions[input.i]; + if input_precision == out_precision { + return input; + } + self.add_operator(Operator::UnsafeCast { + input, + out_precision, + }) + } + #[allow(clippy::len_without_is_empty)] pub fn len(&self) -> usize { self.operators.len() @@ -100,6 +120,139 @@ impl OperationDag { } acc } + + fn add_shift_left_lsb_to_msb_no_padding(&mut self, input: OperatorIndex) -> OperatorIndex { + // Convert any input to simple 1bit msb replacing the padding + // For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding. + let in_precision = self.out_precisions[input.i]; + let shift_factor = Weights::number(1 << (in_precision as i64)); + let lsb_as_msb = self.add_dot([input], shift_factor); + self.add_unsafe_cast(lsb_as_msb, 0 as Precision) + } + + fn add_lut_1bit_no_padding( + &mut self, + input: OperatorIndex, + table: FunctionTable, + out_precision: Precision, + ) -> OperatorIndex { + // For now encoding is not explicit, so 1 bit content without padding <=> 0 bit content with padding. + let in_precision = self.out_precisions[input.i]; + assert!(in_precision == 0); + // An add after with a clear constant is skipped here as it doesn't change noise handling. + self.add_lut(input, table, out_precision) + } + + fn add_shift_right_msb_no_padding_to_lsb( + &mut self, + input: OperatorIndex, + out_precision: Precision, + ) -> OperatorIndex { + // Convert simple 1 bit msb to a nbit with zero padding + let to_nbits_padded = FunctionTable::UNKWOWN; + self.add_lut_1bit_no_padding(input, to_nbits_padded, out_precision) + } + + fn add_isolate_lowest_bit(&mut self, input: OperatorIndex) -> OperatorIndex { + // The lowest bit is converted to a cyphertext of same precision as input. + // Introduce a pbs of input precision but this precision is only used on 1 levelled op and converted to lower precision + // Noise is reduced by a pbs. + let out_precision = self.out_precisions[input.i]; + let lsb_as_msb = self.add_shift_left_lsb_to_msb_no_padding(input); + self.add_shift_right_msb_no_padding_to_lsb(lsb_as_msb, out_precision) + } + + pub fn add_truncate_1_bit(&mut self, input: OperatorIndex) -> OperatorIndex { + // Reset a bit. + // ex: 10110 is truncated to 1011, 10111 is truncated to 1011 + let in_precision = self.out_precisions[input.i]; + let lowest_bit = self.add_isolate_lowest_bit(input); + let bit_cleared = self.add_dot([input, lowest_bit], [1, -1]); + self.add_unsafe_cast(bit_cleared, in_precision - 1) + } + + pub fn add_expanded_round( + &mut self, + input: OperatorIndex, + rounded_precision: Precision, + ) -> OperatorIndex { + // Round such that the ouput has precision out_precision. + // We round by adding 2**(removed_precision - 1) to the last remaining bit to clear (this step is a no-op). + // Than all lower bits are cleared. + // Note: this is a simplified graph, some constant additions are missing without consequence on crypto parameter choice. + // Note: reset and rounds could be done by 4, 3, 2 and 1 bits groups for efficiency. + // bit efficiency is better for 4 precision then 3, but the feasability is lower for high noise + let in_precision = self.out_precisions[input.i]; + assert!(rounded_precision <= in_precision); + if in_precision == rounded_precision { + return input; + } + // Add rounding constant, this is a represented as non-op since it doesn't influence crypto parameters. + let mut rounded = input; + // The rounded is in high precision with garbage lowest bits + let bits_to_truncate = in_precision - rounded_precision; + for _ in 1..=bits_to_truncate as i64 { + rounded = self.add_truncate_1_bit(rounded); + } + rounded + } + + pub fn add_expanded_rounded_lut( + &mut self, + input: OperatorIndex, + table: FunctionTable, + rounded_precision: Precision, + out_precision: Precision, + ) -> OperatorIndex { + // note: this is a simplified graph, some constant additions are missing without consequence on crypto parameter choice. + let rounded = self.add_expanded_round(input, rounded_precision); + self.add_lut(rounded, table, out_precision) + } + + fn infer_out_shape(&self, op: &UnparameterizedOperator) -> Shape { + match op { + Operator::Input { out_shape, .. } | Operator::LevelledOp { out_shape, .. } => { + out_shape.clone() + } + Operator::Lut { input, .. } | Operator::UnsafeCast { input, .. } => { + self.out_shapes[input.i].clone() + } + Operator::Dot { + inputs, weights, .. + } => { + let input_shape = self.out_shapes[inputs[0].i].clone(); + let kind = dot_kind(inputs.len() as u64, &input_shape, weights); + match kind { + DotKind::Simple | DotKind::Tensor | DotKind::CompatibleTensor => { + Shape::number() + } + DotKind::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()), + DotKind::Unsupported { .. } => { + let weights_shape = &weights.shape; + println!(); + println!(); + println!("Error diagnostic on dot operation:"); + println!( + "Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>" + ); + println!(); + panic!("Unsupported or invalid dot operation") + } + } + } + } + } + + fn infer_out_precision(&self, op: &UnparameterizedOperator) -> Precision { + match op { + Operator::Input { out_precision, .. } + | Operator::Lut { out_precision, .. } + | Operator::UnsafeCast { out_precision, .. } => *out_precision, + Operator::Dot { inputs, .. } | Operator::LevelledOp { inputs, .. } => { + self.out_precisions[inputs[0].i] + } + } + } } #[cfg(test)] @@ -138,12 +291,10 @@ mod tests { Operator::Input { out_precision: 1, out_shape: Shape::number(), - extra_data: () }, Operator::Input { out_precision: 2, out_shape: Shape::number(), - extra_data: () }, Operator::LevelledOp { inputs: vec![input1, input2], @@ -151,13 +302,11 @@ mod tests { manp: 1.0, out_shape: Shape::number(), comment: "sum".to_string(), - extra_data: () }, Operator::Lut { input: sum1, table: FunctionTable::UNKWOWN, out_precision: 1, - extra_data: () }, Operator::LevelledOp { inputs: vec![input1, lut1], @@ -165,7 +314,6 @@ mod tests { manp: 1.0, out_shape: Shape::vector(2), comment: "concat".to_string(), - extra_data: () }, Operator::Dot { inputs: vec![concat], @@ -173,15 +321,119 @@ mod tests { shape: Shape::vector(2), values: vec![1, 2] }, - extra_data: () }, Operator::Lut { input: dot, table: FunctionTable::UNKWOWN, out_precision: 2, - extra_data: () } ] ); } + + #[test] + fn test_rounded_lut() { + let mut graph = OperationDag::new(); + let out_precision = 5; + let rounded_precision = 2; + let input1 = graph.add_input(out_precision, Shape::number()); + let _ = graph.add_expanded_rounded_lut( + input1, + FunctionTable::UNKWOWN, + rounded_precision, + out_precision, + ); + let expecteds = [ + Operator::Input { + out_precision, + out_shape: Shape::number(), + }, + // The rounding addition skipped, it's a no-op wrt crypto parameter + // Clear: cleared = input - bit0 + //// Extract bit + Operator::Dot { + inputs: vec![input1], + weights: Weights::number(1 << 5), + }, + Operator::UnsafeCast { + input: OperatorIndex { i: 1 }, + out_precision: 0, + }, + //// 1 Bit to out_precision + Operator::Lut { + input: OperatorIndex { i: 2 }, + table: FunctionTable::UNKWOWN, + out_precision: 5, + }, + //// Erase bit + Operator::Dot { + inputs: vec![input1, OperatorIndex { i: 3 }], + weights: Weights::vector([1, -1]), + }, + Operator::UnsafeCast { + input: OperatorIndex { i: 4 }, + out_precision: 4, + }, + // Clear: cleared = input - bit0 - bit1 + //// Extract bit + Operator::Dot { + inputs: vec![OperatorIndex { i: 5 }], + weights: Weights::number(1 << 4), + }, + Operator::UnsafeCast { + input: OperatorIndex { i: 6 }, + out_precision: 0, + }, + //// 1 Bit to out_precision + Operator::Lut { + input: OperatorIndex { i: 7 }, + table: FunctionTable::UNKWOWN, + out_precision: 4, + }, + //// Erase bit + Operator::Dot { + inputs: vec![OperatorIndex { i: 5 }, OperatorIndex { i: 8 }], + weights: Weights::vector([1, -1]), + }, + Operator::UnsafeCast { + input: OperatorIndex { i: 9 }, + out_precision: 3, + }, + // Clear: cleared = input - bit0 - bit1 - bit2 + //// Extract bit + Operator::Dot { + inputs: vec![OperatorIndex { i: 10 }], + weights: Weights::number(1 << 3), + }, + Operator::UnsafeCast { + input: OperatorIndex { i: 11 }, + out_precision: 0, + }, + //// 1 Bit to out_precision + Operator::Lut { + input: OperatorIndex { i: 12 }, + table: FunctionTable::UNKWOWN, + out_precision: 3, + }, + //// Erase bit + Operator::Dot { + inputs: vec![OperatorIndex { i: 10 }, OperatorIndex { i: 13 }], + weights: Weights::vector([1, -1]), + }, + Operator::UnsafeCast { + input: OperatorIndex { i: 14 }, + out_precision: 2, + }, + // Lut on rounded precision + Operator::Lut { + input: OperatorIndex { i: 15 }, + table: FunctionTable::UNKWOWN, + out_precision: 5, + }, + ]; + assert_eq!(expecteds.len(), graph.operators.len()); + for (i, (expected, actual)) in std::iter::zip(expecteds, graph.operators).enumerate() { + assert_eq!(expected, actual, "{i}-th operation"); + } + } } diff --git a/concrete-optimizer/src/dag/value_parametrized.rs b/concrete-optimizer/src/dag/value_parametrized.rs deleted file mode 100644 index 4658b6a62..000000000 --- a/concrete-optimizer/src/dag/value_parametrized.rs +++ /dev/null @@ -1,9 +0,0 @@ -use crate::dag::parameter_indexed::OperatorParameterIndexed; -use crate::global_parameters::{ParameterToOperation, ParameterValues}; - -#[allow(dead_code)] -pub struct OperationDag { - pub(crate) operators: Vec, - pub(crate) parameter_ranges: ParameterValues, - pub(crate) reverse_map: ParameterToOperation, -} diff --git a/concrete-optimizer/src/global_parameters.rs b/concrete-optimizer/src/global_parameters.rs index af2fa8a88..8ff244fce 100644 --- a/concrete-optimizer/src/global_parameters.rs +++ b/concrete-optimizer/src/global_parameters.rs @@ -1,43 +1,14 @@ -use std::collections::HashSet; - -use crate::dag::operator::{Operator, OperatorIndex}; -use crate::dag::parameter_indexed::{ - InputParameterIndexed, LutParametersIndexed, OperatorParameterIndexed, -}; -use crate::dag::unparametrized::UnparameterizedOperator; -use crate::dag::{parameter_indexed, range_parametrized, unparametrized}; use crate::parameters::{ - BrDecompositionParameterRanges, BrDecompositionParameters, GlweParameterRanges, GlweParameters, - KsDecompositionParameterRanges, KsDecompositionParameters, + BrDecompositionParameterRanges, GlweParameterRanges, KsDecompositionParameterRanges, }; -#[derive(Clone)] -pub(crate) struct ParameterToOperation { - pub glwe: Vec>, - pub br_decomposition: Vec>, - pub ks_decomposition: Vec>, -} - #[derive(Clone, Copy, PartialEq, Eq, Debug)] -pub struct ParameterCount { +struct ParameterCount { pub glwe: usize, pub br_decomposition: usize, pub ks_decomposition: usize, } -#[derive(Clone)] -pub struct ParameterRanges { - pub glwe: Vec, - pub br_decomposition: Vec, // 0 => lpetit , 1 => l plus grand - pub ks_decomposition: Vec, -} - -pub struct ParameterValues { - pub glwe: Vec, - pub br_decomposition: Vec, - pub ks_decomposition: Vec, -} - #[derive(Clone, Copy)] pub struct ParameterDomains { // move next comment to pareto ranges definition @@ -91,289 +62,3 @@ impl Range { self.into_iter().collect() } } - -#[must_use] -#[allow(clippy::needless_pass_by_value)] -pub fn minimal_unify(_g: unparametrized::OperationDag) -> parameter_indexed::OperationDag { - todo!() -} - -fn convert_maximal(op: UnparameterizedOperator) -> OperatorParameterIndexed { - let external_glwe_index = 0; - let internal_lwe_index = 1; - let br_decomposition_index = 0; - let ks_decomposition_index = 0; - match op { - Operator::Input { - out_precision, - out_shape, - .. - } => Operator::Input { - out_precision, - out_shape, - extra_data: InputParameterIndexed { - lwe_dimension_index: external_glwe_index, - }, - }, - Operator::Lut { - input, - table, - out_precision, - .. - } => Operator::Lut { - input, - table, - out_precision, - extra_data: LutParametersIndexed { - input_lwe_dimension_index: external_glwe_index, - ks_decomposition_parameter_index: ks_decomposition_index, - internal_lwe_dimension_index: internal_lwe_index, - br_decomposition_parameter_index: br_decomposition_index, - output_glwe_params_index: external_glwe_index, - }, - }, - Operator::Dot { - inputs, weights, .. - } => Operator::Dot { - inputs, - weights, - extra_data: (), - }, - Operator::LevelledOp { - inputs, - complexity, - manp, - out_shape, - comment, - .. - } => Operator::LevelledOp { - inputs, - complexity, - manp, - comment, - out_shape, - extra_data: (), - }, - } -} - -#[must_use] -pub fn maximal_unify(g: unparametrized::OperationDag) -> parameter_indexed::OperationDag { - let operators: Vec<_> = g.operators.into_iter().map(convert_maximal).collect(); - - let parameters = ParameterCount { - glwe: 2, - br_decomposition: 1, - ks_decomposition: 1, - }; - - let mut reverse_map = ParameterToOperation { - glwe: vec![vec![], vec![]], - br_decomposition: vec![vec![]], - ks_decomposition: vec![vec![]], - }; - - for (i, op) in operators.iter().enumerate() { - let index = OperatorIndex { i }; - match op { - Operator::Input { .. } => { - reverse_map.glwe[0].push(index); - } - Operator::Lut { .. } => { - reverse_map.glwe[0].push(index); - reverse_map.glwe[1].push(index); - reverse_map.br_decomposition[0].push(index); - reverse_map.ks_decomposition[0].push(index); - } - Operator::Dot { .. } | Operator::LevelledOp { .. } => { - reverse_map.glwe[0].push(index); - reverse_map.glwe[1].push(index); - } - } - } - - parameter_indexed::OperationDag { - operators, - parameters_count: parameters, - reverse_map, - } -} - -#[must_use] -pub fn domains_to_ranges( - parameter_indexed::OperationDag { - operators, - parameters_count, - reverse_map, - }: parameter_indexed::OperationDag, - domains: ParameterDomains, -) -> range_parametrized::OperationDag { - let mut constrained_glwe_parameter_indexes = HashSet::new(); - for op in &operators { - if let Operator::Lut { extra_data, .. } = op { - let _ = constrained_glwe_parameter_indexes.insert(extra_data.output_glwe_params_index); - } - } - - let mut glwe = vec![]; - - for i in 0..parameters_count.glwe { - if constrained_glwe_parameter_indexes.contains(&i) { - glwe.push(domains.glwe_pbs_constrained); - } else { - glwe.push(domains.free_glwe); - } - } - - let parameter_ranges = ParameterRanges { - glwe, - br_decomposition: vec![domains.br_decomposition; parameters_count.br_decomposition], - ks_decomposition: vec![domains.ks_decomposition; parameters_count.ks_decomposition], - }; - - range_parametrized::OperationDag { - operators, - parameter_ranges, - reverse_map, - } -} - -// fn fill_ranges(g: parameter_indexed::AtomicPatternDag) -> parameter_ranged::AtomicPatternDag { -// //check unconstrained GlweDim -> set range_poly_size=[1, 2[ -// todo!() -// } - -#[cfg(test)] -mod tests { - use super::*; - use crate::dag::operator::{FunctionTable, LevelledComplexity, Shape}; - - #[test] - fn test_maximal_unify() { - let mut graph = unparametrized::OperationDag::new(); - - let input1 = graph.add_input(1, Shape::number()); - - let input2 = graph.add_input(2, Shape::number()); - - let cpx_add = LevelledComplexity::ADDITION; - let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); - - let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 2); - - let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat"); - - let dot = graph.add_dot([concat], [1, 2]); - - let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2); - - let graph_params = maximal_unify(graph); - - assert_eq!( - graph_params.parameters_count, - ParameterCount { - glwe: 2, - br_decomposition: 1, - ks_decomposition: 1, - } - ); - - assert_eq!( - graph_params.reverse_map.glwe, - vec![ - vec![input1, input2, sum1, lut1, concat, dot, lut2], - vec![sum1, lut1, concat, dot, lut2] - ] - ); - - assert_eq!( - graph_params.reverse_map.br_decomposition, - vec![vec![lut1, lut2]] - ); - - assert_eq!( - graph_params.reverse_map.ks_decomposition, - vec![vec![lut1, lut2]] - ); - // collectes l'ensemble des parametres - // unify structurellement les parametres identiques - // => counts - // => - // let parametrized_expr = { global, dag + indexation} - } - - #[test] - fn test_simple_lwe() { - let mut graph = unparametrized::OperationDag::new(); - let input1 = graph.add_input(1, Shape::number()); - let _input2 = graph.add_input(2, Shape::number()); - - let graph_params = maximal_unify(graph); - - let range_parametrized::OperationDag { - operators, - parameter_ranges, - reverse_map: _, - } = domains_to_ranges(graph_params, DEFAUT_DOMAINS); - - let input_1_lwe_params = match &operators[input1.i] { - Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index, - _ => unreachable!(), - }; - - dbg!(¶meter_ranges.glwe); - - assert_eq!( - DEFAUT_DOMAINS.free_glwe, - parameter_ranges.glwe[input_1_lwe_params] - ); - } - - #[test] - fn test_simple_lwe2() { - let mut graph = unparametrized::OperationDag::new(); - let input1 = graph.add_input(1, Shape::number()); - let input2 = graph.add_input(2, Shape::number()); - - let cpx_add = LevelledComplexity::ADDITION; - let concat = - graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat"); - - let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN, 2); - - let graph_params = maximal_unify(graph); - - let range_parametrized::OperationDag { - operators, - parameter_ranges, - reverse_map: _, - } = domains_to_ranges(graph_params, DEFAUT_DOMAINS); - - let input_1_lwe_params = match &operators[input1.i] { - Operator::Input { extra_data, .. } => extra_data.lwe_dimension_index, - _ => unreachable!(), - }; - assert_eq!( - DEFAUT_DOMAINS.glwe_pbs_constrained, - parameter_ranges.glwe[input_1_lwe_params] - ); - - let lut1_out_glwe_params = match &operators[lut1.i] { - Operator::Lut { extra_data, .. } => extra_data.output_glwe_params_index, - _ => unreachable!(), - }; - assert_eq!( - DEFAUT_DOMAINS.glwe_pbs_constrained, - parameter_ranges.glwe[lut1_out_glwe_params] - ); - - let lut1_internal_glwe_params = match &operators[lut1.i] { - Operator::Lut { extra_data, .. } => extra_data.internal_lwe_dimension_index, - _ => unreachable!(), - }; - assert_eq!( - DEFAUT_DOMAINS.free_glwe, - parameter_ranges.glwe[lut1_internal_glwe_params] - ); - } -} diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index cdf0339dc..169d41b9d 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -54,6 +54,8 @@ fn assert_non_empty_inputs(op: &unparametrized::UnparameterizedOperator) { fn assert_dag_correctness(dag: &unparametrized::OperationDag) { for op in &dag.operators { assert_non_empty_inputs(op); + assert_inputs_uniform_precisions(op, &dag.out_precisions); + assert_dot_uniform_inputs_shape(op, &dag.out_shapes); } } @@ -68,10 +70,6 @@ fn assert_valid_variances(dag: &OperationDag) { } fn assert_properties_correctness(dag: &OperationDag) { - for op in &dag.operators { - assert_inputs_uniform_precisions(op, &dag.out_precisions); - assert_dot_uniform_inputs_shape(op, &dag.out_shapes); - } assert_valid_variances(dag); } @@ -89,10 +87,6 @@ fn variance_origin(inputs: &[OperatorIndex], out_variances: &[SymbolicVariance]) #[derive(Clone, Debug)] pub struct OperationDag { pub operators: Vec, - // Collect all operators ouput shape - pub out_shapes: Vec, - // Collect all operators ouput precision - pub out_precisions: Vec, // Collect all operators ouput variances pub out_variances: Vec, pub nb_luts: u64, @@ -118,66 +112,6 @@ pub struct VariancesAndBound { pub all_in_lut: Vec<(u64, SymbolicVariance)>, } -fn out_shape(op: &unparametrized::UnparameterizedOperator, out_shapes: &mut [Shape]) -> Shape { - match op { - Op::Input { out_shape, .. } | Op::LevelledOp { out_shape, .. } => out_shape.clone(), - Op::Lut { input, .. } => out_shapes[input.i].clone(), - Op::Dot { - inputs, weights, .. - } => { - if inputs.is_empty() { - return Shape::number(); - } - let input_shape = first(inputs, out_shapes); - let kind = dot_kind(inputs.len() as u64, input_shape, weights); - match kind { - DK::Simple | DK::Tensor => Shape::number(), - DK::CompatibleTensor => weights.shape.clone(), - DK::Broadcast { .. } => Shape::vector(input_shape.first_dim_size()), - DK::Unsupported { .. } => { - let weights_shape = &weights.shape; - println!(); - println!(); - println!("Error diagnostic on dot operation:"); - println!("Incompatible operands: <{input_shape:?}> DOT <{weights_shape:?}>"); - println!(); - panic!("Unsupported or invalid dot operation") - } - } - } - } -} - -fn out_shapes(dag: &unparametrized::OperationDag) -> Vec { - let nb_ops = dag.operators.len(); - let mut out_shapes = Vec::::with_capacity(nb_ops); - for op in &dag.operators { - let shape = out_shape(op, &mut out_shapes); - out_shapes.push(shape); - } - out_shapes -} - -fn out_precision( - op: &unparametrized::UnparameterizedOperator, - out_precisions: &[Precision], -) -> Precision { - match op { - Op::Input { out_precision, .. } | Op::Lut { out_precision, .. } => *out_precision, - Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i], - } -} - -fn out_precisions(dag: &unparametrized::OperationDag) -> Vec { - let nb_ops = dag.operators.len(); - let mut out_precisions = Vec::::with_capacity(nb_ops); - for op in &dag.operators { - let precision = out_precision(op, &out_precisions); - out_precisions.push(precision); - } - out_precisions -} - fn out_variance( op: &unparametrized::UnparameterizedOperator, out_shapes: &[Shape], @@ -220,17 +154,15 @@ fn out_variance( DK::Unsupported { .. } => panic!("Unsupported"), } } + Op::UnsafeCast { input, .. } => out_variances[input.i], } } -fn out_variances( - dag: &unparametrized::OperationDag, - out_shapes: &[Shape], -) -> Vec { +fn out_variances(dag: &unparametrized::OperationDag) -> Vec { let nb_ops = dag.operators.len(); let mut out_variances = Vec::with_capacity(nb_ops); for op in &dag.operators { - let vf = out_variance(op, out_shapes, &mut out_variances); + let vf = out_variance(op, &dag.out_shapes, &mut out_variances); out_variances.push(vf); } out_variances @@ -242,7 +174,7 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec for op in &dag.operators { match op { Op::Input { .. } => (), - Op::Lut { input, .. } => { + Op::Lut { input, .. } | Op::UnsafeCast { input, .. } => { extra_values_to_check[input.i] = false; } Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => { @@ -257,8 +189,6 @@ fn extra_final_values_to_check(dag: &unparametrized::OperationDag) -> Vec fn extra_final_variances( dag: &unparametrized::OperationDag, - out_shapes: &[Shape], - out_precisions: &[Precision], out_variances: &[SymbolicVariance], ) -> Vec<(Precision, Shape, SymbolicVariance)> { extra_final_values_to_check(dag) @@ -266,7 +196,11 @@ fn extra_final_variances( .enumerate() .filter_map(|(i, &is_final)| { if is_final { - Some((out_precisions[i], out_shapes[i].clone(), out_variances[i])) + Some(( + dag.out_precisions[i], + dag.out_shapes[i].clone(), + out_variances[i], + )) } else { None } @@ -276,8 +210,6 @@ fn extra_final_variances( fn in_luts_variance( dag: &unparametrized::OperationDag, - out_shapes: &[Shape], - out_precisions: &[Precision], out_variances: &[SymbolicVariance], ) -> Vec<(Precision, Shape, SymbolicVariance)> { dag.operators @@ -286,8 +218,8 @@ fn in_luts_variance( .filter_map(|(i, op)| { if let &Op::Lut { input, .. } = op { Some(( - out_precisions[input.i], - out_shapes[i].clone(), + dag.out_precisions[input.i], + dag.out_shapes[i].clone(), out_variances[input.i], )) } else { @@ -315,26 +247,23 @@ fn op_levelled_complexity( } } Op::LevelledOp { complexity, .. } => *complexity, - Op::Input { .. } | Op::Lut { .. } => LevelledComplexity::ZERO, + Op::Input { .. } | Op::Lut { .. } | Op::UnsafeCast { .. } => LevelledComplexity::ZERO, } } -fn levelled_complexity( - dag: &unparametrized::OperationDag, - out_shapes: &[Shape], -) -> LevelledComplexity { +fn levelled_complexity(dag: &unparametrized::OperationDag) -> LevelledComplexity { let mut levelled_complexity = LevelledComplexity::ZERO; for op in &dag.operators { - levelled_complexity += op_levelled_complexity(op, out_shapes); + levelled_complexity += op_levelled_complexity(op, &dag.out_shapes); } levelled_complexity } -fn lut_count(dag: &unparametrized::OperationDag, out_shapes: &[Shape]) -> u64 { +pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 { let mut count = 0; for (i, op) in dag.operators.iter().enumerate() { if let Op::Lut { .. } = op { - count += out_shapes[i].flat_size(); + count += dag.out_shapes[i].flat_size(); } } count @@ -433,10 +362,8 @@ fn constraint_for_one_precision( pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 { assert_dag_correctness(dag); - let out_shapes = out_shapes(dag); - let out_precisions = out_precisions(dag); - let out_variances = out_variances(dag, &out_shapes); - let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances); + let out_variances = out_variances(dag); + let in_luts_variance = in_luts_variance(dag, &out_variances); let coeffs = in_luts_variance .iter() .map(|(_precision, _shape, symbolic_variance)| { @@ -447,25 +374,18 @@ pub fn worst_log_norm(dag: &unparametrized::OperationDag) -> f64 { worst.log2() } -pub fn lut_count_from_dag(dag: &unparametrized::OperationDag) -> u64 { - lut_count(dag, &out_shapes(dag)) -} - pub fn analyze( dag: &unparametrized::OperationDag, noise_config: &NoiseBoundConfig, ) -> OperationDag { assert_dag_correctness(dag); - let out_shapes = out_shapes(dag); - let out_precisions = out_precisions(dag); - let out_variances = out_variances(dag, &out_shapes); - let in_luts_variance = in_luts_variance(dag, &out_shapes, &out_precisions, &out_variances); - let nb_luts = lut_count(dag, &out_shapes); - let extra_final_variances = - extra_final_variances(dag, &out_shapes, &out_precisions, &out_variances); - let levelled_complexity = levelled_complexity(dag, &out_shapes); + let out_variances = out_variances(dag); + let in_luts_variance = in_luts_variance(dag, &out_variances); + let nb_luts = lut_count_from_dag(dag); + let extra_final_variances = extra_final_variances(dag, &out_variances); + let levelled_complexity = levelled_complexity(dag); let constraints_by_precisions = constraints_by_precisions( - &out_precisions, + &dag.out_precisions, &extra_final_variances, &in_luts_variance, noise_config, @@ -475,8 +395,6 @@ pub fn analyze( .all(|(_, _, sb)| sb.origin() == VarianceOrigin::Input); let result = OperationDag { operators: dag.operators.clone(), - out_shapes, - out_precisions, out_variances, nb_luts, levelled_complexity, @@ -712,9 +630,9 @@ mod tests { let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); assert_eq!(analysis.out_variances[input1.i], SymbolicVariance::INPUT); - assert_eq!(analysis.out_shapes[input1.i], Shape::number()); + assert_eq!(graph.out_shapes[input1.i], Shape::number()); assert_eq!(analysis.levelled_complexity, LevelledComplexity::ZERO); - assert_eq!(analysis.out_precisions[input1.i], 1); + assert_eq!(graph.out_precisions[input1.i], 1); assert_f64_eq(complexity_cost, 0.0); assert!(analysis.nb_luts == 0); let constraint = analysis.constraint(); @@ -735,9 +653,9 @@ mod tests { let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); assert!(analysis.out_variances[lut1.i] == SymbolicVariance::LUT); - assert!(analysis.out_shapes[lut1.i] == Shape::number()); + assert!(graph.out_shapes[lut1.i] == Shape::number()); assert!(analysis.levelled_complexity == LevelledComplexity::ZERO); - assert_eq!(analysis.out_precisions[lut1.i], 8); + assert_eq!(graph.out_precisions[lut1.i], 8); assert_f64_eq(one_lut_cost, complexity_cost); let constraint = analysis.constraint(); assert!(constraint.pareto_output.len() == 1); @@ -765,9 +683,9 @@ mod tests { lut_coeff: 0.0, }; assert!(analysis.out_variances[dot.i] == expected_var); - assert!(analysis.out_shapes[dot.i] == Shape::number()); + assert!(graph.out_shapes[dot.i] == Shape::number()); assert!(analysis.levelled_complexity == LevelledComplexity::ADDITION * 2); - assert_eq!(analysis.out_precisions[dot.i], 1); + assert_eq!(graph.out_precisions[dot.i], 1); let expected_dot_cost = (2 * lwe_dim) as f64; assert_f64_eq(expected_dot_cost, complexity_cost); let constraint = analysis.constraint(); @@ -792,7 +710,7 @@ mod tests { let complexity_cost = analysis.complexity(lwe_dim, one_lut_cost); assert!(analysis.out_variances[dot.i].origin() == VO::Input); - assert_eq!(analysis.out_precisions[dot.i], 3); + assert_eq!(graph.out_precisions[dot.i], 3); let expected_square_norm2 = weights.square_norm2() as f64; let actual_square_norm2 = analysis.out_variances[dot.i].input_coeff; // Due to call on log2() to compute manp the result is not exact diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index 8065e45cb..0061845a3 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -240,10 +240,10 @@ pub fn optimize( maximum_acceptable_error_probability: config.maximum_acceptable_error_probability, ciphertext_modulus_log, }; - let dag = analyze::analyze(dag, &noise_config); - let &min_precision = dag.out_precisions.iter().min().unwrap(); + let dag = analyze::analyze(dag, &noise_config); + let safe_variance = error::safe_variance_bound_2padbits( min_precision as u64, ciphertext_modulus_log, @@ -357,6 +357,8 @@ pub fn optimize_v0( mod tests { use std::time::Instant; + use once_cell::sync::Lazy; + use super::*; use crate::computing_cost::cpu::CpuComplexity; use crate::config; @@ -388,10 +390,12 @@ mod tests { const _4_SIGMA: f64 = 1.0 - 0.999_936_657_516; - fn optimize( - dag: &unparametrized::OperationDag, - cache: &PersistDecompCache, - ) -> OptimizationState { + static SHARED_CACHES: Lazy = Lazy::new(|| { + let processing_unit = config::ProcessingUnit::Cpu; + decomposition::cache(128, processing_unit, None) + }); + + fn optimize(dag: &unparametrized::OperationDag) -> OptimizationState { let config = Config { security_level: 128, maximum_acceptable_error_probability: _4_SIGMA, @@ -401,7 +405,7 @@ mod tests { let search_space = SearchSpace::default_cpu(); - super::optimize(dag, config, &search_space, cache) + super::optimize(dag, config, &search_space, &SHARED_CACHES) } struct Times { @@ -442,15 +446,13 @@ mod tests { complexity_model: &CpuComplexity::default(), }; - let cache = decomposition::cache(config.security_level, processing_unit, None); - let _ = optimize_v0( sum_size, precision, config, weight as f64, &search_space, - &cache, + &SHARED_CACHES, ); // ensure cache is filled @@ -461,7 +463,7 @@ mod tests { config, weight as f64, &search_space, - &cache, + &SHARED_CACHES, ); times.dag_time += chrono.elapsed().as_nanos(); @@ -472,7 +474,7 @@ mod tests { config, weight as f64, &search_space, - &cache, + &SHARED_CACHES, ); times.worst_time += chrono.elapsed().as_nanos(); assert_eq!( @@ -504,8 +506,6 @@ mod tests { let processing_unit = config::ProcessingUnit::Cpu; let security_level = 128; - let cache = decomposition::cache(security_level, processing_unit, None); - let mut dag = unparametrized::OperationDag::new(); { let input1 = dag.add_input(precision, Shape::number()); @@ -542,14 +542,14 @@ mod tests { complexity_model: &CpuComplexity::default(), }; - let state = optimize(&dag, &cache); + let state = optimize(&dag); let state_ref = atomic_pattern::optimize_one( 1, precision as u64, config, weight as f64, &search_space, - &cache, + &SHARED_CACHES, ); assert_eq!( state.best_solution.is_some(), @@ -567,7 +567,7 @@ mod tests { assert!(sol.global_p_error <= 1.0); } - fn no_lut_vs_lut(precision: Precision, cache: &PersistDecompCache) { + fn no_lut_vs_lut(precision: Precision) { let mut dag_lut = unparametrized::OperationDag::new(); let input1 = dag_lut.add_input(precision, Shape::number()); let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision); @@ -575,8 +575,8 @@ mod tests { let mut dag_no_lut = unparametrized::OperationDag::new(); let _input2 = dag_no_lut.add_input(precision, Shape::number()); - let state_no_lut = optimize(&dag_no_lut, cache); - let state_lut = optimize(&dag_lut, cache); + let state_no_lut = optimize(&dag_no_lut); + let state_lut = optimize(&dag_lut); assert_eq!( state_no_lut.best_solution.is_some(), state_lut.best_solution.is_some() @@ -592,17 +592,14 @@ mod tests { } #[test] fn test_lut_vs_no_lut() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); for precision in 1..=8 { - no_lut_vs_lut(precision, &cache); + no_lut_vs_lut(precision); } } fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise( precision: Precision, weight: i64, - cache: &PersistDecompCache, ) { let weight = &Weights::number(weight); @@ -622,8 +619,8 @@ mod tests { let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision); } - let state_1 = optimize(&dag_1, cache); - let state_2 = optimize(&dag_2, cache); + let state_1 = optimize(&dag_1); + let state_2 = optimize(&dag_2); if state_1.best_solution.is_none() { assert!(state_2.best_solution.is_none()); @@ -636,19 +633,15 @@ mod tests { #[test] fn test_lut_with_input_base_noise_better_than_lut_with_lut_base_noise() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); for log_weight in 1..=16 { let weight = 1 << log_weight; for precision in 5..=9 { - lut_with_input_base_noise_better_than_lut_with_lut_base_noise( - precision, weight, &cache, - ); + lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision, weight); } } } - fn lut_1_layer_has_better_complexity(precision: Precision, cache: &PersistDecompCache) { + fn lut_1_layer_has_better_complexity(precision: Precision) { let dag_1_layer = { let mut dag = unparametrized::OperationDag::new(); let input1 = dag.add_input(precision, Shape::number()); @@ -664,19 +657,17 @@ mod tests { dag }; - let sol_1_layer = optimize(&dag_1_layer, cache).best_solution.unwrap(); - let sol_2_layer = optimize(&dag_2_layer, cache).best_solution.unwrap(); + let sol_1_layer = optimize(&dag_1_layer).best_solution.unwrap(); + let sol_2_layer = optimize(&dag_2_layer).best_solution.unwrap(); assert!(sol_1_layer.complexity < sol_2_layer.complexity); } #[test] fn test_lut_1_layer_is_better() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); // for some reason on 4, 5, 6, the complexity is already minimal // this could be due to pre-defined pareto set for precision in [1, 2, 3, 7, 8] { - lut_1_layer_has_better_complexity(precision, &cache); + lut_1_layer_has_better_complexity(precision); } } @@ -688,10 +679,7 @@ mod tests { let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); } - fn assert_multi_precision_dominate_single( - weight: i64, - cache: &PersistDecompCache, - ) -> Option { + fn assert_multi_precision_dominate_single(weight: i64) -> Option { let low_precision = 4u8; let high_precision = 5u8; let mut dag_low = unparametrized::OperationDag::new(); @@ -704,12 +692,12 @@ mod tests { circuit(&mut dag_multi, low_precision, weight); circuit(&mut dag_multi, high_precision, 1); } - let state_multi = optimize(&dag_multi, cache); + let state_multi = optimize(&dag_multi); let mut sol_multi = state_multi.best_solution?; - let state_low = optimize(&dag_low, cache); - let state_high = optimize(&dag_high, cache); + let state_low = optimize(&dag_low); + let state_high = optimize(&dag_high); let sol_low = state_low.best_solution.unwrap(); let sol_high = state_high.best_solution.unwrap(); @@ -728,12 +716,10 @@ mod tests { #[test] fn test_multi_precision_dominate_single() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); let mut prev = Some(true); // true -> ... -> true -> false -> ... -> false for log2_weight in 0..29 { let weight = 1 << log2_weight; - let current = assert_multi_precision_dominate_single(weight, &cache); + let current = assert_multi_precision_dominate_single(weight); #[allow(clippy::match_like_matches_macro)] // less readable let authorized = match (prev, current) { (Some(false), Some(true)) => false, @@ -763,29 +749,22 @@ mod tests { #[test] fn test_global_p_error_input() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); for precision in [4_u8, 8] { for weight in [1, 3, 27, 243, 729] { for dim in [1, 2, 16, 32] { - let _ = check_global_p_error_input(dim, weight, precision, &cache); + let _ = check_global_p_error_input(dim, weight, precision); } } } } - fn check_global_p_error_input( - dim: u64, - weight: i64, - precision: u8, - cache: &PersistDecompCache, - ) -> f64 { + fn check_global_p_error_input(dim: u64, weight: i64, precision: u8) -> f64 { let shape = Shape::vector(dim); let weights = Weights::number(weight); let mut dag = unparametrized::OperationDag::new(); let input1 = dag.add_input(precision, shape); let _dot1 = dag.add_dot([input1], weights); // this is just several multiply - let state = optimize(&dag, cache); + let state = optimize(&dag); let sol = state.best_solution.unwrap(); let worst_expected_p_error_dim = local_to_approx_global_p_error(sol.p_error, dim); approx::assert_relative_eq!(sol.global_p_error, worst_expected_p_error_dim); @@ -794,23 +773,16 @@ mod tests { #[test] fn test_global_p_error_lut() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); for precision in [4_u8, 8] { for weight in [1, 3, 27, 243, 729] { for depth in [2, 16, 32] { - check_global_p_error_lut(depth, weight, precision, &cache); + check_global_p_error_lut(depth, weight, precision); } } } } - fn check_global_p_error_lut( - depth: u64, - weight: i64, - precision: u8, - cache: &PersistDecompCache, - ) { + fn check_global_p_error_lut(depth: u64, weight: i64, precision: u8) { let shape = Shape::number(); let weights = Weights::number(weight); let mut dag = unparametrized::OperationDag::new(); @@ -819,7 +791,7 @@ mod tests { let dot = dag.add_dot([last_val], &weights); last_val = dag.add_lut(dot, FunctionTable::UNKWOWN, precision); } - let state = optimize(&dag, cache); + let state = optimize(&dag); let sol = state.best_solution.unwrap(); // the first lut on input has reduced impact on error probability let lower_nb_dominating_lut = depth - 1; @@ -856,8 +828,6 @@ mod tests { #[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const #[test] fn test_global_p_error_dominating_lut() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); let depth = 128; let weights_low = 1; let weights_high = 1; @@ -870,7 +840,7 @@ mod tests { weights_low, weights_high, ); - let sol = optimize(&dag, &cache).best_solution.unwrap(); + let sol = optimize(&dag).best_solution.unwrap(); // the 2 first luts and low precision/weight luts have little impact on error probability let nb_dominating_lut = depth - 1; let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut); @@ -885,8 +855,6 @@ mod tests { #[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const #[test] fn test_global_p_error_non_dominating_lut() { - let processing_unit = config::ProcessingUnit::Cpu; - let cache = decomposition::cache(128, processing_unit, None); let depth = 128; let weights_low = 1024 * 1024 * 3; let weights_high = 1; @@ -899,7 +867,7 @@ mod tests { weights_low, weights_high, ); - let sol = optimize(&dag, &cache).best_solution.unwrap(); + let sol = optimize(&dag).best_solution.unwrap(); // all intern luts have an impact on error probability almost equaly let nb_dominating_lut = (2 * depth) - 1; let approx_global_p_error = local_to_approx_global_p_error(sol.p_error, nb_dominating_lut); @@ -910,4 +878,89 @@ mod tests { max_relative = 0.05 ); } + + fn circuit_with_rounded_lut( + rounded_precision: Precision, + precision: Precision, + weight: i64, + ) -> unparametrized::OperationDag { + // circuit with intermediate high precision in levelled op + let shape = Shape::number(); + let mut dag = unparametrized::OperationDag::new(); + let weight = Weights::number(weight); + let val = dag.add_input(precision, &shape); + let lut1 = + dag.add_expanded_rounded_lut(val, FunctionTable::UNKWOWN, rounded_precision, precision); + let dot = dag.add_dot([lut1], &weight); + let _lut2 = dag.add_expanded_rounded_lut( + dot, + FunctionTable::UNKWOWN, + rounded_precision, + rounded_precision, + ); + dag + } + + fn check_global_p_error_rounded_lut( + precision: Precision, + rounded_precision: Precision, + weight: i64, + ) { + let dag_no_rounded = circuit_with_rounded_lut(precision, precision, weight); + let dag_rounded = circuit_with_rounded_lut(rounded_precision, precision, weight); + let dag_reduced = circuit_with_rounded_lut(rounded_precision, rounded_precision, weight); + let best_reduced = optimize(&dag_reduced).best_solution.unwrap(); + let best_rounded = optimize(&dag_rounded).best_solution.unwrap(); + let best_no_rounded_complexity = optimize(&dag_no_rounded) + .best_solution + .map_or(f64::INFINITY, |s| s.complexity); + // println!("Slowdown acc {rounded_precision} -> {precision}, {best_rounded.complexity/best_reduced.complexity}"); + // println!("Speedup tlu {precision} -> {rounded_precision}, {best_no_rounded_complexity/best_rounded.complexity}"); + if weight == 0 && precision - rounded_precision <= 4 + || weight == 16 && precision - rounded_precision <= 3 + { + // linear slowdown with almost no margin + assert!( + best_rounded.complexity + <= best_reduced.complexity + * (1.0 + 1.01 * (precision - rounded_precision) as f64) + ); + } else if precision - rounded_precision <= 4 { + // linear slowdown with margin + assert!( + best_rounded.complexity + <= best_reduced.complexity + * (1.0 + 1.5 * (precision - rounded_precision) as f64) + ); + } else if precision != rounded_precision { + // slowdown + assert!(best_reduced.complexity < best_rounded.complexity); + } + // linear speedup + assert!( + best_rounded.complexity * (precision - rounded_precision) as f64 + <= best_no_rounded_complexity + ); + } + + #[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const + #[test] + fn test_global_p_error_rounded_lut() { + let precision = 8 as Precision; + for rounded_precision in 4..9 { + check_global_p_error_rounded_lut(precision, rounded_precision, 1); + } + } + + #[allow(clippy::unnecessary_cast)] // clippy bug refusing as Precision on const + #[test] + fn test_global_p_error_increased_accumulator() { + let rounded_precision = 8 as Precision; + for precision in [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20] { + for weight in [1, 2, 4, 8, 16, 32, 64, 128] { + println!("{precision} {weight}"); + check_global_p_error_rounded_lut(precision, rounded_precision, weight); + } + } + } } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs index 593b8341d..5edd7046f 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs @@ -1,5 +1,5 @@ -use crate::dag::operator::{Operator, Precision}; -use crate::dag::unparametrized::{OperationDag, UnparameterizedOperator}; +use crate::dag::operator::Precision; +use crate::dag::unparametrized::OperationDag; use crate::noise_estimator::p_error::repeat_p_error; use crate::optimization::atomic_pattern::Solution as WpSolution; use crate::optimization::config::{Config, SearchSpace}; @@ -18,21 +18,8 @@ pub enum Solution { WopSolution(WopSolution), } -fn precision_op(op: &UnparameterizedOperator) -> Option { - match op { - Operator::Input { out_precision, .. } | Operator::Lut { out_precision, .. } => { - Some(*out_precision) - } - Operator::Dot { .. } | Operator::LevelledOp { .. } => None, - } -} - fn max_precision(dag: &OperationDag) -> Precision { - dag.operators - .iter() - .filter_map(precision_op) - .max() - .unwrap_or(0) + dag.out_precisions.iter().copied().max().unwrap_or(0) } fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {