From 34764edf50167122ab27759c3ce6d257007b80fa Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 2 Dec 2022 11:56:10 +0100 Subject: [PATCH] feat: option to force optimization with a particular encoding --- concrete-optimizer-cpp/Makefile | 2 +- .../src/concrete-optimizer.rs | 30 +++++++++++-- .../src/cpp/concrete-optimizer.cpp | 11 +++++ .../src/cpp/concrete-optimizer.hpp | 11 +++++ concrete-optimizer-cpp/tests/src/main.cpp | 41 ++++++++++++----- .../dag/solo_key/optimize_generic.rs | 45 +++++++++++++++---- 6 files changed, 115 insertions(+), 25 deletions(-) diff --git a/concrete-optimizer-cpp/Makefile b/concrete-optimizer-cpp/Makefile index 6756dcd0c..82b6b2cc6 100644 --- a/concrete-optimizer-cpp/Makefile +++ b/concrete-optimizer-cpp/Makefile @@ -37,7 +37,7 @@ $(INTERFACE_LIB): $(INTERFACE_LIB_ORIG) TESTS_SOURCES = tests/src/main.cpp TEST_DEP_LIBS = -l pthread -ldl tests/tests_exe: $(INTERFACE_LIB) $(INTERFACE_HEADER) $(INTERFACE_CPP) $(TESTS_SOURCES) - g++ -o $@ $(TESTS_SOURCES) $(INTERFACE_CPP) $(INTERFACE_LIB) -I $(shell dirname $(INTERFACE_HEADER)) $(TEST_DEP_LIBS) + g++ -Wall -Werror -Wextra -o $@ $(TESTS_SOURCES) $(INTERFACE_CPP) $(INTERFACE_LIB) -I $(shell dirname $(INTERFACE_HEADER)) $(TEST_DEP_LIBS) chmod +x $@ test: tests/tests_exe diff --git a/concrete-optimizer-cpp/src/concrete-optimizer.rs b/concrete-optimizer-cpp/src/concrete-optimizer.rs index e8f1966f0..cf768481e 100644 --- a/concrete-optimizer-cpp/src/concrete-optimizer.rs +++ b/concrete-optimizer-cpp/src/concrete-optimizer.rs @@ -6,7 +6,9 @@ use concrete_optimizer::dag::operator::{ }; use concrete_optimizer::dag::unparametrized; use concrete_optimizer::optimization::config::{Config, SearchSpace}; -use concrete_optimizer::optimization::dag::solo_key::optimize_generic::Solution as DagSolution; +use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{ + Encoding, Solution as DagSolution, +}; use concrete_optimizer::optimization::decomposition; fn no_solution() -> ffi::Solution { @@ -248,11 +250,12 @@ impl OperationDag { processing_unit, Some(ProcessingUnit::Cpu.complexity_model()), ); - + let encoding = options.encoding.into(); let result = concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize( &self.0, config, &search_space, + encoding, options.default_log_norm2_woppbs, &cache, ); @@ -283,6 +286,18 @@ impl Into for ffi::OperatorIndex { } } +#[allow(clippy::from_over_into)] +impl Into for ffi::Encoding { + fn into(self) -> Encoding { + match self { + Self::Auto => Encoding::Auto, + Self::Native => Encoding::Native, + Self::Crt => Encoding::Crt, + _ => unreachable!("Internal error: Invalid encoding"), + } + } +} + #[allow(unused_must_use)] #[cxx::bridge] mod ffi { @@ -342,6 +357,14 @@ mod ffi { fn vector(weights: &[i64]) -> Box; } + #[derive(Debug, Clone, Copy)] + #[namespace = "concrete_optimizer"] + pub enum Encoding { + Auto, + Native, + Crt, + } + #[derive(Clone, Copy)] #[namespace = "concrete_optimizer::dag"] struct OperatorIndex { @@ -386,12 +409,13 @@ mod ffi { } #[namespace = "concrete_optimizer"] - #[derive(Debug, Clone, Copy, Default)] + #[derive(Debug, Clone, Copy)] pub struct Options { pub security_level: u64, pub maximum_acceptable_error_probability: f64, pub default_log_norm2_woppbs: f64, pub use_gpu_constraints: bool, + pub encoding: Encoding, } } diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp index 8152f0250..a98abce1c 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.cpp @@ -941,6 +941,7 @@ union MaybeUninit { namespace concrete_optimizer { struct OperationDag; struct Weights; + enum class Encoding : ::std::uint8_t; struct Options; namespace dag { struct OperatorIndex; @@ -987,6 +988,15 @@ private: }; #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights +#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$Encoding +#define CXXBRIDGE1_ENUM_concrete_optimizer$Encoding +enum class Encoding : ::std::uint8_t { + Auto = 0, + Native = 1, + Crt = 2, +}; +#endif // CXXBRIDGE1_ENUM_concrete_optimizer$Encoding + namespace dag { #ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex #define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex @@ -1052,6 +1062,7 @@ struct Options final { double maximum_acceptable_error_probability; double default_log_norm2_woppbs; bool use_gpu_constraints; + ::concrete_optimizer::Encoding encoding; using IsRelocatable = ::std::true_type; }; diff --git a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp index 654940f06..7c5a7fa1b 100644 --- a/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp +++ b/concrete-optimizer-cpp/src/cpp/concrete-optimizer.hpp @@ -922,6 +922,7 @@ std::size_t align_of() { namespace concrete_optimizer { struct OperationDag; struct Weights; + enum class Encoding : ::std::uint8_t; struct Options; namespace dag { struct OperatorIndex; @@ -968,6 +969,15 @@ private: }; #endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights +#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$Encoding +#define CXXBRIDGE1_ENUM_concrete_optimizer$Encoding +enum class Encoding : ::std::uint8_t { + Auto = 0, + Native = 1, + Crt = 2, +}; +#endif // CXXBRIDGE1_ENUM_concrete_optimizer$Encoding + namespace dag { #ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex #define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex @@ -1033,6 +1043,7 @@ struct Options final { double maximum_acceptable_error_probability; double default_log_norm2_woppbs; bool use_gpu_constraints; + ::concrete_optimizer::Encoding encoding; using IsRelocatable = ::std::true_type; }; diff --git a/concrete-optimizer-cpp/tests/src/main.cpp b/concrete-optimizer-cpp/tests/src/main.cpp index 4f7601c9f..c49d942cf 100644 --- a/concrete-optimizer-cpp/tests/src/main.cpp +++ b/concrete-optimizer-cpp/tests/src/main.cpp @@ -18,11 +18,13 @@ const double WOP_FALLBACK_LOG_NORM = 8; const double NOISE_DEVIATION_COEFF = 1.0; concrete_optimizer::Options default_options() { - concrete_optimizer::Options options; - options.security_level = SECURITY_128B; - options.maximum_acceptable_error_probability = P_ERROR; - options.use_gpu_constraints = false; - return options; + return concrete_optimizer::Options { + .security_level = SECURITY_128B, + .maximum_acceptable_error_probability = P_ERROR, + .default_log_norm2_woppbs = WOP_FALLBACK_LOG_NORM, + .use_gpu_constraints = false, + .encoding = concrete_optimizer::Encoding::Auto + }; } void test_v0() { @@ -48,8 +50,7 @@ void test_dag_no_lut() { rust::cxxbridge1::Box weights = concrete_optimizer::weights::vector(slice(weight_vec)); - concrete_optimizer::dag::OperatorIndex node2 = - dag->add_dot(slice(inputs), std::move(weights)); + dag->add_dot(slice(inputs), std::move(weights)); auto solution = dag->optimize_v0(default_options()); assert(solution.glwe_polynomial_size == 256); @@ -64,8 +65,7 @@ void test_dag_lut() { dag->add_input(PRECISION_8B, slice(shape)); std::vector table = {}; - concrete_optimizer::dag::OperatorIndex node2 = - dag->add_lut(input, slice(table), PRECISION_8B); + dag->add_lut(input, slice(table), PRECISION_8B); auto solution = dag->optimize(default_options()); assert(solution.glwe_dimension == 1); @@ -82,8 +82,7 @@ void test_dag_lut_wop() { dag->add_input(PRECISION_16B, slice(shape)); std::vector table = {}; - concrete_optimizer::dag::OperatorIndex node2 = - dag->add_lut(input, slice(table), PRECISION_16B); + dag->add_lut(input, slice(table), PRECISION_16B); auto solution = dag->optimize(default_options()); assert(solution.glwe_dimension == 2); @@ -91,11 +90,29 @@ void test_dag_lut_wop() { assert(solution.use_wop_pbs); } -int main(int argc, char *argv[]) { +void test_dag_lut_force_wop() { + auto dag = concrete_optimizer::dag::empty(); + + std::vector shape = {3}; + + concrete_optimizer::dag::OperatorIndex input = + dag->add_input(PRECISION_8B, slice(shape)); + + std::vector table = {}; + dag->add_lut(input, slice(table), PRECISION_8B); + + auto options = default_options(); + options.encoding = concrete_optimizer::Encoding::Crt; + auto solution = dag->optimize(options); + assert(solution.use_wop_pbs); +} + +int main() { test_v0(); test_dag_no_lut(); test_dag_lut(); test_dag_lut_wop(); + test_dag_lut_force_wop(); return 0; } diff --git a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs index d285f062a..ad81e4c04 100644 --- a/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs +++ b/concrete-optimizer/src/optimization/dag/solo_key/optimize_generic.rs @@ -13,6 +13,13 @@ pub enum Solution { WopSolution(WopSolution), } +#[derive(Clone, Copy)] +pub enum Encoding { + Auto, + Native, + Crt, +} + fn max_precision(dag: &OperationDag) -> Precision { dag.out_precisions.iter().copied().max().unwrap_or(0) } @@ -26,22 +33,42 @@ fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution { } } -pub fn optimize( +fn optimize_with_wop_pbs( dag: &OperationDag, config: Config, search_space: &SearchSpace, default_log_norm2_woppbs: f64, caches: &PersistDecompCaches, -) -> Option { - let opt_sol = optimize::optimize(dag, config, search_space, caches).best_solution; - if opt_sol.is_some() { - return opt_sol.map(Solution::WpSolution); - } +) -> Option { let max_precision = max_precision(dag); let nb_luts = analyze::lut_count_from_dag(dag); let worst_log_norm = analyze::worst_log_norm(dag); let log_norm = default_log_norm2_woppbs.min(worst_log_norm); - let opt_sol = - wop_optimize(max_precision as u64, config, log_norm, search_space, caches).best_solution; - opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol))) + wop_optimize(max_precision as u64, config, log_norm, search_space, caches) + .best_solution + .map(|sol| updated_global_p_error(nb_luts, sol)) +} + +pub fn optimize( + dag: &OperationDag, + config: Config, + search_space: &SearchSpace, + encoding: Encoding, + default_log_norm2_woppbs: f64, + caches: &PersistDecompCaches, +) -> Option { + let native = || { + optimize::optimize(dag, config, search_space, caches) + .best_solution + .map(Solution::WpSolution) + }; + let crt = || { + optimize_with_wop_pbs(dag, config, search_space, default_log_norm2_woppbs, caches) + .map(Solution::WopSolution) + }; + match encoding { + Encoding::Auto => native().or_else(crt), + Encoding::Native => native(), + Encoding::Crt => crt(), + } }