feat: option to force optimization with a particular encoding

This commit is contained in:
rudy
2022-12-02 11:56:10 +01:00
committed by rudy-6-4
parent 0d7cb97e7e
commit 34764edf50
6 changed files with 115 additions and 25 deletions

View File

@@ -37,7 +37,7 @@ $(INTERFACE_LIB): $(INTERFACE_LIB_ORIG)
TESTS_SOURCES = tests/src/main.cpp
TEST_DEP_LIBS = -l pthread -ldl
tests/tests_exe: $(INTERFACE_LIB) $(INTERFACE_HEADER) $(INTERFACE_CPP) $(TESTS_SOURCES)
g++ -o $@ $(TESTS_SOURCES) $(INTERFACE_CPP) $(INTERFACE_LIB) -I $(shell dirname $(INTERFACE_HEADER)) $(TEST_DEP_LIBS)
g++ -Wall -Werror -Wextra -o $@ $(TESTS_SOURCES) $(INTERFACE_CPP) $(INTERFACE_LIB) -I $(shell dirname $(INTERFACE_HEADER)) $(TEST_DEP_LIBS)
chmod +x $@
test: tests/tests_exe

View File

@@ -6,7 +6,9 @@ use concrete_optimizer::dag::operator::{
};
use concrete_optimizer::dag::unparametrized;
use concrete_optimizer::optimization::config::{Config, SearchSpace};
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::Solution as DagSolution;
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{
Encoding, Solution as DagSolution,
};
use concrete_optimizer::optimization::decomposition;
fn no_solution() -> ffi::Solution {
@@ -248,11 +250,12 @@ impl OperationDag {
processing_unit,
Some(ProcessingUnit::Cpu.complexity_model()),
);
let encoding = options.encoding.into();
let result = concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize(
&self.0,
config,
&search_space,
encoding,
options.default_log_norm2_woppbs,
&cache,
);
@@ -283,6 +286,18 @@ impl Into<OperatorIndex> for ffi::OperatorIndex {
}
}
#[allow(clippy::from_over_into)]
impl Into<Encoding> for ffi::Encoding {
fn into(self) -> Encoding {
match self {
Self::Auto => Encoding::Auto,
Self::Native => Encoding::Native,
Self::Crt => Encoding::Crt,
_ => unreachable!("Internal error: Invalid encoding"),
}
}
}
#[allow(unused_must_use)]
#[cxx::bridge]
mod ffi {
@@ -342,6 +357,14 @@ mod ffi {
fn vector(weights: &[i64]) -> Box<Weights>;
}
#[derive(Debug, Clone, Copy)]
#[namespace = "concrete_optimizer"]
pub enum Encoding {
Auto,
Native,
Crt,
}
#[derive(Clone, Copy)]
#[namespace = "concrete_optimizer::dag"]
struct OperatorIndex {
@@ -386,12 +409,13 @@ mod ffi {
}
#[namespace = "concrete_optimizer"]
#[derive(Debug, Clone, Copy, Default)]
#[derive(Debug, Clone, Copy)]
pub struct Options {
pub security_level: u64,
pub maximum_acceptable_error_probability: f64,
pub default_log_norm2_woppbs: f64,
pub use_gpu_constraints: bool,
pub encoding: Encoding,
}
}

View File

@@ -941,6 +941,7 @@ union MaybeUninit {
namespace concrete_optimizer {
struct OperationDag;
struct Weights;
enum class Encoding : ::std::uint8_t;
struct Options;
namespace dag {
struct OperatorIndex;
@@ -987,6 +988,15 @@ private:
};
#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights
#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
#define CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
enum class Encoding : ::std::uint8_t {
Auto = 0,
Native = 1,
Crt = 2,
};
#endif // CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
namespace dag {
#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
#define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
@@ -1052,6 +1062,7 @@ struct Options final {
double maximum_acceptable_error_probability;
double default_log_norm2_woppbs;
bool use_gpu_constraints;
::concrete_optimizer::Encoding encoding;
using IsRelocatable = ::std::true_type;
};

View File

@@ -922,6 +922,7 @@ std::size_t align_of() {
namespace concrete_optimizer {
struct OperationDag;
struct Weights;
enum class Encoding : ::std::uint8_t;
struct Options;
namespace dag {
struct OperatorIndex;
@@ -968,6 +969,15 @@ private:
};
#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights
#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
#define CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
enum class Encoding : ::std::uint8_t {
Auto = 0,
Native = 1,
Crt = 2,
};
#endif // CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
namespace dag {
#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
#define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
@@ -1033,6 +1043,7 @@ struct Options final {
double maximum_acceptable_error_probability;
double default_log_norm2_woppbs;
bool use_gpu_constraints;
::concrete_optimizer::Encoding encoding;
using IsRelocatable = ::std::true_type;
};

View File

@@ -18,11 +18,13 @@ const double WOP_FALLBACK_LOG_NORM = 8;
const double NOISE_DEVIATION_COEFF = 1.0;
concrete_optimizer::Options default_options() {
concrete_optimizer::Options options;
options.security_level = SECURITY_128B;
options.maximum_acceptable_error_probability = P_ERROR;
options.use_gpu_constraints = false;
return options;
return concrete_optimizer::Options {
.security_level = SECURITY_128B,
.maximum_acceptable_error_probability = P_ERROR,
.default_log_norm2_woppbs = WOP_FALLBACK_LOG_NORM,
.use_gpu_constraints = false,
.encoding = concrete_optimizer::Encoding::Auto
};
}
void test_v0() {
@@ -48,8 +50,7 @@ void test_dag_no_lut() {
rust::cxxbridge1::Box<concrete_optimizer::Weights> weights =
concrete_optimizer::weights::vector(slice(weight_vec));
concrete_optimizer::dag::OperatorIndex node2 =
dag->add_dot(slice(inputs), std::move(weights));
dag->add_dot(slice(inputs), std::move(weights));
auto solution = dag->optimize_v0(default_options());
assert(solution.glwe_polynomial_size == 256);
@@ -64,8 +65,7 @@ void test_dag_lut() {
dag->add_input(PRECISION_8B, slice(shape));
std::vector<u_int64_t> table = {};
concrete_optimizer::dag::OperatorIndex node2 =
dag->add_lut(input, slice(table), PRECISION_8B);
dag->add_lut(input, slice(table), PRECISION_8B);
auto solution = dag->optimize(default_options());
assert(solution.glwe_dimension == 1);
@@ -82,8 +82,7 @@ void test_dag_lut_wop() {
dag->add_input(PRECISION_16B, slice(shape));
std::vector<u_int64_t> table = {};
concrete_optimizer::dag::OperatorIndex node2 =
dag->add_lut(input, slice(table), PRECISION_16B);
dag->add_lut(input, slice(table), PRECISION_16B);
auto solution = dag->optimize(default_options());
assert(solution.glwe_dimension == 2);
@@ -91,11 +90,29 @@ void test_dag_lut_wop() {
assert(solution.use_wop_pbs);
}
int main(int argc, char *argv[]) {
void test_dag_lut_force_wop() {
auto dag = concrete_optimizer::dag::empty();
std::vector<uint64_t> shape = {3};
concrete_optimizer::dag::OperatorIndex input =
dag->add_input(PRECISION_8B, slice(shape));
std::vector<u_int64_t> table = {};
dag->add_lut(input, slice(table), PRECISION_8B);
auto options = default_options();
options.encoding = concrete_optimizer::Encoding::Crt;
auto solution = dag->optimize(options);
assert(solution.use_wop_pbs);
}
int main() {
test_v0();
test_dag_no_lut();
test_dag_lut();
test_dag_lut_wop();
test_dag_lut_force_wop();
return 0;
}

View File

@@ -13,6 +13,13 @@ pub enum Solution {
WopSolution(WopSolution),
}
#[derive(Clone, Copy)]
pub enum Encoding {
Auto,
Native,
Crt,
}
fn max_precision(dag: &OperationDag) -> Precision {
dag.out_precisions.iter().copied().max().unwrap_or(0)
}
@@ -26,22 +33,42 @@ fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {
}
}
pub fn optimize(
fn optimize_with_wop_pbs(
dag: &OperationDag,
config: Config,
search_space: &SearchSpace,
default_log_norm2_woppbs: f64,
caches: &PersistDecompCaches,
) -> Option<Solution> {
let opt_sol = optimize::optimize(dag, config, search_space, caches).best_solution;
if opt_sol.is_some() {
return opt_sol.map(Solution::WpSolution);
}
) -> Option<WopSolution> {
let max_precision = max_precision(dag);
let nb_luts = analyze::lut_count_from_dag(dag);
let worst_log_norm = analyze::worst_log_norm(dag);
let log_norm = default_log_norm2_woppbs.min(worst_log_norm);
let opt_sol =
wop_optimize(max_precision as u64, config, log_norm, search_space, caches).best_solution;
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
wop_optimize(max_precision as u64, config, log_norm, search_space, caches)
.best_solution
.map(|sol| updated_global_p_error(nb_luts, sol))
}
pub fn optimize(
dag: &OperationDag,
config: Config,
search_space: &SearchSpace,
encoding: Encoding,
default_log_norm2_woppbs: f64,
caches: &PersistDecompCaches,
) -> Option<Solution> {
let native = || {
optimize::optimize(dag, config, search_space, caches)
.best_solution
.map(Solution::WpSolution)
};
let crt = || {
optimize_with_wop_pbs(dag, config, search_space, default_log_norm2_woppbs, caches)
.map(Solution::WopSolution)
};
match encoding {
Encoding::Auto => native().or_else(crt),
Encoding::Native => native(),
Encoding::Crt => crt(),
}
}