mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat(optimizer): report or warn using global p-error
This commit is contained in:
@@ -16,19 +16,25 @@ namespace concretelang {
|
||||
|
||||
namespace optimizer {
|
||||
constexpr double P_ERROR_4_SIGMA = 1.0 - 0.999936657516;
|
||||
constexpr double UNSPECIFIED_P_ERROR = NAN; // will use the default p error
|
||||
constexpr double NO_GLOBAL_P_ERROR = NAN; // will fallback on p error
|
||||
constexpr uint DEFAULT_SECURITY = 128;
|
||||
constexpr uint DEFAULT_FALLBACK_LOG_NORM_WOPPBS = 8;
|
||||
constexpr bool DEFAULT_DISPLAY = false;
|
||||
constexpr bool DEFAULT_STARTEGY_V0 = false;
|
||||
|
||||
struct Config {
|
||||
double p_error;
|
||||
double global_p_error;
|
||||
bool display;
|
||||
bool strategy_v0;
|
||||
std::uint64_t security;
|
||||
double fallback_log_norm_woppbs;
|
||||
};
|
||||
constexpr Config DEFAULT_CONFIG = {P_ERROR_4_SIGMA, false, false,
|
||||
DEFAULT_SECURITY,
|
||||
DEFAULT_FALLBACK_LOG_NORM_WOPPBS};
|
||||
|
||||
constexpr Config DEFAULT_CONFIG = {
|
||||
UNSPECIFIED_P_ERROR, NO_GLOBAL_P_ERROR, DEFAULT_DISPLAY,
|
||||
DEFAULT_STARTEGY_V0, DEFAULT_SECURITY, DEFAULT_FALLBACK_LOG_NORM_WOPPBS};
|
||||
|
||||
using Dag = rust::Box<concrete_optimizer::OperationDag>;
|
||||
using Solution = concrete_optimizer::v0::Solution;
|
||||
|
||||
@@ -66,6 +66,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
|
||||
.def("set_strategy_v0",
|
||||
[](CompilationOptions &options, bool strategy_v0) {
|
||||
options.optimizerConfig.strategy_v0 = strategy_v0;
|
||||
})
|
||||
.def("set_global_p_error",
|
||||
[](CompilationOptions &options, double global_p_error) {
|
||||
options.optimizerConfig.global_p_error = global_p_error;
|
||||
});
|
||||
|
||||
pybind11::class_<mlir::concretelang::JitCompilationResult>(
|
||||
|
||||
@@ -135,7 +135,7 @@ class CompilationOptions(WrapperCpp):
|
||||
self.cpp().set_funcname(funcname)
|
||||
|
||||
def set_p_error(self, p_error: float):
|
||||
"""Set global error probability for each pbs.
|
||||
"""Set error probability for shared by each pbs.
|
||||
|
||||
Args:
|
||||
p_error (float): probability of error for each lut
|
||||
@@ -177,3 +177,21 @@ class CompilationOptions(WrapperCpp):
|
||||
if not isinstance(enable, bool):
|
||||
raise TypeError("enable should be a bool")
|
||||
self.cpp().set_strategy_v0(enable)
|
||||
|
||||
def set_global_p_error(self, global_p_error: float):
|
||||
"""Set global error probability for the full circuit.
|
||||
|
||||
Args:
|
||||
global_p_error (float): probability of error for the full circuit
|
||||
|
||||
Raises:
|
||||
TypeError: if the value to set is not float
|
||||
ValueError: if the value to set is not in interval ]0; 1[
|
||||
"""
|
||||
if not isinstance(global_p_error, float):
|
||||
raise TypeError("can't set global_p_error to a non-float value")
|
||||
if global_p_error in (0.0, 1.0):
|
||||
raise ValueError("global_p_error cannot be 0 or 1")
|
||||
if not 0.0 <= global_p_error <= 1.0:
|
||||
raise ValueError("global_p_error be a probability in ]0; 1[")
|
||||
self.cpp().set_global_p_error(global_p_error)
|
||||
|
||||
@@ -34,19 +34,52 @@ optimizer::DagSolution getV0Parameter(V0FHEConstraint constraint,
|
||||
return concrete_optimizer::utils::convert_to_dag_solution(solution);
|
||||
}
|
||||
|
||||
const int MAXIMUM_OPTIMIZER_CALL = 5;
|
||||
optimizer::DagSolution getV1ParameterGlobalPError(optimizer::Dag &dag,
|
||||
optimizer::Config config) {
|
||||
// We find the approximate translation between local and global error with a
|
||||
// calibration call
|
||||
auto ref_p_error = std::min(config.p_error, config.global_p_error);
|
||||
auto ref_global_p_success = 1.0 - config.global_p_error;
|
||||
auto sol = dag->optimize(config.security, ref_p_error,
|
||||
config.fallback_log_norm_woppbs);
|
||||
for (int i = 2; i <= MAXIMUM_OPTIMIZER_CALL; i++) {
|
||||
auto local_p_success = 1.0 - sol.p_error;
|
||||
auto global_p_success = 1.0 - sol.global_p_error;
|
||||
auto power_global_to_local = log(local_p_success) / log(global_p_success);
|
||||
auto surrogate_p_local_success =
|
||||
pow(ref_global_p_success, power_global_to_local);
|
||||
config.p_error = 1.0 - surrogate_p_local_success;
|
||||
sol = dag->optimize(config.security, config.p_error,
|
||||
config.fallback_log_norm_woppbs);
|
||||
if (sol.global_p_error <= config.global_p_error) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return sol;
|
||||
}
|
||||
|
||||
optimizer::DagSolution getV1Parameter(optimizer::Dag &dag,
|
||||
optimizer::Config config) {
|
||||
if (!std::isnan(config.global_p_error)) {
|
||||
return getV1ParameterGlobalPError(dag, config);
|
||||
}
|
||||
return dag->optimize(config.security, config.p_error,
|
||||
config.fallback_log_norm_woppbs);
|
||||
}
|
||||
|
||||
static void display(V0FHEConstraint constraint,
|
||||
constexpr double WARN_ABOVE_GLOBAL_ERROR_RATE = 1.0 / 1000.0;
|
||||
|
||||
static void display(optimizer::Description &descr,
|
||||
optimizer::Config optimizerConfig,
|
||||
optimizer::DagSolution sol,
|
||||
optimizer::DagSolution sol, bool naive_user,
|
||||
std::chrono::milliseconds duration) {
|
||||
if (!optimizerConfig.display && !mlir::concretelang::isVerbose()) {
|
||||
return;
|
||||
}
|
||||
auto constraint = descr.constraint;
|
||||
auto complexity_label =
|
||||
descr.dag ? "for the full circuit" : "for each Pbs call";
|
||||
auto o = llvm::outs;
|
||||
o() << "--- Circuit\n"
|
||||
<< " " << constraint.p << " bits integers\n"
|
||||
@@ -54,12 +87,19 @@ static void display(V0FHEConstraint constraint,
|
||||
<< " " << duration.count() << "ms to solve\n"
|
||||
<< "--- Optimizer config\n"
|
||||
<< " " << optimizerConfig.p_error << " error per pbs call\n"
|
||||
<< "--- Complexity for each Pbs call\n"
|
||||
<< " " << optimizerConfig.global_p_error << " error per circuit call\n"
|
||||
<< "--- Complexity " << complexity_label << "\n"
|
||||
<< " " << (long)sol.complexity / (1000 * 1000)
|
||||
<< " Millions Operations\n"
|
||||
<< "--- Correctness for each Pbs call\n"
|
||||
<< " 1/" << int(1.0 / sol.p_error) << " errors (" << sol.p_error << ")\n"
|
||||
<< "--- Parameters resolution\n"
|
||||
<< " 1/" << int(1.0 / sol.p_error) << " errors (" << sol.p_error
|
||||
<< ")\n";
|
||||
if (descr.dag) {
|
||||
o() << "--- Correctness for the full circuit\n"
|
||||
<< " 1/" << int(1.0 / sol.global_p_error) << " errors ("
|
||||
<< sol.global_p_error << ")\n";
|
||||
}
|
||||
o() << "--- Parameters resolution\n"
|
||||
<< " " << sol.glwe_dimension << "x glwe_dimension\n"
|
||||
<< " 2**" << (size_t)std::log2l(sol.glwe_polynomial_size)
|
||||
<< " polynomial (" << sol.glwe_polynomial_size << ")\n"
|
||||
@@ -74,12 +114,36 @@ static void display(V0FHEConstraint constraint,
|
||||
<< sol.cb_decomposition_base_log << "\n";
|
||||
}
|
||||
o() << "---\n";
|
||||
|
||||
if (descr.dag && naive_user &&
|
||||
sol.global_p_error > WARN_ABOVE_GLOBAL_ERROR_RATE) {
|
||||
auto dominating_pbs =
|
||||
(int)(log(1.0 - sol.global_p_error) / log(1.0 - sol.p_error));
|
||||
o() << "---\n"
|
||||
<< "!!!!! WARNING !!!!!\n"
|
||||
<< "\n"
|
||||
<< "HIGH ERROR RATE: 1/" << int(1.0 / sol.global_p_error)
|
||||
<< " errors \n\n"
|
||||
<< "Resolve by using command line option: \n"
|
||||
<< "--global-error-probability=" << WARN_ABOVE_GLOBAL_ERROR_RATE
|
||||
<< "\n\n"
|
||||
<< "Reason:\n"
|
||||
<< dominating_pbs << " pbs dominate at 1/" << int(1.0 / sol.p_error)
|
||||
<< " errors rate\n";
|
||||
o() << "\n!!!!!!!!!!!!!!!!!!!\n";
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
|
||||
optimizer::Config config) {
|
||||
namespace chrono = std::chrono;
|
||||
auto start = chrono::high_resolution_clock::now();
|
||||
|
||||
auto naive_user =
|
||||
std::isnan(config.p_error) && std::isnan(config.global_p_error);
|
||||
if (std::isnan(config.p_error)) {
|
||||
config.p_error = optimizer::P_ERROR_4_SIGMA;
|
||||
}
|
||||
auto sol = (!descr.dag || config.strategy_v0)
|
||||
? getV0Parameter(descr.constraint, config)
|
||||
: getV1Parameter(descr.dag.getValue(), config);
|
||||
@@ -91,13 +155,24 @@ llvm::Expected<V0Parameter> getParameter(optimizer::Description &descr,
|
||||
llvm::errs() << "concrete-optimizer time: " << duration_s.count() << "s\n";
|
||||
}
|
||||
|
||||
display(descr.constraint, config, sol, duration);
|
||||
display(descr, config, sol, naive_user, duration);
|
||||
|
||||
if (sol.p_error == 1.0) {
|
||||
// The optimizer return a p_error = 1 if there is no solution
|
||||
// The optimizer return a p_error = 1 if there is no solution
|
||||
bool no_solution = sol.p_error == 1.0;
|
||||
// The global_p_error is best effort only, so we must verify
|
||||
bool bad_solution = !std::isnan(config.global_p_error) &&
|
||||
config.global_p_error < sol.global_p_error;
|
||||
|
||||
if (no_solution || bad_solution) {
|
||||
return StreamStringError() << "Cannot find crypto parameters";
|
||||
}
|
||||
|
||||
if (descr.dag && !config.display && naive_user &&
|
||||
sol.global_p_error > WARN_ABOVE_GLOBAL_ERROR_RATE) {
|
||||
llvm::errs() << "WARNING: high error rate, more details with "
|
||||
"--display-optimizer-choice\n";
|
||||
}
|
||||
|
||||
V0Parameter params;
|
||||
params.glweDimension = sol.glwe_dimension;
|
||||
params.logPolynomialSize = (size_t)std::log2l(sol.glwe_polynomial_size);
|
||||
|
||||
@@ -182,6 +182,13 @@ llvm::cl::opt<double> pbsErrorProbability(
|
||||
llvm::cl::desc("Change the default probability of error for all pbs"),
|
||||
llvm::cl::init(mlir::concretelang::optimizer::DEFAULT_CONFIG.p_error));
|
||||
|
||||
llvm::cl::opt<double> globalErrorProbability(
|
||||
"global-error-probability",
|
||||
llvm::cl::desc(
|
||||
"Use global error probability (override pbs error probability)"),
|
||||
llvm::cl::init(
|
||||
mlir::concretelang::optimizer::DEFAULT_CONFIG.global_p_error));
|
||||
|
||||
llvm::cl::opt<bool> displayOptimizerChoice(
|
||||
"display-optimizer-choice",
|
||||
llvm::cl::desc("Display the information returned by the optimizer"),
|
||||
@@ -342,10 +349,18 @@ cmdlineCompilationOptions() {
|
||||
cmdline::largeIntegerCircuitBootstrap[1];
|
||||
}
|
||||
|
||||
options.optimizerConfig.global_p_error = cmdline::globalErrorProbability;
|
||||
options.optimizerConfig.p_error = cmdline::pbsErrorProbability;
|
||||
options.optimizerConfig.display = cmdline::displayOptimizerChoice;
|
||||
options.optimizerConfig.strategy_v0 = cmdline::optimizerV0;
|
||||
|
||||
if (!std::isnan(options.optimizerConfig.global_p_error) &&
|
||||
options.optimizerConfig.strategy_v0) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"--global-error-probability is not compatible with --optimizer-v0",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
return options;
|
||||
}
|
||||
|
||||
|
||||
@@ -223,6 +223,23 @@ def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options)
|
||||
|
||||
|
||||
def test_lib_compile_and_run_p_error(keyset_cache):
|
||||
mlir_input = """
|
||||
func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64>
|
||||
%1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>)
|
||||
return %1: !FHE.eint<7>
|
||||
}
|
||||
"""
|
||||
args = (73,)
|
||||
expected_result = 73
|
||||
engine = LibrarySupport.new("./py_test_lib_compile_and_run_custom_perror")
|
||||
options = CompilationOptions.new("main")
|
||||
options.set_global_p_error(0.00001)
|
||||
options.set_display_optimizer_choice(True)
|
||||
compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options)
|
||||
|
||||
|
||||
@pytest.mark.parallel
|
||||
@pytest.mark.parametrize(
|
||||
"mlir_input, args, expected_result", end_to_end_parallel_fixture
|
||||
|
||||
Reference in New Issue
Block a user