mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat: option to force optimization with a particular encoding
This commit is contained in:
@@ -37,7 +37,7 @@ $(INTERFACE_LIB): $(INTERFACE_LIB_ORIG)
|
||||
TESTS_SOURCES = tests/src/main.cpp
|
||||
TEST_DEP_LIBS = -l pthread -ldl
|
||||
tests/tests_exe: $(INTERFACE_LIB) $(INTERFACE_HEADER) $(INTERFACE_CPP) $(TESTS_SOURCES)
|
||||
g++ -o $@ $(TESTS_SOURCES) $(INTERFACE_CPP) $(INTERFACE_LIB) -I $(shell dirname $(INTERFACE_HEADER)) $(TEST_DEP_LIBS)
|
||||
g++ -Wall -Werror -Wextra -o $@ $(TESTS_SOURCES) $(INTERFACE_CPP) $(INTERFACE_LIB) -I $(shell dirname $(INTERFACE_HEADER)) $(TEST_DEP_LIBS)
|
||||
chmod +x $@
|
||||
|
||||
test: tests/tests_exe
|
||||
|
||||
@@ -6,7 +6,9 @@ use concrete_optimizer::dag::operator::{
|
||||
};
|
||||
use concrete_optimizer::dag::unparametrized;
|
||||
use concrete_optimizer::optimization::config::{Config, SearchSpace};
|
||||
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::Solution as DagSolution;
|
||||
use concrete_optimizer::optimization::dag::solo_key::optimize_generic::{
|
||||
Encoding, Solution as DagSolution,
|
||||
};
|
||||
use concrete_optimizer::optimization::decomposition;
|
||||
|
||||
fn no_solution() -> ffi::Solution {
|
||||
@@ -248,11 +250,12 @@ impl OperationDag {
|
||||
processing_unit,
|
||||
Some(ProcessingUnit::Cpu.complexity_model()),
|
||||
);
|
||||
|
||||
let encoding = options.encoding.into();
|
||||
let result = concrete_optimizer::optimization::dag::solo_key::optimize_generic::optimize(
|
||||
&self.0,
|
||||
config,
|
||||
&search_space,
|
||||
encoding,
|
||||
options.default_log_norm2_woppbs,
|
||||
&cache,
|
||||
);
|
||||
@@ -283,6 +286,18 @@ impl Into<OperatorIndex> for ffi::OperatorIndex {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::from_over_into)]
|
||||
impl Into<Encoding> for ffi::Encoding {
|
||||
fn into(self) -> Encoding {
|
||||
match self {
|
||||
Self::Auto => Encoding::Auto,
|
||||
Self::Native => Encoding::Native,
|
||||
Self::Crt => Encoding::Crt,
|
||||
_ => unreachable!("Internal error: Invalid encoding"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_must_use)]
|
||||
#[cxx::bridge]
|
||||
mod ffi {
|
||||
@@ -342,6 +357,14 @@ mod ffi {
|
||||
fn vector(weights: &[i64]) -> Box<Weights>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
#[namespace = "concrete_optimizer"]
|
||||
pub enum Encoding {
|
||||
Auto,
|
||||
Native,
|
||||
Crt,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
#[namespace = "concrete_optimizer::dag"]
|
||||
struct OperatorIndex {
|
||||
@@ -386,12 +409,13 @@ mod ffi {
|
||||
}
|
||||
|
||||
#[namespace = "concrete_optimizer"]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Options {
|
||||
pub security_level: u64,
|
||||
pub maximum_acceptable_error_probability: f64,
|
||||
pub default_log_norm2_woppbs: f64,
|
||||
pub use_gpu_constraints: bool,
|
||||
pub encoding: Encoding,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -941,6 +941,7 @@ union MaybeUninit {
|
||||
namespace concrete_optimizer {
|
||||
struct OperationDag;
|
||||
struct Weights;
|
||||
enum class Encoding : ::std::uint8_t;
|
||||
struct Options;
|
||||
namespace dag {
|
||||
struct OperatorIndex;
|
||||
@@ -987,6 +988,15 @@ private:
|
||||
};
|
||||
#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights
|
||||
|
||||
#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
|
||||
#define CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
|
||||
enum class Encoding : ::std::uint8_t {
|
||||
Auto = 0,
|
||||
Native = 1,
|
||||
Crt = 2,
|
||||
};
|
||||
#endif // CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
|
||||
|
||||
namespace dag {
|
||||
#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
|
||||
#define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
|
||||
@@ -1052,6 +1062,7 @@ struct Options final {
|
||||
double maximum_acceptable_error_probability;
|
||||
double default_log_norm2_woppbs;
|
||||
bool use_gpu_constraints;
|
||||
::concrete_optimizer::Encoding encoding;
|
||||
|
||||
using IsRelocatable = ::std::true_type;
|
||||
};
|
||||
|
||||
@@ -922,6 +922,7 @@ std::size_t align_of() {
|
||||
namespace concrete_optimizer {
|
||||
struct OperationDag;
|
||||
struct Weights;
|
||||
enum class Encoding : ::std::uint8_t;
|
||||
struct Options;
|
||||
namespace dag {
|
||||
struct OperatorIndex;
|
||||
@@ -968,6 +969,15 @@ private:
|
||||
};
|
||||
#endif // CXXBRIDGE1_STRUCT_concrete_optimizer$Weights
|
||||
|
||||
#ifndef CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
|
||||
#define CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
|
||||
enum class Encoding : ::std::uint8_t {
|
||||
Auto = 0,
|
||||
Native = 1,
|
||||
Crt = 2,
|
||||
};
|
||||
#endif // CXXBRIDGE1_ENUM_concrete_optimizer$Encoding
|
||||
|
||||
namespace dag {
|
||||
#ifndef CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
|
||||
#define CXXBRIDGE1_STRUCT_concrete_optimizer$dag$OperatorIndex
|
||||
@@ -1033,6 +1043,7 @@ struct Options final {
|
||||
double maximum_acceptable_error_probability;
|
||||
double default_log_norm2_woppbs;
|
||||
bool use_gpu_constraints;
|
||||
::concrete_optimizer::Encoding encoding;
|
||||
|
||||
using IsRelocatable = ::std::true_type;
|
||||
};
|
||||
|
||||
@@ -18,11 +18,13 @@ const double WOP_FALLBACK_LOG_NORM = 8;
|
||||
const double NOISE_DEVIATION_COEFF = 1.0;
|
||||
|
||||
concrete_optimizer::Options default_options() {
|
||||
concrete_optimizer::Options options;
|
||||
options.security_level = SECURITY_128B;
|
||||
options.maximum_acceptable_error_probability = P_ERROR;
|
||||
options.use_gpu_constraints = false;
|
||||
return options;
|
||||
return concrete_optimizer::Options {
|
||||
.security_level = SECURITY_128B,
|
||||
.maximum_acceptable_error_probability = P_ERROR,
|
||||
.default_log_norm2_woppbs = WOP_FALLBACK_LOG_NORM,
|
||||
.use_gpu_constraints = false,
|
||||
.encoding = concrete_optimizer::Encoding::Auto
|
||||
};
|
||||
}
|
||||
|
||||
void test_v0() {
|
||||
@@ -48,8 +50,7 @@ void test_dag_no_lut() {
|
||||
rust::cxxbridge1::Box<concrete_optimizer::Weights> weights =
|
||||
concrete_optimizer::weights::vector(slice(weight_vec));
|
||||
|
||||
concrete_optimizer::dag::OperatorIndex node2 =
|
||||
dag->add_dot(slice(inputs), std::move(weights));
|
||||
dag->add_dot(slice(inputs), std::move(weights));
|
||||
|
||||
auto solution = dag->optimize_v0(default_options());
|
||||
assert(solution.glwe_polynomial_size == 256);
|
||||
@@ -64,8 +65,7 @@ void test_dag_lut() {
|
||||
dag->add_input(PRECISION_8B, slice(shape));
|
||||
|
||||
std::vector<u_int64_t> table = {};
|
||||
concrete_optimizer::dag::OperatorIndex node2 =
|
||||
dag->add_lut(input, slice(table), PRECISION_8B);
|
||||
dag->add_lut(input, slice(table), PRECISION_8B);
|
||||
|
||||
auto solution = dag->optimize(default_options());
|
||||
assert(solution.glwe_dimension == 1);
|
||||
@@ -82,8 +82,7 @@ void test_dag_lut_wop() {
|
||||
dag->add_input(PRECISION_16B, slice(shape));
|
||||
|
||||
std::vector<u_int64_t> table = {};
|
||||
concrete_optimizer::dag::OperatorIndex node2 =
|
||||
dag->add_lut(input, slice(table), PRECISION_16B);
|
||||
dag->add_lut(input, slice(table), PRECISION_16B);
|
||||
|
||||
auto solution = dag->optimize(default_options());
|
||||
assert(solution.glwe_dimension == 2);
|
||||
@@ -91,11 +90,29 @@ void test_dag_lut_wop() {
|
||||
assert(solution.use_wop_pbs);
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
void test_dag_lut_force_wop() {
|
||||
auto dag = concrete_optimizer::dag::empty();
|
||||
|
||||
std::vector<uint64_t> shape = {3};
|
||||
|
||||
concrete_optimizer::dag::OperatorIndex input =
|
||||
dag->add_input(PRECISION_8B, slice(shape));
|
||||
|
||||
std::vector<u_int64_t> table = {};
|
||||
dag->add_lut(input, slice(table), PRECISION_8B);
|
||||
|
||||
auto options = default_options();
|
||||
options.encoding = concrete_optimizer::Encoding::Crt;
|
||||
auto solution = dag->optimize(options);
|
||||
assert(solution.use_wop_pbs);
|
||||
}
|
||||
|
||||
int main() {
|
||||
test_v0();
|
||||
test_dag_no_lut();
|
||||
test_dag_lut();
|
||||
test_dag_lut_wop();
|
||||
test_dag_lut_force_wop();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -13,6 +13,13 @@ pub enum Solution {
|
||||
WopSolution(WopSolution),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Encoding {
|
||||
Auto,
|
||||
Native,
|
||||
Crt,
|
||||
}
|
||||
|
||||
fn max_precision(dag: &OperationDag) -> Precision {
|
||||
dag.out_precisions.iter().copied().max().unwrap_or(0)
|
||||
}
|
||||
@@ -26,22 +33,42 @@ fn updated_global_p_error(nb_luts: u64, sol: WopSolution) -> WopSolution {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn optimize(
|
||||
fn optimize_with_wop_pbs(
|
||||
dag: &OperationDag,
|
||||
config: Config,
|
||||
search_space: &SearchSpace,
|
||||
default_log_norm2_woppbs: f64,
|
||||
caches: &PersistDecompCaches,
|
||||
) -> Option<Solution> {
|
||||
let opt_sol = optimize::optimize(dag, config, search_space, caches).best_solution;
|
||||
if opt_sol.is_some() {
|
||||
return opt_sol.map(Solution::WpSolution);
|
||||
}
|
||||
) -> Option<WopSolution> {
|
||||
let max_precision = max_precision(dag);
|
||||
let nb_luts = analyze::lut_count_from_dag(dag);
|
||||
let worst_log_norm = analyze::worst_log_norm(dag);
|
||||
let log_norm = default_log_norm2_woppbs.min(worst_log_norm);
|
||||
let opt_sol =
|
||||
wop_optimize(max_precision as u64, config, log_norm, search_space, caches).best_solution;
|
||||
opt_sol.map(|sol| Solution::WopSolution(updated_global_p_error(nb_luts, sol)))
|
||||
wop_optimize(max_precision as u64, config, log_norm, search_space, caches)
|
||||
.best_solution
|
||||
.map(|sol| updated_global_p_error(nb_luts, sol))
|
||||
}
|
||||
|
||||
pub fn optimize(
|
||||
dag: &OperationDag,
|
||||
config: Config,
|
||||
search_space: &SearchSpace,
|
||||
encoding: Encoding,
|
||||
default_log_norm2_woppbs: f64,
|
||||
caches: &PersistDecompCaches,
|
||||
) -> Option<Solution> {
|
||||
let native = || {
|
||||
optimize::optimize(dag, config, search_space, caches)
|
||||
.best_solution
|
||||
.map(Solution::WpSolution)
|
||||
};
|
||||
let crt = || {
|
||||
optimize_with_wop_pbs(dag, config, search_space, default_log_norm2_woppbs, caches)
|
||||
.map(Solution::WopSolution)
|
||||
};
|
||||
match encoding {
|
||||
Encoding::Auto => native().or_else(crt),
|
||||
Encoding::Native => native(),
|
||||
Encoding::Crt => crt(),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user