From 48bf6e269684732f6b3152cc156dec0ecc7a5346 Mon Sep 17 00:00:00 2001 From: rudy Date: Wed, 20 Jul 2022 15:22:38 +0200 Subject: [PATCH] feat(optimizer): report or warn using global p-error --- .../concretelang/Support/V0Parameters.h | 12 ++- .../lib/Bindings/Python/CompilerAPIModule.cpp | 4 + .../concrete/compiler/compilation_options.py | 20 +++- compiler/lib/Support/V0Parameters.cpp | 91 +++++++++++++++++-- compiler/src/main.cpp | 15 +++ compiler/tests/python/test_compilation.py | 17 ++++ 6 files changed, 147 insertions(+), 12 deletions(-) diff --git a/compiler/include/concretelang/Support/V0Parameters.h b/compiler/include/concretelang/Support/V0Parameters.h index 70264a435..8eb5abccf 100644 --- a/compiler/include/concretelang/Support/V0Parameters.h +++ b/compiler/include/concretelang/Support/V0Parameters.h @@ -16,19 +16,25 @@ namespace concretelang { namespace optimizer { constexpr double P_ERROR_4_SIGMA = 1.0 - 0.999936657516; +constexpr double UNSPECIFIED_P_ERROR = NAN; // will use the default p error +constexpr double NO_GLOBAL_P_ERROR = NAN; // will fallback on p error constexpr uint DEFAULT_SECURITY = 128; constexpr uint DEFAULT_FALLBACK_LOG_NORM_WOPPBS = 8; +constexpr bool DEFAULT_DISPLAY = false; +constexpr bool DEFAULT_STARTEGY_V0 = false; struct Config { double p_error; + double global_p_error; bool display; bool strategy_v0; std::uint64_t security; double fallback_log_norm_woppbs; }; -constexpr Config DEFAULT_CONFIG = {P_ERROR_4_SIGMA, false, false, - DEFAULT_SECURITY, - DEFAULT_FALLBACK_LOG_NORM_WOPPBS}; + +constexpr Config DEFAULT_CONFIG = { + UNSPECIFIED_P_ERROR, NO_GLOBAL_P_ERROR, DEFAULT_DISPLAY, + DEFAULT_STARTEGY_V0, DEFAULT_SECURITY, DEFAULT_FALLBACK_LOG_NORM_WOPPBS}; using Dag = rust::Box; using Solution = concrete_optimizer::v0::Solution; diff --git a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp index a62ad49e8..039ca53e8 100644 --- a/compiler/lib/Bindings/Python/CompilerAPIModule.cpp +++ b/compiler/lib/Bindings/Python/CompilerAPIModule.cpp @@ -66,6 +66,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule( .def("set_strategy_v0", [](CompilationOptions &options, bool strategy_v0) { options.optimizerConfig.strategy_v0 = strategy_v0; + }) + .def("set_global_p_error", + [](CompilationOptions &options, double global_p_error) { + options.optimizerConfig.global_p_error = global_p_error; }); pybind11::class_( diff --git a/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py b/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py index ca5a4f28a..a07d628a8 100644 --- a/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py +++ b/compiler/lib/Bindings/Python/concrete/compiler/compilation_options.py @@ -135,7 +135,7 @@ class CompilationOptions(WrapperCpp): self.cpp().set_funcname(funcname) def set_p_error(self, p_error: float): - """Set global error probability for each pbs. + """Set error probability for shared by each pbs. Args: p_error (float): probability of error for each lut @@ -177,3 +177,21 @@ class CompilationOptions(WrapperCpp): if not isinstance(enable, bool): raise TypeError("enable should be a bool") self.cpp().set_strategy_v0(enable) + + def set_global_p_error(self, global_p_error: float): + """Set global error probability for the full circuit. + + Args: + global_p_error (float): probability of error for the full circuit + + Raises: + TypeError: if the value to set is not float + ValueError: if the value to set is not in interval ]0; 1[ + """ + if not isinstance(global_p_error, float): + raise TypeError("can't set global_p_error to a non-float value") + if global_p_error in (0.0, 1.0): + raise ValueError("global_p_error cannot be 0 or 1") + if not 0.0 <= global_p_error <= 1.0: + raise ValueError("global_p_error be a probability in ]0; 1[") + self.cpp().set_global_p_error(global_p_error) diff --git a/compiler/lib/Support/V0Parameters.cpp b/compiler/lib/Support/V0Parameters.cpp index cdc6b3958..7d53466ae 100644 --- a/compiler/lib/Support/V0Parameters.cpp +++ b/compiler/lib/Support/V0Parameters.cpp @@ -34,19 +34,52 @@ optimizer::DagSolution getV0Parameter(V0FHEConstraint constraint, return concrete_optimizer::utils::convert_to_dag_solution(solution); } +const int MAXIMUM_OPTIMIZER_CALL = 5; +optimizer::DagSolution getV1ParameterGlobalPError(optimizer::Dag &dag, + optimizer::Config config) { + // We find the approximate translation between local and global error with a + // calibration call + auto ref_p_error = std::min(config.p_error, config.global_p_error); + auto ref_global_p_success = 1.0 - config.global_p_error; + auto sol = dag->optimize(config.security, ref_p_error, + config.fallback_log_norm_woppbs); + for (int i = 2; i <= MAXIMUM_OPTIMIZER_CALL; i++) { + auto local_p_success = 1.0 - sol.p_error; + auto global_p_success = 1.0 - sol.global_p_error; + auto power_global_to_local = log(local_p_success) / log(global_p_success); + auto surrogate_p_local_success = + pow(ref_global_p_success, power_global_to_local); + config.p_error = 1.0 - surrogate_p_local_success; + sol = dag->optimize(config.security, config.p_error, + config.fallback_log_norm_woppbs); + if (sol.global_p_error <= config.global_p_error) { + break; + } + } + return sol; +} + optimizer::DagSolution getV1Parameter(optimizer::Dag &dag, optimizer::Config config) { + if (!std::isnan(config.global_p_error)) { + return getV1ParameterGlobalPError(dag, config); + } return dag->optimize(config.security, config.p_error, config.fallback_log_norm_woppbs); } -static void display(V0FHEConstraint constraint, +constexpr double WARN_ABOVE_GLOBAL_ERROR_RATE = 1.0 / 1000.0; + +static void display(optimizer::Description &descr, optimizer::Config optimizerConfig, - optimizer::DagSolution sol, + optimizer::DagSolution sol, bool naive_user, std::chrono::milliseconds duration) { if (!optimizerConfig.display && !mlir::concretelang::isVerbose()) { return; } + auto constraint = descr.constraint; + auto complexity_label = + descr.dag ? "for the full circuit" : "for each Pbs call"; auto o = llvm::outs; o() << "--- Circuit\n" << " " << constraint.p << " bits integers\n" @@ -54,12 +87,19 @@ static void display(V0FHEConstraint constraint, << " " << duration.count() << "ms to solve\n" << "--- Optimizer config\n" << " " << optimizerConfig.p_error << " error per pbs call\n" - << "--- Complexity for each Pbs call\n" + << " " << optimizerConfig.global_p_error << " error per circuit call\n" + << "--- Complexity " << complexity_label << "\n" << " " << (long)sol.complexity / (1000 * 1000) << " Millions Operations\n" << "--- Correctness for each Pbs call\n" - << " 1/" << int(1.0 / sol.p_error) << " errors (" << sol.p_error << ")\n" - << "--- Parameters resolution\n" + << " 1/" << int(1.0 / sol.p_error) << " errors (" << sol.p_error + << ")\n"; + if (descr.dag) { + o() << "--- Correctness for the full circuit\n" + << " 1/" << int(1.0 / sol.global_p_error) << " errors (" + << sol.global_p_error << ")\n"; + } + o() << "--- Parameters resolution\n" << " " << sol.glwe_dimension << "x glwe_dimension\n" << " 2**" << (size_t)std::log2l(sol.glwe_polynomial_size) << " polynomial (" << sol.glwe_polynomial_size << ")\n" @@ -74,12 +114,36 @@ static void display(V0FHEConstraint constraint, << sol.cb_decomposition_base_log << "\n"; } o() << "---\n"; + + if (descr.dag && naive_user && + sol.global_p_error > WARN_ABOVE_GLOBAL_ERROR_RATE) { + auto dominating_pbs = + (int)(log(1.0 - sol.global_p_error) / log(1.0 - sol.p_error)); + o() << "---\n" + << "!!!!! WARNING !!!!!\n" + << "\n" + << "HIGH ERROR RATE: 1/" << int(1.0 / sol.global_p_error) + << " errors \n\n" + << "Resolve by using command line option: \n" + << "--global-error-probability=" << WARN_ABOVE_GLOBAL_ERROR_RATE + << "\n\n" + << "Reason:\n" + << dominating_pbs << " pbs dominate at 1/" << int(1.0 / sol.p_error) + << " errors rate\n"; + o() << "\n!!!!!!!!!!!!!!!!!!!\n"; + } } llvm::Expected getParameter(optimizer::Description &descr, optimizer::Config config) { namespace chrono = std::chrono; auto start = chrono::high_resolution_clock::now(); + + auto naive_user = + std::isnan(config.p_error) && std::isnan(config.global_p_error); + if (std::isnan(config.p_error)) { + config.p_error = optimizer::P_ERROR_4_SIGMA; + } auto sol = (!descr.dag || config.strategy_v0) ? getV0Parameter(descr.constraint, config) : getV1Parameter(descr.dag.getValue(), config); @@ -91,13 +155,24 @@ llvm::Expected getParameter(optimizer::Description &descr, llvm::errs() << "concrete-optimizer time: " << duration_s.count() << "s\n"; } - display(descr.constraint, config, sol, duration); + display(descr, config, sol, naive_user, duration); - if (sol.p_error == 1.0) { - // The optimizer return a p_error = 1 if there is no solution + // The optimizer return a p_error = 1 if there is no solution + bool no_solution = sol.p_error == 1.0; + // The global_p_error is best effort only, so we must verify + bool bad_solution = !std::isnan(config.global_p_error) && + config.global_p_error < sol.global_p_error; + + if (no_solution || bad_solution) { return StreamStringError() << "Cannot find crypto parameters"; } + if (descr.dag && !config.display && naive_user && + sol.global_p_error > WARN_ABOVE_GLOBAL_ERROR_RATE) { + llvm::errs() << "WARNING: high error rate, more details with " + "--display-optimizer-choice\n"; + } + V0Parameter params; params.glweDimension = sol.glwe_dimension; params.logPolynomialSize = (size_t)std::log2l(sol.glwe_polynomial_size); diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index f7c8a3b12..4fe7625a9 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -182,6 +182,13 @@ llvm::cl::opt pbsErrorProbability( llvm::cl::desc("Change the default probability of error for all pbs"), llvm::cl::init(mlir::concretelang::optimizer::DEFAULT_CONFIG.p_error)); +llvm::cl::opt globalErrorProbability( + "global-error-probability", + llvm::cl::desc( + "Use global error probability (override pbs error probability)"), + llvm::cl::init( + mlir::concretelang::optimizer::DEFAULT_CONFIG.global_p_error)); + llvm::cl::opt displayOptimizerChoice( "display-optimizer-choice", llvm::cl::desc("Display the information returned by the optimizer"), @@ -342,10 +349,18 @@ cmdlineCompilationOptions() { cmdline::largeIntegerCircuitBootstrap[1]; } + options.optimizerConfig.global_p_error = cmdline::globalErrorProbability; options.optimizerConfig.p_error = cmdline::pbsErrorProbability; options.optimizerConfig.display = cmdline::displayOptimizerChoice; options.optimizerConfig.strategy_v0 = cmdline::optimizerV0; + if (!std::isnan(options.optimizerConfig.global_p_error) && + options.optimizerConfig.strategy_v0) { + return llvm::make_error( + "--global-error-probability is not compatible with --optimizer-v0", + llvm::inconvertibleErrorCode()); + } + return options; } diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 5c2e6c12f..0b953f167 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -223,6 +223,23 @@ def test_lib_compile_and_run_p_error(keyset_cache): compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options) +def test_lib_compile_and_run_p_error(keyset_cache): + mlir_input = """ + func.func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> { + %tlu = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127]> : tensor<128xi64> + %1 = "FHE.apply_lookup_table"(%arg0, %tlu): (!FHE.eint<7>, tensor<128xi64>) -> (!FHE.eint<7>) + return %1: !FHE.eint<7> + } + """ + args = (73,) + expected_result = 73 + engine = LibrarySupport.new("./py_test_lib_compile_and_run_custom_perror") + options = CompilationOptions.new("main") + options.set_global_p_error(0.00001) + options.set_display_optimizer_choice(True) + compile_run_assert(engine, mlir_input, args, expected_result, keyset_cache, options) + + @pytest.mark.parallel @pytest.mark.parametrize( "mlir_input, args, expected_result", end_to_end_parallel_fixture