From cc6c2576ec6c97f3446a24d394ef58eaa92821b3 Mon Sep 17 00:00:00 2001 From: rudy Date: Fri, 27 May 2022 18:18:10 +0200 Subject: [PATCH] feat(optimizer): create optimizer dag and use it --- compiler/CMakeLists.txt | 3 +- compiler/concrete-optimizer | 2 +- .../Dialect/FHE/Analysis/CMakeLists.txt | 7 + .../Dialect/FHE/Analysis/ConcreteOptimizer.h | 29 ++ .../Dialect/FHE/Analysis/ConcreteOptimizer.td | 15 + .../concretelang/Dialect/FHE/Analysis/MANP.h | 2 + .../concretelang/Dialect/FHE/Analysis/utils.h | 24 ++ .../concretelang/Support/CompilerEngine.h | 4 +- .../include/concretelang/Support/Pipeline.h | 7 +- .../concretelang/Support/V0Parameters.h | 23 +- .../lib/Dialect/FHE/Analysis/CMakeLists.txt | 2 + .../FHE/Analysis/ConcreteOptimizer.cpp | 329 ++++++++++++++++++ compiler/lib/Dialect/FHE/Analysis/MANP.cpp | 34 +- compiler/lib/Dialect/FHE/Analysis/utils.cpp | 51 +++ compiler/lib/Support/CMakeLists.txt | 3 - compiler/lib/Support/CompilerEngine.cpp | 92 +++-- compiler/lib/Support/Pipeline.cpp | 39 ++- compiler/lib/Support/V0ClientParameters.cpp | 3 +- compiler/lib/Support/V0Parameters.cpp | 41 ++- compiler/src/main.cpp | 12 + .../TFHEGlobalParametrization/pbs_ks_bs.mlir | 2 +- .../Dialect/FHE/eint_error_p_too_big.mlir | 5 +- .../end_to_end_fixture/EndToEndFixture.cpp | 4 +- .../tests/end_to_end_tests/CMakeLists.txt | 1 + .../end_to_end_tests/end_to_end_jit_test.h | 4 +- compiler/tests/python/test_compilation.py | 2 +- 26 files changed, 628 insertions(+), 112 deletions(-) create mode 100644 compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h create mode 100644 compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.td create mode 100644 compiler/include/concretelang/Dialect/FHE/Analysis/utils.h create mode 100644 compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp create mode 100644 compiler/lib/Dialect/FHE/Analysis/utils.cpp diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index c7de48b93..efe9dfb6c 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -9,7 +9,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Wouldn't be able to compile LLVM without this on Mac (using either Clang or AppleClang) if (APPLE) - add_definitions("-Wno-narrowing") + add_definitions("-Wno-narrowing -Wno-dollar-in-identifier-extension") endif() # If we are trying to build the compiler with LLVM/MLIR as libraries @@ -116,6 +116,7 @@ option(CONCRETELANG_BENCHMARK "Enables the build of benchmarks" ON) #------------------------------------------------------------------------------- # Handling sub dirs #------------------------------------------------------------------------------- +include_directories(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/src/cpp) add_subdirectory(include) add_subdirectory(lib) diff --git a/compiler/concrete-optimizer b/compiler/concrete-optimizer index bc52e3cd2..b446d3124 160000 --- a/compiler/concrete-optimizer +++ b/compiler/concrete-optimizer @@ -1 +1 @@ -Subproject commit bc52e3cd2185ff20d2315a496f2d043ae7a02fa7 +Subproject commit b446d3124d89e0e5783df947770ecec19e7a6582 diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt index 341e51e21..3bde2e875 100644 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt @@ -4,3 +4,10 @@ mlir_tablegen(MANP.capi.h.inc -gen-pass-capi-header --prefix Analysis) mlir_tablegen(MANP.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis) add_public_tablegen_target(MANPPassIncGen) add_dependencies(mlir-headers MANPPassIncGen) + +set(LLVM_TARGET_DEFINITIONS ConcreteOptimizer.td) +mlir_tablegen(ConcreteOptimizer.h.inc -gen-pass-decls -name Analysis) +mlir_tablegen(ConcreteOptimizer.capi.h.inc -gen-pass-capi-header --prefix Analysis) +mlir_tablegen(ConcreteOptimizer.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis) +add_public_tablegen_target(ConcreteOptimizerPassIncGen) +add_dependencies(mlir-headers ConcreteOptimizerPassIncGen) diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h b/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h new file mode 100644 index 000000000..0ad7746c4 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h @@ -0,0 +1,29 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER_H +#define CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER_H + +#include +#include + +#include "concrete-optimizer.hpp" + +#include "concretelang/Support/V0Parameters.h" + +namespace mlir { +namespace concretelang { + +namespace optimizer { +using FunctionsDag = std::map>; + +std::unique_ptr createDagPass(optimizer::Config config, + optimizer::FunctionsDag &dags); + +} // namespace optimizer +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.td b/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.td new file mode 100644 index 000000000..2b112edf8 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.td @@ -0,0 +1,15 @@ +#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER +#define CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER + +include "mlir/Pass/PassBase.td" + +def ConcreteOptimizer : Pass<"ConcreteOptimizer", "::mlir::func::FuncOp"> { + let summary = "Call concrete-optimizer"; + let description = [{ + The pass calls the concrete-optimizer to provide crypto parameter. + It construct a simplified representation of the FHE circuit and send it to the concrete optimizer. + It uses on the values from the MANP pass to indicate how noise is propagate and amplified in levelled operations. + }]; +} + +#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h b/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h index bac220628..a88710bae 100644 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/MANP.h @@ -11,6 +11,8 @@ namespace mlir { namespace concretelang { +bool isEncryptedValue(mlir::Value value); +unsigned int getEintPrecision(mlir::Value value); std::unique_ptr createMANPPass(bool debug = false); std::unique_ptr diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h b/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h new file mode 100644 index 000000000..0da6574c2 --- /dev/null +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h @@ -0,0 +1,24 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H +#define CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H + +#include + +namespace mlir { +namespace concretelang { +namespace fhe { +namespace utils { + +bool isEncryptedValue(mlir::Value value); +unsigned int getEintPrecision(mlir::Value value); + +} // namespace utils +} // namespace fhe +} // namespace concretelang +} // namespace mlir + +#endif diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 7164641b6..3b753a63e 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -243,8 +243,8 @@ protected: std::shared_ptr compilationContext; private: - llvm::Expected> - getV0FHEConstraint(CompilationResult &res); + llvm::Expected> + getConcreteOptimizerDescription(CompilationResult &res); llvm::Error determineFHEParameters(CompilationResult &res); }; diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index e87f89fcc..a888e5ca5 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -20,9 +20,10 @@ namespace pipeline { mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); -llvm::Expected> -getFHEConstraintsFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass); +llvm::Expected>> +getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + optimizer::Config config, + std::function enablePass); mlir::LogicalResult tileMarkedFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compiler/include/concretelang/Support/V0Parameters.h b/compiler/include/concretelang/Support/V0Parameters.h index b7273077e..865411574 100644 --- a/compiler/include/concretelang/Support/V0Parameters.h +++ b/compiler/include/concretelang/Support/V0Parameters.h @@ -8,6 +8,7 @@ #include "llvm/ADT/Optional.h" +#include "concrete-optimizer.hpp" #include "concretelang/Conversion/Utils/GlobalFHEContext.h" namespace mlir { @@ -15,16 +16,30 @@ namespace concretelang { namespace optimizer { constexpr double P_ERROR_4_SIGMA = 1.0 - 0.999936657516; +constexpr uint DEFAULT_SECURITY = 128; + struct Config { double p_error; bool display; + bool strategy_v0; + std::uint64_t security; }; -constexpr Config DEFAULT_CONFIG = {P_ERROR_4_SIGMA, false}; +constexpr Config DEFAULT_CONFIG = {P_ERROR_4_SIGMA, false, false, + DEFAULT_SECURITY}; + +using Dag = rust::Box; +using Solution = concrete_optimizer::v0::Solution; + +/* Contains any circuit description usable by the concrete-optimizer */ +struct Description { + V0FHEConstraint constraint; + llvm::Optional dag; +}; + } // namespace optimizer -llvm::Optional getV0Parameter(V0FHEConstraint constraint, - optimizer::Config optimizerConfig); - +llvm::Optional getParameter(optimizer::Description &descr, + optimizer::Config optimizerConfig); } // namespace concretelang } // namespace mlir #endif diff --git a/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt b/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt index 057c086d4..62de96c07 100644 --- a/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt +++ b/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt @@ -1,4 +1,6 @@ add_mlir_library(FHEDialectAnalysis + utils.cpp + ConcreteOptimizer.cpp MANP.cpp ADDITIONAL_HEADER_DIRS diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp new file mode 100644 index 000000000..727aa70c5 --- /dev/null +++ b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp @@ -0,0 +1,329 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include + +#include "boost/outcome.h" + +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Pass/PassManager.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Pass.h" +#include "llvm/Support/raw_ostream.h" + +#include "concrete-optimizer.hpp" + +#include "concretelang/Common/Error.h" +#include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h" +#include "concretelang/Dialect/FHE/Analysis/utils.h" +#include "concretelang/Dialect/FHE/IR/FHEOps.h" +#include "concretelang/Dialect/FHE/IR/FHETypes.h" +#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" +#include "concretelang/Support/V0Parameters.h" +#include "concretelang/Support/logging.h" + +#define GEN_PASS_CLASSES +#include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h.inc" + +namespace mlir { +namespace concretelang { +namespace optimizer { + +namespace { + +template rust::Slice slice(const std::vector &vec) { + return rust::Slice(vec.data(), vec.size()); +} + +template rust::Slice slice(const llvm::ArrayRef &vec) { + return rust::Slice(vec.data(), vec.size()); +} + +struct FunctionToDag { + // Inputs of operators + using Inputs = std::vector; + + const double NEGLIGIBLE_COMPLEXITY = 0.0; + + mlir::func::FuncOp func; + optimizer::Config config; + llvm::DenseMap index; + + FunctionToDag(mlir::func::FuncOp func, optimizer::Config config) + : func(func), config(config) {} + +#define DEBUG(MSG) \ + if (mlir::concretelang::isVerbose()) { \ + mlir::concretelang::log_verbose() << MSG << "\n"; \ + } + + outcome::checked, + ::concretelang::error::StringError> + build() { + auto dag = concrete_optimizer::dag::empty(); + // Converting arguments as Input + for (auto &arg : func.getArguments()) { + addArg(dag, arg); + } + // Converting ops + for (auto &bb : func.getBody().getBlocks()) { + for (auto &op : bb.getOperations()) { + addOperation(dag, op); + } + } + if (index.empty()) { + // Dag is empty <=> classical function without encryption + DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName() + << "\n"); + return llvm::None; + }; + DEBUG(std::string(dag->dump())); + return std::move(dag); + } + + void addArg(optimizer::Dag &dag, mlir::Value &arg) { + DEBUG("Arg " << arg << " " << arg.getType()); + if (!fhe::utils::isEncryptedValue(arg)) { + return; + } + auto precision = fhe::utils::getEintPrecision(arg); + auto shape = getShape(arg); + auto opI = dag->add_input(precision, slice(shape)); + index[arg] = opI; + } + + bool hasEncryptedResult(mlir::Operation &op) { + for (auto val : op.getResults()) { + if (fhe::utils::isEncryptedValue(val)) { + return true; + } + } + return false; + } + + void addOperation(optimizer::Dag &dag, mlir::Operation &op) { + DEBUG("Instr " << op); + + if (isReturn(op)) { + // This op has no result + return; + } + + auto encrypted_inputs = encryptedInputs(op); + if (!hasEncryptedResult(op)) { + // This op is unrelated to FHE + assert(encrypted_inputs.empty()); + return; + } + assert(op.getNumResults() == 1); + auto val = op.getResult(0); + auto precision = fhe::utils::getEintPrecision(val); + if (isLut(op)) { + addLut(dag, val, encrypted_inputs, precision); + return; + } + if (auto dot = asDot(op)) { + auto weightsOpt = dotWeights(dot); + if (weightsOpt) { + addDot(dag, val, encrypted_inputs, weightsOpt.getValue()); + return; + } + DEBUG("Replace Dot by LevelledOp on " << op); + } + // default + addLevelledOp(dag, op, encrypted_inputs); + } + + void addLut(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, + int precision) { + assert(encrypted_inputs.size() == 1); + // No need to distinguish different lut kind until we do approximate + // paradigm on outputs + auto encrypted_input = encrypted_inputs[0]; + std::vector unknowFunction; + index[val] = + dag->add_lut(encrypted_input, slice(unknowFunction), precision); + } + + void addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, + std::vector &weights_vector) { + assert(encrypted_inputs.size() == 1); + auto weights = concrete_optimizer::weights::vector(slice(weights_vector)); + index[val] = dag->add_dot(slice(encrypted_inputs), std::move(weights)); + } + + std::string loc_to_string(mlir::Location location) { + std::string loc; + llvm::raw_string_ostream loc_stream(loc); + location.print(loc_stream); + return loc; + } + + void addLevelledOp(optimizer::Dag &dag, mlir::Operation &op, Inputs &inputs) { + auto val = op.getResult(0); + auto out_shape = getShape(val); + if (inputs.empty()) { + // Trivial encrypted constants encoding + // There are converted to input + levelledop + auto precision = fhe::utils::getEintPrecision(val); + auto opI = dag->add_input(precision, slice(out_shape)); + inputs.push_back(opI); + } + // Default complexity is negligible + double fixed_cost = NEGLIGIBLE_COMPLEXITY; + double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY; + auto manp_int = op.getAttrOfType("MANP"); + auto loc = loc_to_string(op.getLoc()); + if (!manp_int) { + DEBUG("Cannot read manp on " << op << "\n" << loc); + } + assert(manp_int && "Missing manp value on a crypto operation"); + double manp = (double)manp_int.getValue().getZExtValue(); + auto comment = std::string(op.getName().getStringRef()) + " " + loc; + index[val] = + dag->add_levelled_op(slice(inputs), lwe_dim_cost_factor, fixed_cost, + manp, slice(out_shape), comment); + } + + Inputs encryptedInputs(mlir::Operation &op) { + Inputs inputs; + for (auto operand : op.getOperands()) { + auto entry = index.find(operand); + if (entry == index.end()) { + assert(!fhe::utils::isEncryptedValue(operand)); + DEBUG("Ignoring as input " << operand); + continue; + } + inputs.push_back(entry->getSecond()); + } + return inputs; + } + + bool isLut(mlir::Operation &op) { + return llvm::isa< + mlir::concretelang::FHE::ApplyLookupTableEintOp, + mlir::concretelang::FHELinalg::ApplyLookupTableEintOp, + mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp, + mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(op); + } + + mlir::concretelang::FHELinalg::Dot asDot(mlir::Operation &op) { + return llvm::dyn_cast(op); + } + + bool isReturn(mlir::Operation &op) { + return llvm::isa(op); + } + + bool isConst(mlir::Operation &op) { + return llvm::isa(op); + } + + bool isArg(const mlir::Value &value) { + return value.isa(); + } + + std::vector + resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) { + std::vector values; + mlir::DenseIntElementsAttr denseVals = + cstOp->getAttrOfType("value"); + + for (llvm::APInt val : denseVals.getValues()) { + values.push_back(val.getZExtValue()); + } + return values; + } + + llvm::Optional> + resolveConstantWeights(mlir::Value &value) { + if (auto cstOp = llvm::dyn_cast_or_null( + value.getDefiningOp())) { + auto shape = getShape(value); + switch (shape.size()) { + case 1: + return resolveConstantVectorWeights(cstOp); + default: + DEBUG("High-Rank tensor: rely on MANP and levelledOp"); + return llvm::None; + } + } else { + DEBUG("Dynamic Weights: rely on MANP and levelledOp"); + return llvm::None; + } + } + + llvm::Optional> + dotWeights(mlir::concretelang::FHELinalg::Dot &dot) { + if (dot.getOperands().size() != 2) { + return llvm::None; + } + auto weights = dot.getOperands()[1]; + return resolveConstantWeights(weights); + } + + std::vector getShape(mlir::Value &value) { + return getShape(value.getType()); + } + + std::vector getShape(mlir::Type type_) { + if (auto ranked_tensor = type_.dyn_cast_or_null()) { + std::vector shape; + for (auto v : ranked_tensor.getShape()) { + shape.push_back(v); + } + return shape; + } else { + return {}; + } + } +}; + +} // namespace + +struct DagPass : ConcreteOptimizerBase { + optimizer::Config config; + optimizer::FunctionsDag &dags; + + void runOnOperation() override { + mlir::func::FuncOp func = getOperation(); + auto name = std::string(func.getName()); + DEBUG("ConcreteOptimizer Dag: " << name); + if (config.strategy_v0) { + // we avoid building the dag since it's not used in this case + // so strategy_v0 can be used to avoid dag creation issues + dags.insert(optimizer::FunctionsDag::value_type(name, llvm::None)); + } + auto dag = FunctionToDag(func, config).build(); + if (dag) { + dags.insert( + optimizer::FunctionsDag::value_type(name, std::move(dag.value()))); + } else { + this->signalPassFailure(); + } + } + + DagPass() = delete; + DagPass(optimizer::Config config, optimizer::FunctionsDag &dags) + : config(config), dags(dags) {} +}; + +// Create an instance of the ConcreteOptimizerPass pass. +// A global pass result is communicated using `dags`. +// If `debug` is true, for each operation, the pass emits a +// remark containing the squared Minimal Arithmetic Noise Padding of +// the equivalent dot operation. +std::unique_ptr createDagPass(optimizer::Config config, + optimizer::FunctionsDag &dags) { + return std::make_unique(config, dags); +} + +} // namespace optimizer +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 39bce1179..85c66b201 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -4,6 +4,7 @@ // for license information. #include +#include #include #include #include @@ -45,36 +46,7 @@ static bool isEncryptedFunctionParameter(mlir::Value value) { return false; } - return ( - value.getType().isa() || - (value.getType().isa() && - value.getType() - .cast() - .getElementType() - .isa())); -} - -/// Returns the bit width of `value` if `value` is an encrypted integer -/// or the bit width of the elements if `value` is a tensor of -/// encrypted integers. -static unsigned int getEintPrecision(mlir::Value value) { - if (auto ty = value.getType() - .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>()) { - return ty.getWidth(); - } else if (auto tensorTy = - value.getType().dyn_cast_or_null()) { - if (auto ty = tensorTy.getElementType() - .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>()) - return ty.getWidth(); - } - - assert(false && - "Value is neither an encrypted integer nor a tensor of encrypted " - "integers"); - - return 0; + return mlir::concretelang::fhe::utils::isEncryptedValue(value); } /// The `MANPLatticeValue` represents the squared Minimal Arithmetic @@ -1573,7 +1545,7 @@ protected: llvm::dyn_cast_or_null(op)) { for (mlir::BlockArgument blockArg : func.getBody().getArguments()) { if (isEncryptedFunctionParameter(blockArg)) { - unsigned int width = getEintPrecision(blockArg); + unsigned int width = fhe::utils::getEintPrecision(blockArg); if (this->maxEintWidth < width) { this->maxEintWidth = width; diff --git a/compiler/lib/Dialect/FHE/Analysis/utils.cpp b/compiler/lib/Dialect/FHE/Analysis/utils.cpp new file mode 100644 index 000000000..d0d3b270c --- /dev/null +++ b/compiler/lib/Dialect/FHE/Analysis/utils.cpp @@ -0,0 +1,51 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include "concretelang/Dialect/FHE/IR/FHETypes.h" +#include + +namespace mlir { +namespace concretelang { +namespace fhe { +namespace utils { +/// Returns `true` if the given value is a scalar or tensor argument of +/// a function, for which a MANP of 1 can be assumed. +bool isEncryptedValue(mlir::Value value) { + return ( + value.getType().isa() || + (value.getType().isa() && + value.getType() + .cast() + .getElementType() + .isa())); +} + +/// Returns the bit width of `value` if `value` is an encrypted integer +/// or the bit width of the elements if `value` is a tensor of +/// encrypted integers. +unsigned int getEintPrecision(mlir::Value value) { + if (auto ty = value.getType() + .dyn_cast_or_null< + mlir::concretelang::FHE::EncryptedIntegerType>()) { + return ty.getWidth(); + } else if (auto tensorTy = + value.getType().dyn_cast_or_null()) { + if (auto ty = tensorTy.getElementType() + .dyn_cast_or_null< + mlir::concretelang::FHE::EncryptedIntegerType>()) + return ty.getWidth(); + } + + assert(false && + "Value is neither an encrypted integer nor a tensor of encrypted " + "integers"); + + return 0; +} + +} // namespace utils +} // namespace fhe +} // namespace concretelang +} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 836f89d9b..9cca7dbfc 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,6 +1,3 @@ -# not working in ADDITIONAL_HEADER_DIRS -include_directories(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/src/cpp) - add_mlir_library(ConcretelangSupport Pipeline.cpp Jit.cpp diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 9da028caf..e91bdc0ee 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -25,6 +25,7 @@ #include #include +#include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include #include #include @@ -122,63 +123,81 @@ void CompilerEngine::setEnablePass( this->enablePass = enablePass; } -/// Returns the overwritten V0FHEConstraint or try to compute them from FHE -llvm::Expected> -CompilerEngine::getV0FHEConstraint(CompilationResult &res) { +/// Returns the optimizer::Description +llvm::Expected> +CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) { mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); mlir::ModuleOp module = res.mlirModuleRef->get(); // If the values has been overwritten returns if (this->overrideMaxEintPrecision.hasValue() && this->overrideMaxMANP.hasValue()) { - return mlir::concretelang::V0FHEConstraint{ + auto constraint = mlir::concretelang::V0FHEConstraint{ this->overrideMaxMANP.getValue(), this->overrideMaxEintPrecision.getValue()}; + return optimizer::Description{constraint, llvm::None}; } - // Else compute constraint from FHE - llvm::Expected> - fheConstraintsOrErr = - mlir::concretelang::pipeline::getFHEConstraintsFromFHE( - mlirContext, module, enablePass); - - if (auto err = fheConstraintsOrErr.takeError()) + auto config = this->compilerOptions.optimizerConfig; + auto descriptions = mlir::concretelang::pipeline::getFHEContextFromFHE( + mlirContext, module, config, enablePass); + if (auto err = descriptions.takeError()) { return std::move(err); - - return fheConstraintsOrErr.get(); + } + if (descriptions->empty()) { // The pass has not been run + return llvm::None; + } + if (this->compilerOptions.clientParametersFuncName.hasValue()) { + auto name = this->compilerOptions.clientParametersFuncName.getValue(); + auto description = descriptions->find(name); + if (description == descriptions->end()) { + std::string names; + for (auto &entry : *descriptions) { + names += "'" + entry.first + "' "; + } + return StreamStringError() + << "Could not find existing crypto parameters for function '" + << name << "' (known functions: " << names << ")"; + } + return std::move(description->second); + } + if (descriptions->size() != 1) { + llvm::errs() << "Several crypto parameters exists: the function need to be " + "specified, taking the first one"; + } + return std::move(descriptions->begin()->second); } /// set the fheContext field if the v0Constraint can be computed llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { - auto fheConstraintOrErr = getV0FHEConstraint(res); - if (auto err = fheConstraintOrErr.takeError()) + auto descrOrErr = getConcreteOptimizerDescription(res); + if (auto err = descrOrErr.takeError()) { return err; - if (!fheConstraintOrErr.get().hasValue()) { + } + // The function is non-crypto and without constraint override + if (!descrOrErr.get().hasValue()) { return llvm::Error::success(); } - llvm::Optional v0Params; - if (compilerOptions.v0Parameter.hasValue()) { - v0Params = compilerOptions.v0Parameter; - } else { - v0Params = getV0Parameter(fheConstraintOrErr.get().getValue(), - this->compilerOptions.optimizerConfig); + auto descr = std::move(descrOrErr.get().getValue()); + auto config = this->compilerOptions.optimizerConfig; - if (!v0Params) { - return StreamStringError() - << "Could not determine V0 parameters for 2-norm of " - << (*fheConstraintOrErr)->norm2 << " and p of " - << (*fheConstraintOrErr)->p; - } + auto fheParams = (compilerOptions.v0Parameter.hasValue()) + ? compilerOptions.v0Parameter + : getParameter(descr, config); + if (!fheParams) { + return StreamStringError() + << "Could not determine V0 parameters for 2-norm of " + << (*descrOrErr)->constraint.norm2 << " and p of " + << (*descrOrErr)->constraint.p; } - res.fheContext.emplace(mlir::concretelang::V0FHEContext{ - (*fheConstraintOrErr).getValue(), v0Params.getValue()}); - + res.fheContext.emplace( + mlir::concretelang::V0FHEContext{descr.constraint, fheParams.getValue()}); return llvm::Error::success(); } using OptionalLib = llvm::Optional>; -/// Compile the sources managed by the source manager `sm` to the -/// target dialect `target`. If successful, the result can be retrieved -/// using `getModule()` and `getLLVMModule()`, respectively depending -/// on the target dialect. +// Compile the sources managed by the source manager `sm` to the +// target dialect `target`. If successful, the result can be retrieved +// using `getModule()` and `getLLVMModule()`, respectively depending +// on the target dialect. llvm::Expected CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { std::unique_ptr smHandler; @@ -297,7 +316,8 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { } if (!res.fheContext.hasValue()) { return StreamStringError( - "Cannot generate client parameters, the fhe context is empty"); + "Cannot generate client parameters, the fhe context is empty for " + + options.clientParametersFuncName.getValue()); } } // Generate client parameters if requested diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 775cab7d2..74a46068b 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -27,9 +27,12 @@ #include #include +#include "concretelang/Support/CompilerEngine.h" +#include "concretelang/Support/Error.h" #include #include #include +#include #include #include #include @@ -73,11 +76,13 @@ addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr pass, } } -llvm::Expected> -getFHEConstraintsFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, - std::function enablePass) { +llvm::Expected>> +getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + optimizer::Config config, + std::function enablePass) { llvm::Optional oMax2norm; llvm::Optional oMaxWidth; + optimizer::FunctionsDag dags; mlir::PassManager pm(&context); @@ -109,18 +114,36 @@ getFHEConstraintsFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, if (pm.run(module.getOperation()).failed()) { return llvm::make_error( "Failed to determine the maximum Arithmetic Noise Padding and maximum" - "required precision", + " required precision", llvm::inconvertibleErrorCode()); } - llvm::Optional ret; + llvm::Optional constraint = llvm::None; if (oMax2norm.hasValue() && oMaxWidth.hasValue()) { - ret = llvm::Optional( + constraint = llvm::Optional( {/*.norm2 = */ ceilLog2(oMax2norm.getValue()), /*.p = */ oMaxWidth.getValue()}); } - - return ret; + addPotentiallyNestedPass(pm, optimizer::createDagPass(config, dags), + enablePass); + if (pm.run(module.getOperation()).failed()) { + return StreamStringError() << "Failed to create concrete-optimizer dag\n"; + } + std::map> descriptions; + for (auto &entry_dag : dags) { + if (!constraint) { + descriptions.insert( + decltype(descriptions)::value_type(entry_dag.first, llvm::None)); + continue; + } + optimizer::Description description = {*constraint, + std::move(entry_dag.second)}; + llvm::Optional opt_description{ + std::move(description)}; + descriptions.insert(decltype(descriptions)::value_type( + entry_dag.first, std::move(opt_description))); + } + return std::move(descriptions); } mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index d8c0fe362..ae05933d9 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -139,7 +139,8 @@ createClientParametersForV0(V0FHEContext fheContext, }); if (funcOp == rangeOps.end()) { return llvm::make_error( - "cannot find the function for generate client parameters", + "cannot find the function for generate client parameters '" + + functionName + "'", llvm::inconvertibleErrorCode()); } diff --git a/compiler/lib/Support/V0Parameters.cpp b/compiler/lib/Support/V0Parameters.cpp index 844f9f886..92bf183ab 100644 --- a/compiler/lib/Support/V0Parameters.cpp +++ b/compiler/lib/Support/V0Parameters.cpp @@ -17,15 +17,30 @@ #include "concrete-optimizer.hpp" #include "concretelang/Support/V0Parameters.h" +#include namespace mlir { namespace concretelang { +optimizer::Solution getV0Parameter(V0FHEConstraint constraint, + optimizer::Config config) { + // the norm2 0 is equivalent to a maximum noise_factor of 2.0 + // norm2 = 0 ==> 1.0 =< noise_factor < 2.0 + // norm2 = k ==> 2^norm2 =< noise_factor < 2.0^norm2 + 1 + double noise_factor = std::exp2(constraint.norm2 + 1); + return concrete_optimizer::v0::optimize_bootstrap( + constraint.p, config.security, noise_factor, config.p_error); +} + +optimizer::Solution getV1Parameter(optimizer::Dag &dag, + optimizer::Config config) { + return dag->optimize_v0(config.security, config.p_error); +} + static void display(V0FHEConstraint constraint, - optimizer::Config optimizerConfig, - concrete_optimizer::v0::Solution sol, + optimizer::Config optimizerConfig, optimizer::Solution sol, std::chrono::milliseconds duration) { - if (!optimizerConfig.display) { + if (!optimizerConfig.display && !mlir::concretelang::isVerbose()) { return; } auto o = llvm::outs; @@ -54,19 +69,15 @@ static void display(V0FHEConstraint constraint, << "---\n"; } -llvm::Optional getV0Parameter(V0FHEConstraint constraint, - optimizer::Config optimizerConfig) { +llvm::Optional getParameter(optimizer::Description &descr, + optimizer::Config config) { namespace chrono = std::chrono; - int security = 128; - // the norm2 0 is equivalent to a maximum noise_factor of 2.0 - // norm2 = 0 ==> 1.0 =< noise_factor < 2.0 - // norm2 = k ==> 2^norm2 =< noise_factor < 2.0^norm2 + 1 - double noise_factor = std::exp2(constraint.norm2 + 1); - // https://github.com/zama-ai/concrete-optimizer/blob/prototype/python/optimizer/V0Parameters/tabulation.py#L58 - double p_error = optimizerConfig.p_error; auto start = chrono::high_resolution_clock::now(); - auto sol = concrete_optimizer::v0::optimize_bootstrap(constraint.p, security, - noise_factor, p_error); + + auto sol = (!descr.dag || config.strategy_v0) + ? getV0Parameter(descr.constraint, config) + : getV1Parameter(descr.dag.getValue(), config); + auto stop = chrono::high_resolution_clock::now(); if (sol.p_error == 1.0) { // The optimizer return a p_error = 1 if there is no solution @@ -78,7 +89,7 @@ llvm::Optional getV0Parameter(V0FHEConstraint constraint, llvm::errs() << "concrete-optimizer time: " << duration_s.count() << "s\n"; } - display(constraint, optimizerConfig, sol, duration); + display(descr.constraint, config, sol, duration); return mlir::concretelang::V0Parameter{ sol.glwe_dimension, diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 4d41d0d15..2d4a59cb9 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -186,6 +186,11 @@ llvm::cl::opt displayOptimizerChoice( llvm::cl::desc("Display the information returned by the optimizer"), llvm::cl::init(false)); +llvm::cl::opt + optimizerV0("optimizer-v0", + llvm::cl::desc("Select the v0 parameters strategy"), + llvm::cl::init(false)); + llvm::cl::list fhelinalgTileSizes( "fhelinalg-tile-sizes", llvm::cl::desc( @@ -265,8 +270,15 @@ cmdlineCompilationOptions() { cmdline::v0Parameter[6]); } + if (!cmdline::v0Constraint.empty() && !cmdline::optimizerV0) { + return llvm::make_error( + "You must use --v0-constraint with --optimizer-v0-strategy", + llvm::inconvertibleErrorCode()); + } + options.optimizerConfig.p_error = cmdline::pbsErrorProbability; options.optimizerConfig.display = cmdline::displayOptimizerChoice; + options.optimizerConfig.strategy_v0 = cmdline::optimizerV0; return options; } diff --git a/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir b/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir index 120140a95..61bc3e5ff 100644 --- a/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-global-parametrization --action=dump-std --v0-parameter=2,10,750,1,23,3,4 -v0-constraint=4,0 %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-global-parametrization --action=dump-std --optimizer-v0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s //CHECK: func.func @main(%[[A0:.*]]: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> { //CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> diff --git a/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir b/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir index d567d08ef..96333083f 100644 --- a/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir @@ -1,6 +1,7 @@ // RUN: not concretecompiler --action=dump-llvm-ir %s 2>&1| FileCheck %s // CHECK-LABEL: Could not determine V0 parameters -func.func @test(%arg0: !FHE.eint<9>) { - return +func.func @test(%arg0: !FHE.eint<9>, %arg1: tensor<512xi64>) -> !FHE.eint<9> { + %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<9>, tensor<512xi64>) -> (!FHE.eint<9>) + return %1 : !FHE.eint<9> } diff --git a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp index c3bc187b8..9fc6baba4 100644 --- a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp +++ b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp @@ -165,7 +165,7 @@ template <> struct llvm::yaml::MappingTraits { } }; -LLVM_YAML_IS_SEQUENCE_VECTOR(ValueDescription); +LLVM_YAML_IS_SEQUENCE_VECTOR(ValueDescription) template <> struct llvm::yaml::MappingTraits { static void mapping(IO &io, TestDescription &desc) { @@ -174,7 +174,7 @@ template <> struct llvm::yaml::MappingTraits { } }; -LLVM_YAML_IS_SEQUENCE_VECTOR(TestDescription); +LLVM_YAML_IS_SEQUENCE_VECTOR(TestDescription) template <> struct llvm::yaml::MappingTraits { static void mapping(IO &io, EndToEndDesc &desc) { diff --git a/compiler/tests/end_to_end_tests/CMakeLists.txt b/compiler/tests/end_to_end_tests/CMakeLists.txt index d5ce1787f..3213621cb 100644 --- a/compiler/tests/end_to_end_tests/CMakeLists.txt +++ b/compiler/tests/end_to_end_tests/CMakeLists.txt @@ -7,6 +7,7 @@ function(add_concretecompiler_unittest test_name) endfunction() include_directories(${PROJECT_SOURCE_DIR}/include) +include_directories(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/src/cpp) if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") link_libraries( diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_test.h b/compiler/tests/end_to_end_tests/end_to_end_jit_test.h index c02f3656e..00d84b295 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_test.h +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_test.h @@ -23,8 +23,10 @@ internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main", auto options = mlir::concretelang::CompilationOptions(std::string(func.data())); - if (useDefaultFHEConstraints) + if (useDefaultFHEConstraints) { options.v0FHEConstraints = defaultV0Constraints; + options.optimizerConfig.strategy_v0 = true; + } // Allow loop parallelism in all cases options.loopParallelize = loopParallelize; diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 47047ce8b..5c2e6c12f 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -350,6 +350,6 @@ def test_compile_and_run_invalid_arg_number( def test_compile_invalid(mlir_input): engine = JITSupport.new() with pytest.raises( - RuntimeError, match=r"cannot find the function for generate client parameters" + RuntimeError, match=r"Could not find existing crypto parameters for" ): engine.compile(mlir_input)