From d220eb4009f06d2c5d153ae291bdfd54068158fb Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 1 Jun 2022 18:47:08 +0200 Subject: [PATCH] feat: multiprecision, allow precision change in lut --- .../src/concrete-optimizer.rs | 16 +++++-- .../src/cpp/concrete-optimizer.cpp | 8 ++-- .../src/cpp/concrete-optimizer.hpp | 2 +- .../src/dag/operator/operator.rs | 2 +- concrete-optimizer/src/dag/unparametrized.rs | 14 +++++-- concrete-optimizer/src/global_parameters.rs | 14 +++++-- .../src/optimization/dag/solo_key/analyze.rs | 37 ++++++++-------- .../src/optimization/dag/solo_key/optimize.rs | 42 ++++++++++--------- 8 files changed, 82 insertions(+), 53 deletions(-) diff --git a/concrete-optimizer-cpp/src/concrete-optimizer.rs b/concrete-optimizer-cpp/src/concrete-optimizer.rs index 9c50fb220..2c6781cab 100644 --- a/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -71,12 +71,17 @@ impl OperationDag { self.0.add_input(out_precision, out_shape).into() } - fn add_lut(&mut self, input: ffi::OperatorIndex, table: &[u64]) -> ffi::OperatorIndex { + fn add_lut( + &mut self, + input: ffi::OperatorIndex, + table: &[u64], + out_precision: Precision, + ) -> ffi::OperatorIndex { let table = FunctionTable { values: table.to_owned(), }; - self.0.add_lut(input.into(), table).into() + self.0.add_lut(input.into(), table, out_precision).into() } #[allow(clippy::boxed_local)] @@ -160,7 +165,12 @@ mod ffi { out_shape: &[u64], ) -> OperatorIndex; - fn add_lut(self: &mut OperationDag, input: OperatorIndex, table: &[u64]) -> OperatorIndex; + fn add_lut( + self: &mut OperationDag, + input: OperatorIndex, + table: &[u64], + out_precision: u8, + ) -> OperatorIndex; fn add_dot( self: &mut OperationDag, diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 893a28fa8..12d69195c 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -620,7 +620,7 @@ namespace concrete_optimizer { #define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag struct OperationDag final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice out_shape) noexcept; - ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table, ::std::uint8_t out_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice out_shape, ::rust::Str comment) noexcept; ~OperationDag() = delete; @@ -698,7 +698,7 @@ extern "C" { extern "C" { ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_input(::concrete_optimizer::OperationDag &self, ::std::uint8_t out_precision, ::rust::Slice out_shape) noexcept; -::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_lut(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept; +::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_lut(::concrete_optimizer::OperationDag &self, ::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table, ::std::uint8_t out_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex concrete_optimizer$cxxbridge1$OperationDag$add_dot(::concrete_optimizer::OperationDag &self, ::rust::Slice inputs, ::concrete_optimizer::Weights *weights) noexcept; @@ -737,8 +737,8 @@ namespace dag { return concrete_optimizer$cxxbridge1$OperationDag$add_input(*this, out_precision, out_shape); } -::concrete_optimizer::dag::OperatorIndex OperationDag::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept { - return concrete_optimizer$cxxbridge1$OperationDag$add_lut(*this, input, table); +::concrete_optimizer::dag::OperatorIndex OperationDag::add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table, ::std::uint8_t out_precision) noexcept { + return concrete_optimizer$cxxbridge1$OperationDag$add_lut(*this, input, table, out_precision); } ::concrete_optimizer::dag::OperatorIndex OperationDag::add_dot(::rust::Slice inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept { diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 8d7546e0f..c58c2bc15 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -621,7 +621,7 @@ namespace concrete_optimizer { #define CXXBRIDGE1_STRUCT_concrete_optimizer$OperationDag struct OperationDag final : public ::rust::Opaque { ::concrete_optimizer::dag::OperatorIndex add_input(::std::uint8_t out_precision, ::rust::Slice out_shape) noexcept; - ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table) noexcept; + ::concrete_optimizer::dag::OperatorIndex add_lut(::concrete_optimizer::dag::OperatorIndex input, ::rust::Slice table, ::std::uint8_t out_precision) noexcept; ::concrete_optimizer::dag::OperatorIndex add_dot(::rust::Slice inputs, ::rust::Box<::concrete_optimizer::Weights> weights) noexcept; ::concrete_optimizer::dag::OperatorIndex add_levelled_op(::rust::Slice inputs, double lwe_dim_cost_factor, double fixed_cost, double manp, ::rust::Slice out_shape, ::rust::Str comment) noexcept; ~OperationDag() = delete; diff --git a/concrete-optimizer/src/dag/operator/operator.rs b/concrete-optimizer/src/dag/operator/operator.rs index bb2d974ab..6f126b05c 100644 --- a/concrete-optimizer/src/dag/operator/operator.rs +++ b/concrete-optimizer/src/dag/operator/operator.rs @@ -63,7 +63,7 @@ pub enum Operator OperatorIndex { + pub fn add_lut( + &mut self, + input: OperatorIndex, + table: FunctionTable, + out_precision: Precision, + ) -> OperatorIndex { self.add_operator(Operator::Lut { input, table, + out_precision, extra_data: (), }) } @@ -100,14 +106,14 @@ mod tests { let cpx_add = LevelledComplexity::ADDITION; let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); - let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN); + let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 1); let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::vector(2), "concat"); let dot = graph.add_dot([concat], [1, 2]); - let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN); + let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2); let ops_index = [input1, input2, sum1, lut1, concat, dot, lut2]; for (expected_i, op_index) in ops_index.iter().enumerate() { @@ -138,6 +144,7 @@ mod tests { Operator::Lut { input: sum1, table: FunctionTable::UNKWOWN, + out_precision: 1, extra_data: () }, Operator::LevelledOp { @@ -159,6 +166,7 @@ mod tests { Operator::Lut { input: dot, table: FunctionTable::UNKWOWN, + out_precision: 2, extra_data: () } ] diff --git a/concrete-optimizer/src/global_parameters.rs b/concrete-optimizer/src/global_parameters.rs index 835f084fd..3367796ce 100644 --- a/concrete-optimizer/src/global_parameters.rs +++ b/concrete-optimizer/src/global_parameters.rs @@ -116,9 +116,15 @@ fn convert_maximal(op: UnparameterizedOperator) -> OperatorParameterIndexed { lwe_dimension_index: external_glwe_index, }, }, - Operator::Lut { input, table, .. } => Operator::Lut { + Operator::Lut { input, table, + out_precision, + .. + } => Operator::Lut { + input, + table, + out_precision, extra_data: LutParametersIndexed { input_lwe_dimension_index: external_glwe_index, ks_decomposition_parameter_index: ks_decomposition_index, @@ -257,13 +263,13 @@ mod tests { let cpx_add = LevelledComplexity::ADDITION; let sum1 = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::number(), "sum"); - let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN); + let lut1 = graph.add_lut(sum1, FunctionTable::UNKWOWN, 2); let concat = graph.add_levelled_op([input1, lut1], cpx_add, 1.0, Shape::number(), "concat"); let dot = graph.add_dot([concat], [1, 2]); - let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN); + let lut2 = graph.add_lut(dot, FunctionTable::UNKWOWN, 2); let graph_params = maximal_unify(graph); @@ -337,7 +343,7 @@ mod tests { let concat = graph.add_levelled_op([input1, input2], cpx_add, 1.0, Shape::vector(2), "concat"); - let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN); + let lut1 = graph.add_lut(concat, FunctionTable::UNKWOWN, 2); let graph_params = maximal_unify(graph); diff --git a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs index 0f6a1348e..532f1df27 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/analyze.rs @@ -189,11 +189,10 @@ fn out_shapes(dag: &unparametrized::OperationDag) -> Vec { fn out_precision( op: &unparametrized::UnparameterizedOperator, - out_precisions: &mut [Precision], + out_precisions: &[Precision], ) -> Precision { match op { - Op::Input { out_precision, .. } => *out_precision, - Op::Lut { input, .. } => out_precisions[input.i], + Op::Input { out_precision, .. } | Op::Lut { out_precision, .. } => *out_precision, Op::Dot { inputs, .. } | Op::LevelledOp { inputs, .. } => out_precisions[inputs[0].i], } } @@ -202,7 +201,7 @@ fn out_precisions(dag: &unparametrized::OperationDag) -> Vec { let nb_ops = dag.operators.len(); let mut out_precisions = Vec::::with_capacity(nb_ops); for op in &dag.operators { - let precision = out_precision(op, &mut out_precisions); + let precision = out_precision(op, &out_precisions); out_precisions.push(precision); } out_precisions @@ -612,7 +611,7 @@ mod tests { fn test_1_lut() { let mut graph = unparametrized::OperationDag::new(); let input1 = graph.add_input(8, Shape::number()); - let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN); + let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 8); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -695,9 +694,9 @@ mod tests { let input1 = graph.add_input(1, Shape::vector(2)); let weights = &Weights::vector([1, 2]); let dot1 = graph.add_dot([input1], weights); - let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN); + let lut1 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1); let dot2 = graph.add_dot([lut1, lut1], weights); - let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN); + let lut2 = graph.add_lut(dot2, FunctionTable::UNKWOWN, 1); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -742,10 +741,10 @@ mod tests { fn test_lut_dot_mixed_lut() { let mut graph = unparametrized::OperationDag::new(); let input1 = graph.add_input(1, Shape::number()); - let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN); + let lut1 = graph.add_lut(input1, FunctionTable::UNKWOWN, 1); let weights = &Weights::vector([2, 3]); let dot1 = graph.add_dot([input1, lut1], weights); - let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN); + let _lut2 = graph.add_lut(dot1, FunctionTable::UNKWOWN, 1); let analysis = analyze(&graph); let one_lut_cost = 100.0; let lwe_dim = 1024; @@ -768,15 +767,16 @@ mod tests { #[test] fn test_multi_precision_input() { let mut graph = unparametrized::OperationDag::new(); - let max_precision = 5_usize; + let max_precision: Precision = 5; for i in 1..=max_precision { let _ = graph.add_input(i as u8, Shape::number()); } let analysis = analyze(&graph); - assert!(analysis.constraints_by_precisions.len() == max_precision); + assert!(analysis.constraints_by_precisions.len() == max_precision as usize); let mut prev_safe_noise_bound = 0.0; for (i, ns) in analysis.constraints_by_precisions.iter().enumerate() { - assert_eq!(ns.precision, (max_precision - i) as u8); + let i_prec = i as Precision; + assert_eq!(ns.precision, max_precision - i_prec); assert_f64_eq(ns.pareto_output[0].input_coeff, 1.0); assert!(prev_safe_noise_bound < ns.safe_variance_bound); prev_safe_noise_bound = ns.safe_variance_bound; @@ -786,16 +786,17 @@ mod tests { #[test] fn test_multi_precision_lut() { let mut graph = unparametrized::OperationDag::new(); - let max_precision = 5_usize; - for i in 1..=max_precision { - let input = graph.add_input(i as u8, Shape::number()); - let _lut = graph.add_lut(input, FunctionTable::UNKWOWN); + let max_precision: Precision = 5; + for p in 1..=max_precision { + let input = graph.add_input(p, Shape::number()); + let _lut = graph.add_lut(input, FunctionTable::UNKWOWN, p); } let analysis = analyze(&graph); - assert!(analysis.constraints_by_precisions.len() == max_precision); + assert!(analysis.constraints_by_precisions.len() == max_precision as usize); let mut prev_safe_noise_bound = 0.0; for (i, ns) in analysis.constraints_by_precisions.iter().enumerate() { - assert_eq!(ns.precision, (max_precision - i) as u8); + let i_prec = i as Precision; + assert_eq!(ns.precision, max_precision - i_prec); assert_eq!(ns.pareto_output.len(), 1); assert_eq!(ns.pareto_in_lut.len(), 1); assert_f64_eq(ns.pareto_output[0].input_coeff, 0.0); diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs index a8887a611..399d1226c 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize.rs @@ -1,7 +1,7 @@ use concrete_commons::dispersion::DispersionParameter; use concrete_commons::numeric::UnsignedInteger; -use crate::dag::operator::LevelledComplexity; +use crate::dag::operator::{LevelledComplexity, Precision}; use crate::dag::unparametrized; use crate::noise_estimator::error; use crate::noise_estimator::operators::atomic_pattern as noise_atomic_pattern; @@ -300,11 +300,12 @@ pub fn optimize_v0( let complexity = LevelledComplexity::ADDITION * sum_size; let comment = "dot"; let mut dag = unparametrized::OperationDag::new(); - let input1 = dag.add_input(precision as u8, out_shape); + let precision = precision as Precision; + let input1 = dag.add_input(precision, out_shape); let dot1 = dag.add_levelled_op([input1], complexity, same_scale_manp, out_shape, comment); - let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN); + let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); let dot2 = dag.add_levelled_op([lut1], complexity, manp, out_shape, comment); - let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN); + let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); let mut state = optimize::( &dag, security_level, @@ -456,14 +457,14 @@ mod tests { } } - fn v0_parameter_ref_with_dot(precision: u64, weight: u64) { + fn v0_parameter_ref_with_dot(precision: Precision, weight: u64) { let mut dag = unparametrized::OperationDag::new(); { - let input1 = dag.add_input(precision as u8, Shape::number()); + let input1 = dag.add_input(precision, Shape::number()); let dot1 = dag.add_dot([input1], [1]); - let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN); + let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); let dot2 = dag.add_dot([lut1], [weight]); - let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN); + let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); } { let dag2 = analyze::analyze(&dag, &CONFIG); @@ -488,7 +489,7 @@ mod tests { let state = optimize(&dag); let state_ref = atomic_pattern::optimize_one::( 1, - precision, + precision as u64, security_level, weight as f64, maximum_acceptable_error_probability, @@ -510,10 +511,10 @@ mod tests { assert!(sol.assert_same(sol_ref)); } - fn no_lut_vs_lut(precision: u64) { + fn no_lut_vs_lut(precision: Precision) { let mut dag_lut = unparametrized::OperationDag::new(); let input1 = dag_lut.add_input(precision as u8, Shape::number()); - let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN); + let _lut1 = dag_lut.add_lut(input1, FunctionTable::UNKWOWN, precision); let mut dag_no_lut = unparametrized::OperationDag::new(); let _input2 = dag_no_lut.add_input(precision as u8, Shape::number()); @@ -540,23 +541,26 @@ mod tests { } } - fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise(precision: u64, weight: u64) { + fn lut_with_input_base_noise_better_than_lut_with_lut_base_noise( + precision: Precision, + weight: u64, + ) { let weight = &Weights::number(weight); let mut dag_1 = unparametrized::OperationDag::new(); { let input1 = dag_1.add_input(precision as u8, Shape::number()); let scaled_input1 = dag_1.add_dot([input1], weight); - let lut1 = dag_1.add_lut(scaled_input1, FunctionTable::UNKWOWN); - let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN); + let lut1 = dag_1.add_lut(scaled_input1, FunctionTable::UNKWOWN, precision); + let _lut2 = dag_1.add_lut(lut1, FunctionTable::UNKWOWN, precision); } let mut dag_2 = unparametrized::OperationDag::new(); { let input1 = dag_2.add_input(precision as u8, Shape::number()); - let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN); + let lut1 = dag_2.add_lut(input1, FunctionTable::UNKWOWN, precision); let scaled_lut1 = dag_2.add_dot([lut1], weight); - let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN); + let _lut2 = dag_2.add_lut(scaled_lut1, FunctionTable::UNKWOWN, precision); } let state_1 = optimize(&dag_1); @@ -581,12 +585,12 @@ mod tests { } } - fn circuit(dag: &mut unparametrized::OperationDag, precision: u8, weight: u64) { + fn circuit(dag: &mut unparametrized::OperationDag, precision: Precision, weight: u64) { let input = dag.add_input(precision, Shape::number()); let dot1 = dag.add_dot([input], [weight]); - let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN); + let lut1 = dag.add_lut(dot1, FunctionTable::UNKWOWN, precision); let dot2 = dag.add_dot([lut1], [weight]); - let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN); + let _lut2 = dag.add_lut(dot2, FunctionTable::UNKWOWN, precision); } fn assert_multi_precision_dominate_single(weight: u64) -> Option {