diff --git a/compiler/CMakeLists.txt b/compiler/CMakeLists.txt index d722bff59..46cd15469 100644 --- a/compiler/CMakeLists.txt +++ b/compiler/CMakeLists.txt @@ -9,7 +9,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # Wouldn't be able to compile LLVM without this on Mac (using either Clang or AppleClang) if (APPLE) - add_definitions("-Wno-narrowing -Wno-dollar-in-identifier-extension") + add_definitions("-Wno-narrowing") endif() # If we are trying to build the compiler with LLVM/MLIR as libraries @@ -104,7 +104,6 @@ option(CONCRETELANG_BENCHMARK "Enables the build of benchmarks" ON) #------------------------------------------------------------------------------- # Handling sub dirs #------------------------------------------------------------------------------- -include_directories(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/src/cpp) add_subdirectory(include) add_subdirectory(lib) diff --git a/compiler/concrete-optimizer b/compiler/concrete-optimizer index b446d3124..bc52e3cd2 160000 --- a/compiler/concrete-optimizer +++ b/compiler/concrete-optimizer @@ -1 +1 @@ -Subproject commit b446d3124d89e0e5783df947770ecec19e7a6582 +Subproject commit bc52e3cd2185ff20d2315a496f2d043ae7a02fa7 diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt b/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt index 3bde2e875..341e51e21 100644 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt +++ b/compiler/include/concretelang/Dialect/FHE/Analysis/CMakeLists.txt @@ -4,10 +4,3 @@ mlir_tablegen(MANP.capi.h.inc -gen-pass-capi-header --prefix Analysis) mlir_tablegen(MANP.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis) add_public_tablegen_target(MANPPassIncGen) add_dependencies(mlir-headers MANPPassIncGen) - -set(LLVM_TARGET_DEFINITIONS ConcreteOptimizer.td) -mlir_tablegen(ConcreteOptimizer.h.inc -gen-pass-decls -name Analysis) -mlir_tablegen(ConcreteOptimizer.capi.h.inc -gen-pass-capi-header --prefix Analysis) -mlir_tablegen(ConcreteOptimizer.capi.cpp.inc -gen-pass-capi-impl --prefix Analysis) -add_public_tablegen_target(ConcreteOptimizerPassIncGen) -add_dependencies(mlir-headers ConcreteOptimizerPassIncGen) diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h b/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h deleted file mode 100644 index 0ad7746c4..000000000 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h +++ /dev/null @@ -1,29 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER_H -#define CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER_H - -#include -#include - -#include "concrete-optimizer.hpp" - -#include "concretelang/Support/V0Parameters.h" - -namespace mlir { -namespace concretelang { - -namespace optimizer { -using FunctionsDag = std::map>; - -std::unique_ptr createDagPass(optimizer::Config config, - optimizer::FunctionsDag &dags); - -} // namespace optimizer -} // namespace concretelang -} // namespace mlir - -#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.td b/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.td deleted file mode 100644 index 7e00b7308..000000000 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.td +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER -#define CONCRETELANG_DIALECT_FHE_ANALYSIS_CONCRETE_OPTIMIZER - -include "mlir/Pass/PassBase.td" - -def ConcreteOptimizer : Pass<"ConcreteOptmizer", "::mlir::func::FuncOp"> { - let summary = "Call concrete-optimizer"; - let description = [{ - The pass calls the concrete-optimizer to provide crypto parameter. - It construct a simplified representation of the FHE circuit and send it to the concrete optimizer. - It uses on the values from the MANP pass to indicate how noise is propagate and amplified in levelled operations. - }]; -} - -#endif diff --git a/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h b/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h deleted file mode 100644 index 0da6574c2..000000000 --- a/compiler/include/concretelang/Dialect/FHE/Analysis/utils.h +++ /dev/null @@ -1,24 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#ifndef CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H -#define CONCRETELANG_DIALECT_FHE_ANALYSIS_UTILS_H - -#include - -namespace mlir { -namespace concretelang { -namespace fhe { -namespace utils { - -bool isEncryptedValue(mlir::Value value); -unsigned int getEintPrecision(mlir::Value value); - -} // namespace utils -} // namespace fhe -} // namespace concretelang -} // namespace mlir - -#endif diff --git a/compiler/include/concretelang/Support/CompilerEngine.h b/compiler/include/concretelang/Support/CompilerEngine.h index 3b753a63e..7164641b6 100644 --- a/compiler/include/concretelang/Support/CompilerEngine.h +++ b/compiler/include/concretelang/Support/CompilerEngine.h @@ -243,8 +243,8 @@ protected: std::shared_ptr compilationContext; private: - llvm::Expected> - getConcreteOptimizerDescription(CompilationResult &res); + llvm::Expected> + getV0FHEConstraint(CompilationResult &res); llvm::Error determineFHEParameters(CompilationResult &res); }; diff --git a/compiler/include/concretelang/Support/Pipeline.h b/compiler/include/concretelang/Support/Pipeline.h index a888e5ca5..e87f89fcc 100644 --- a/compiler/include/concretelang/Support/Pipeline.h +++ b/compiler/include/concretelang/Support/Pipeline.h @@ -20,10 +20,9 @@ namespace pipeline { mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, std::function enablePass); -llvm::Expected>> -getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, - optimizer::Config config, - std::function enablePass); +llvm::Expected> +getFHEConstraintsFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass); mlir::LogicalResult tileMarkedFHELinalg(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compiler/include/concretelang/Support/V0Parameters.h b/compiler/include/concretelang/Support/V0Parameters.h index 865411574..b7273077e 100644 --- a/compiler/include/concretelang/Support/V0Parameters.h +++ b/compiler/include/concretelang/Support/V0Parameters.h @@ -8,7 +8,6 @@ #include "llvm/ADT/Optional.h" -#include "concrete-optimizer.hpp" #include "concretelang/Conversion/Utils/GlobalFHEContext.h" namespace mlir { @@ -16,30 +15,16 @@ namespace concretelang { namespace optimizer { constexpr double P_ERROR_4_SIGMA = 1.0 - 0.999936657516; -constexpr uint DEFAULT_SECURITY = 128; - struct Config { double p_error; bool display; - bool strategy_v0; - std::uint64_t security; }; -constexpr Config DEFAULT_CONFIG = {P_ERROR_4_SIGMA, false, false, - DEFAULT_SECURITY}; - -using Dag = rust::Box; -using Solution = concrete_optimizer::v0::Solution; - -/* Contains any circuit description usable by the concrete-optimizer */ -struct Description { - V0FHEConstraint constraint; - llvm::Optional dag; -}; - +constexpr Config DEFAULT_CONFIG = {P_ERROR_4_SIGMA, false}; } // namespace optimizer -llvm::Optional getParameter(optimizer::Description &descr, - optimizer::Config optimizerConfig); +llvm::Optional getV0Parameter(V0FHEConstraint constraint, + optimizer::Config optimizerConfig); + } // namespace concretelang } // namespace mlir #endif diff --git a/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt b/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt index 62de96c07..057c086d4 100644 --- a/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt +++ b/compiler/lib/Dialect/FHE/Analysis/CMakeLists.txt @@ -1,6 +1,4 @@ add_mlir_library(FHEDialectAnalysis - utils.cpp - ConcreteOptimizer.cpp MANP.cpp ADDITIONAL_HEADER_DIRS diff --git a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp b/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp deleted file mode 100644 index dcf5a616a..000000000 --- a/compiler/lib/Dialect/FHE/Analysis/ConcreteOptimizer.cpp +++ /dev/null @@ -1,328 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#include -#include -#include - -#include "boost/outcome.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Pass/PassManager.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/Pass.h" -#include "llvm/Support/raw_ostream.h" - -#include "concrete-optimizer.hpp" - -#include "concretelang/Common/Error.h" -#include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h" -#include "concretelang/Dialect/FHE/Analysis/utils.h" -#include "concretelang/Dialect/FHE/IR/FHEOps.h" -#include "concretelang/Dialect/FHE/IR/FHETypes.h" -#include "concretelang/Dialect/FHELinalg/IR/FHELinalgOps.h" -#include "concretelang/Support/V0Parameters.h" -#include "concretelang/Support/logging.h" - -#define GEN_PASS_CLASSES -#include "concretelang/Dialect/FHE/Analysis/ConcreteOptimizer.h.inc" - -namespace mlir { -namespace concretelang { -namespace optimizer { - -namespace { - -template rust::Slice slice(const std::vector &vec) { - return rust::Slice(vec.data(), vec.size()); -} - -template rust::Slice slice(const llvm::ArrayRef &vec) { - return rust::Slice(vec.data(), vec.size()); -} - -struct FunctionToDag { - // Inputs of operators - using Inputs = std::vector; - - const double NEGLIGIBLE_COMPLEXITY = 0.0; - - mlir::func::FuncOp func; - optimizer::Config config; - llvm::DenseMap index; - - FunctionToDag(mlir::func::FuncOp func, optimizer::Config config) - : func(func), config(config) {} - -#define DEBUG(MSG) \ - if (mlir::concretelang::isVerbose()) { \ - mlir::concretelang::log_verbose() << MSG << "\n"; \ - } - - outcome::checked, - ::concretelang::error::StringError> - build() { - auto dag = concrete_optimizer::dag::empty(); - // Converting arguments as Input - for (auto &arg : func.getArguments()) { - addArg(dag, arg); - } - // Converting ops - for (auto &bb : func.getBody().getBlocks()) { - for (auto &op : bb.getOperations()) { - addOperation(dag, op); - } - } - if (index.empty()) { - // Dag is empty <=> classical function without encryption - DEBUG("!!! concrete-optimizer: nothing to do in " << func.getName() - << "\n"); - return llvm::None; - }; - DEBUG(std::string(dag->dump())); - return std::move(dag); - } - - void addArg(optimizer::Dag &dag, mlir::Value &arg) { - DEBUG("Arg " << arg << " " << arg.getType()); - if (!fhe::utils::isEncryptedValue(arg)) { - return; - } - auto precision = fhe::utils::getEintPrecision(arg); - auto shape = getShape(arg); - auto opI = dag->add_input(precision, slice(shape)); - index[arg] = opI; - } - - bool hasEncryptedResult(mlir::Operation &op) { - for (auto val : op.getResults()) { - if (fhe::utils::isEncryptedValue(val)) { - return true; - } - } - return false; - } - - void addOperation(optimizer::Dag &dag, mlir::Operation &op) { - DEBUG("Instr " << op); - - if (isReturn(op)) { - // This op has no result - return; - } - - auto encrypted_inputs = encryptedInputs(op); - if (!hasEncryptedResult(op)) { - // This op is unrelated to FHE - assert(encrypted_inputs.empty()); - return; - } - assert(op.getNumResults() == 1); - auto val = op.getResult(0); - auto precision = fhe::utils::getEintPrecision(val); - if (isLut(op)) { - addLut(dag, val, encrypted_inputs, precision); - return; - } - if (auto dot = asDot(op)) { - auto weightsOpt = dotWeights(dot); - if (weightsOpt) { - addDot(dag, val, encrypted_inputs, weightsOpt.getValue()); - return; - } - DEBUG("Replace Dot by LevelledOp on " << op); - } - // default - addLevelledOp(dag, op, encrypted_inputs); - } - - void addLut(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, - int precision) { - assert(encrypted_inputs.size() == 1); - // No need to distinguish different lut kind until we do approximate - // paradigm on outputs - auto encrypted_input = encrypted_inputs[0]; - std::vector unknowFunction; - index[val] = - dag->add_lut(encrypted_input, slice(unknowFunction), precision); - } - - void addDot(optimizer::Dag &dag, mlir::Value &val, Inputs &encrypted_inputs, - std::vector &weights_vector) { - assert(encrypted_inputs.size() == 1); - auto weights = concrete_optimizer::weights::vector(slice(weights_vector)); - index[val] = dag->add_dot(slice(encrypted_inputs), std::move(weights)); - } - - std::string loc_to_string(mlir::Location location) { - std::string loc; - llvm::raw_string_ostream loc_stream(loc); - location.print(loc_stream); - return loc; - } - - void addLevelledOp(optimizer::Dag &dag, mlir::Operation &op, Inputs &inputs) { - auto val = op.getResult(0); - auto out_shape = getShape(val); - if (inputs.empty()) { - // Trivial encrypted constants encoding - // There are converted to input + levelledop - auto precision = fhe::utils::getEintPrecision(val); - auto opI = dag->add_input(precision, slice(out_shape)); - inputs.push_back(opI); - } - // Default complexity is negligible - double fixed_cost = NEGLIGIBLE_COMPLEXITY; - double lwe_dim_cost_factor = NEGLIGIBLE_COMPLEXITY; - auto manp_int = op.getAttrOfType("MANP"); - auto loc = loc_to_string(op.getLoc()); - if (!manp_int) { - DEBUG("Cannot read manp on " << op << "\n" << loc); - } - assert(manp_int && "Missing manp value on a crypto operation"); - double manp = (double)manp_int.getValue().getZExtValue(); - auto comment = std::string(op.getName().getStringRef()) + " " + loc; - index[val] = - dag->add_levelled_op(slice(inputs), lwe_dim_cost_factor, fixed_cost, - manp, slice(out_shape), comment); - } - - Inputs encryptedInputs(mlir::Operation &op) { - Inputs inputs; - for (auto operand : op.getOperands()) { - auto entry = index.find(operand); - if (entry == index.end()) { - assert(!fhe::utils::isEncryptedValue(operand)); - DEBUG("Ignoring as input " << operand); - continue; - } - inputs.push_back(entry->getSecond()); - } - return inputs; - } - - bool isLut(mlir::Operation &op) { - return llvm::isa< - mlir::concretelang::FHE::ApplyLookupTableEintOp, - mlir::concretelang::FHELinalg::ApplyLookupTableEintOp, - mlir::concretelang::FHELinalg::ApplyMultiLookupTableEintOp, - mlir::concretelang::FHELinalg::ApplyMappedLookupTableEintOp>(op); - } - - mlir::concretelang::FHELinalg::Dot asDot(mlir::Operation &op) { - return llvm::dyn_cast(op); - } - - bool isReturn(mlir::Operation &op) { - return llvm::isa(op); - } - - bool isConst(mlir::Operation &op) { - return llvm::isa(op); - } - - bool isArg(const mlir::Value &value) { - return value.isa(); - } - - std::vector - resolveConstantVectorWeights(mlir::arith::ConstantOp &cstOp) { - std::vector values; - mlir::DenseIntElementsAttr denseVals = - cstOp->getAttrOfType("value"); - - for (llvm::APInt val : denseVals.getValues()) { - values.push_back(val.getZExtValue()); - } - return values; - } - - llvm::Optional> - resolveConstantWeights(mlir::Value &value) { - if (auto cstOp = llvm::dyn_cast_or_null( - value.getDefiningOp())) { - auto shape = getShape(value); - switch (shape.size()) { - case 1: - return resolveConstantVectorWeights(cstOp); - default: - DEBUG("High-Rank tensor: rely on MANP and levelledOp"); - return llvm::None; - } - } else { - DEBUG("Dynamic Weights: rely on MANP and levelledOp"); - return llvm::None; - } - } - - llvm::Optional> - dotWeights(mlir::concretelang::FHELinalg::Dot &dot) { - if (dot.getOperands().size() != 2) { - return llvm::None; - } - auto weights = dot.getOperands()[1]; - return resolveConstantWeights(weights); - } - - std::vector getShape(mlir::Value &value) { - return getShape(value.getType()); - } - - std::vector getShape(mlir::Type type_) { - if (auto ranked_tensor = type_.dyn_cast_or_null()) { - std::vector shape; - for (auto v : ranked_tensor.getShape()) { - shape.push_back(v); - } - return shape; - } else { - return {}; - } - } -}; - -} // namespace - -struct DagPass : ConcreteOptimizerBase { - optimizer::Config config; - optimizer::FunctionsDag &dags; - - void runOnOperation() override { - mlir::func::FuncOp func = getOperation(); - auto name = std::string(func.getName()); - DEBUG("ConcreteOptimizer Dag: " << name); - if (config.strategy_v0) { - // we avoid building the dag since it's not used in this case - // so strategy_v0 can be used to avoid dag creation issues - dags.insert(optimizer::FunctionsDag::value_type(name, llvm::None)); - } - auto dag = FunctionToDag(func, config).build(); - if (dag) { - dags.insert( - optimizer::FunctionsDag::value_type(name, std::move(dag.value()))); - } else { - this->signalPassFailure(); - } - } - - DagPass() = delete; - DagPass(optimizer::Config config, optimizer::FunctionsDag &dags) - : config(config), dags(dags) {} -}; - -// Create an instance of the ConcreteOptimizerPass pass. -// A global pass result is communicated using `dags`. -// If `debug` is true, for each operation, the pass emits a -// remark containing the squared Minimal Arithmetic Noise Padding of -// the equivalent dot operation. -std::unique_ptr createDagPass(optimizer::Config config, - optimizer::FunctionsDag &dags) { - return std::make_unique(config, dags); -} - -} // namespace optimizer -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp index 6edbc362c..39bce1179 100644 --- a/compiler/lib/Dialect/FHE/Analysis/MANP.cpp +++ b/compiler/lib/Dialect/FHE/Analysis/MANP.cpp @@ -4,7 +4,6 @@ // for license information. #include -#include #include #include #include @@ -46,7 +45,36 @@ static bool isEncryptedFunctionParameter(mlir::Value value) { return false; } - return mlir::concretelang::fhe::utils::isEncryptedValue(value); + return ( + value.getType().isa() || + (value.getType().isa() && + value.getType() + .cast() + .getElementType() + .isa())); +} + +/// Returns the bit width of `value` if `value` is an encrypted integer +/// or the bit width of the elements if `value` is a tensor of +/// encrypted integers. +static unsigned int getEintPrecision(mlir::Value value) { + if (auto ty = value.getType() + .dyn_cast_or_null< + mlir::concretelang::FHE::EncryptedIntegerType>()) { + return ty.getWidth(); + } else if (auto tensorTy = + value.getType().dyn_cast_or_null()) { + if (auto ty = tensorTy.getElementType() + .dyn_cast_or_null< + mlir::concretelang::FHE::EncryptedIntegerType>()) + return ty.getWidth(); + } + + assert(false && + "Value is neither an encrypted integer nor a tensor of encrypted " + "integers"); + + return 0; } /// The `MANPLatticeValue` represents the squared Minimal Arithmetic @@ -1496,7 +1524,7 @@ private: } // namespace namespace { -// For documentation see MANP.td +/// For documentation see MANP.td struct MANPPass : public MANPBase { void runOnOperation() override { mlir::func::FuncOp func = getOperation(); @@ -1545,7 +1573,7 @@ protected: llvm::dyn_cast_or_null(op)) { for (mlir::BlockArgument blockArg : func.getBody().getArguments()) { if (isEncryptedFunctionParameter(blockArg)) { - unsigned int width = fhe::utils::getEintPrecision(blockArg); + unsigned int width = getEintPrecision(blockArg); if (this->maxEintWidth < width) { this->maxEintWidth = width; diff --git a/compiler/lib/Dialect/FHE/Analysis/utils.cpp b/compiler/lib/Dialect/FHE/Analysis/utils.cpp deleted file mode 100644 index d0d3b270c..000000000 --- a/compiler/lib/Dialect/FHE/Analysis/utils.cpp +++ /dev/null @@ -1,51 +0,0 @@ -// Part of the Concrete Compiler Project, under the BSD3 License with Zama -// Exceptions. See -// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt -// for license information. - -#include "concretelang/Dialect/FHE/IR/FHETypes.h" -#include - -namespace mlir { -namespace concretelang { -namespace fhe { -namespace utils { -/// Returns `true` if the given value is a scalar or tensor argument of -/// a function, for which a MANP of 1 can be assumed. -bool isEncryptedValue(mlir::Value value) { - return ( - value.getType().isa() || - (value.getType().isa() && - value.getType() - .cast() - .getElementType() - .isa())); -} - -/// Returns the bit width of `value` if `value` is an encrypted integer -/// or the bit width of the elements if `value` is a tensor of -/// encrypted integers. -unsigned int getEintPrecision(mlir::Value value) { - if (auto ty = value.getType() - .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>()) { - return ty.getWidth(); - } else if (auto tensorTy = - value.getType().dyn_cast_or_null()) { - if (auto ty = tensorTy.getElementType() - .dyn_cast_or_null< - mlir::concretelang::FHE::EncryptedIntegerType>()) - return ty.getWidth(); - } - - assert(false && - "Value is neither an encrypted integer nor a tensor of encrypted " - "integers"); - - return 0; -} - -} // namespace utils -} // namespace fhe -} // namespace concretelang -} // namespace mlir diff --git a/compiler/lib/Support/CMakeLists.txt b/compiler/lib/Support/CMakeLists.txt index 9cca7dbfc..836f89d9b 100644 --- a/compiler/lib/Support/CMakeLists.txt +++ b/compiler/lib/Support/CMakeLists.txt @@ -1,3 +1,6 @@ +# not working in ADDITIONAL_HEADER_DIRS +include_directories(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/src/cpp) + add_mlir_library(ConcretelangSupport Pipeline.cpp Jit.cpp diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 188bfcca5..489ac3223 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -21,7 +21,6 @@ #include #include -#include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include #include #include @@ -113,81 +112,63 @@ void CompilerEngine::setEnablePass( this->enablePass = enablePass; } -/// Returns the optimizer::Description -llvm::Expected> -CompilerEngine::getConcreteOptimizerDescription(CompilationResult &res) { +/// Returns the overwritten V0FHEConstraint or try to compute them from FHE +llvm::Expected> +CompilerEngine::getV0FHEConstraint(CompilationResult &res) { mlir::MLIRContext &mlirContext = *this->compilationContext->getMLIRContext(); mlir::ModuleOp module = res.mlirModuleRef->get(); // If the values has been overwritten returns if (this->overrideMaxEintPrecision.hasValue() && this->overrideMaxMANP.hasValue()) { - auto constraint = mlir::concretelang::V0FHEConstraint{ + return mlir::concretelang::V0FHEConstraint{ this->overrideMaxMANP.getValue(), this->overrideMaxEintPrecision.getValue()}; - return optimizer::Description{constraint, llvm::None}; } - auto config = this->compilerOptions.optimizerConfig; - auto descriptions = mlir::concretelang::pipeline::getFHEContextFromFHE( - mlirContext, module, config, enablePass); - if (auto err = descriptions.takeError()) { + // Else compute constraint from FHE + llvm::Expected> + fheConstraintsOrErr = + mlir::concretelang::pipeline::getFHEConstraintsFromFHE( + mlirContext, module, enablePass); + + if (auto err = fheConstraintsOrErr.takeError()) return std::move(err); - } - if (descriptions->empty()) { // The pass has not been run - return llvm::None; - } - if (this->compilerOptions.clientParametersFuncName.hasValue()) { - auto name = this->compilerOptions.clientParametersFuncName.getValue(); - auto description = descriptions->find(name); - if (description == descriptions->end()) { - std::string names; - for (auto &entry : *descriptions) { - names += "'" + entry.first + "' "; - } - return StreamStringError() - << "Could not find existing crypto parameters for function '" - << name << "' (known functions: " << names << ")"; - } - return std::move(description->second); - } - if (descriptions->size() != 1) { - llvm::errs() << "Several crypto parameters exists: the function need to be " - "specified, taking the first one"; - } - return std::move(descriptions->begin()->second); + + return fheConstraintsOrErr.get(); } /// set the fheContext field if the v0Constraint can be computed llvm::Error CompilerEngine::determineFHEParameters(CompilationResult &res) { - auto descrOrErr = getConcreteOptimizerDescription(res); - if (auto err = descrOrErr.takeError()) { + auto fheConstraintOrErr = getV0FHEConstraint(res); + if (auto err = fheConstraintOrErr.takeError()) return err; - } - // The function is non-crypto and without constraint override - if (!descrOrErr.get().hasValue()) { + if (!fheConstraintOrErr.get().hasValue()) { return llvm::Error::success(); } - auto descr = std::move(descrOrErr.get().getValue()); - auto config = this->compilerOptions.optimizerConfig; + llvm::Optional v0Params; + if (compilerOptions.v0Parameter.hasValue()) { + v0Params = compilerOptions.v0Parameter; + } else { + v0Params = getV0Parameter(fheConstraintOrErr.get().getValue(), + this->compilerOptions.optimizerConfig); - auto fheParams = (compilerOptions.v0Parameter.hasValue()) - ? compilerOptions.v0Parameter - : getParameter(descr, config); - if (!fheParams) { - return StreamStringError() - << "Could not determine V0 parameters for 2-norm of " - << (*descrOrErr)->constraint.norm2 << " and p of " - << (*descrOrErr)->constraint.p; + if (!v0Params) { + return StreamStringError() + << "Could not determine V0 parameters for 2-norm of " + << (*fheConstraintOrErr)->norm2 << " and p of " + << (*fheConstraintOrErr)->p; + } } - res.fheContext.emplace( - mlir::concretelang::V0FHEContext{descr.constraint, fheParams.getValue()}); + res.fheContext.emplace(mlir::concretelang::V0FHEContext{ + (*fheConstraintOrErr).getValue(), v0Params.getValue()}); + return llvm::Error::success(); } using OptionalLib = llvm::Optional>; -// Compile the sources managed by the source manager `sm` to the -// target dialect `target`. If successful, the result can be retrieved -// using `getModule()` and `getLLVMModule()`, respectively depending -// on the target dialect. +/// Compile the sources managed by the source manager `sm` to the +/// target dialect `target`. If successful, the result can be retrieved +/// using `getModule()` and `getLLVMModule()`, respectively depending +/// on the target dialect. llvm::Expected CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { std::unique_ptr smHandler; @@ -300,8 +281,7 @@ CompilerEngine::compile(llvm::SourceMgr &sm, Target target, OptionalLib lib) { } if (!res.fheContext.hasValue()) { return StreamStringError( - "Cannot generate client parameters, the fhe context is empty for " + - options.clientParametersFuncName.getValue()); + "Cannot generate client parameters, the fhe context is empty"); } } // Generate client parameters if requested diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 8bc7fff62..3e86dd55f 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -26,12 +26,9 @@ #include #include -#include "concretelang/Support/CompilerEngine.h" -#include "concretelang/Support/Error.h" #include #include #include -#include #include #include #include @@ -76,13 +73,11 @@ addPotentiallyNestedPass(mlir::PassManager &pm, std::unique_ptr pass, } } -llvm::Expected>> -getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, - optimizer::Config config, - std::function enablePass) { +llvm::Expected> +getFHEConstraintsFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, + std::function enablePass) { llvm::Optional oMax2norm; llvm::Optional oMaxWidth; - optimizer::FunctionsDag dags; mlir::PassManager pm(&context); @@ -114,36 +109,18 @@ getFHEContextFromFHE(mlir::MLIRContext &context, mlir::ModuleOp &module, if (pm.run(module.getOperation()).failed()) { return llvm::make_error( "Failed to determine the maximum Arithmetic Noise Padding and maximum" - " required precision", + "required precision", llvm::inconvertibleErrorCode()); } - llvm::Optional constraint = llvm::None; + llvm::Optional ret; if (oMax2norm.hasValue() && oMaxWidth.hasValue()) { - constraint = llvm::Optional( + ret = llvm::Optional( {/*.norm2 = */ ceilLog2(oMax2norm.getValue()), /*.p = */ oMaxWidth.getValue()}); } - addPotentiallyNestedPass(pm, optimizer::createDagPass(config, dags), - enablePass); - if (pm.run(module.getOperation()).failed()) { - return StreamStringError() << "Failed to create concrete-optimizer dag\n"; - } - std::map> descriptions; - for (auto &entry_dag : dags) { - if (!constraint) { - descriptions.insert( - decltype(descriptions)::value_type(entry_dag.first, llvm::None)); - continue; - } - optimizer::Description description = {*constraint, - std::move(entry_dag.second)}; - llvm::Optional opt_description{ - std::move(description)}; - descriptions.insert(decltype(descriptions)::value_type( - entry_dag.first, std::move(opt_description))); - } - return std::move(descriptions); + + return ret; } mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, diff --git a/compiler/lib/Support/V0ClientParameters.cpp b/compiler/lib/Support/V0ClientParameters.cpp index ff9145f9e..96c29e0eb 100644 --- a/compiler/lib/Support/V0ClientParameters.cpp +++ b/compiler/lib/Support/V0ClientParameters.cpp @@ -138,8 +138,7 @@ createClientParametersForV0(V0FHEContext fheContext, }); if (funcOp == rangeOps.end()) { return llvm::make_error( - "cannot find the function for generate client parameters '" + - functionName + "'", + "cannot find the function for generate client parameters", llvm::inconvertibleErrorCode()); } diff --git a/compiler/lib/Support/V0Parameters.cpp b/compiler/lib/Support/V0Parameters.cpp index 92bf183ab..844f9f886 100644 --- a/compiler/lib/Support/V0Parameters.cpp +++ b/compiler/lib/Support/V0Parameters.cpp @@ -17,30 +17,15 @@ #include "concrete-optimizer.hpp" #include "concretelang/Support/V0Parameters.h" -#include namespace mlir { namespace concretelang { -optimizer::Solution getV0Parameter(V0FHEConstraint constraint, - optimizer::Config config) { - // the norm2 0 is equivalent to a maximum noise_factor of 2.0 - // norm2 = 0 ==> 1.0 =< noise_factor < 2.0 - // norm2 = k ==> 2^norm2 =< noise_factor < 2.0^norm2 + 1 - double noise_factor = std::exp2(constraint.norm2 + 1); - return concrete_optimizer::v0::optimize_bootstrap( - constraint.p, config.security, noise_factor, config.p_error); -} - -optimizer::Solution getV1Parameter(optimizer::Dag &dag, - optimizer::Config config) { - return dag->optimize_v0(config.security, config.p_error); -} - static void display(V0FHEConstraint constraint, - optimizer::Config optimizerConfig, optimizer::Solution sol, + optimizer::Config optimizerConfig, + concrete_optimizer::v0::Solution sol, std::chrono::milliseconds duration) { - if (!optimizerConfig.display && !mlir::concretelang::isVerbose()) { + if (!optimizerConfig.display) { return; } auto o = llvm::outs; @@ -69,15 +54,19 @@ static void display(V0FHEConstraint constraint, << "---\n"; } -llvm::Optional getParameter(optimizer::Description &descr, - optimizer::Config config) { +llvm::Optional getV0Parameter(V0FHEConstraint constraint, + optimizer::Config optimizerConfig) { namespace chrono = std::chrono; + int security = 128; + // the norm2 0 is equivalent to a maximum noise_factor of 2.0 + // norm2 = 0 ==> 1.0 =< noise_factor < 2.0 + // norm2 = k ==> 2^norm2 =< noise_factor < 2.0^norm2 + 1 + double noise_factor = std::exp2(constraint.norm2 + 1); + // https://github.com/zama-ai/concrete-optimizer/blob/prototype/python/optimizer/V0Parameters/tabulation.py#L58 + double p_error = optimizerConfig.p_error; auto start = chrono::high_resolution_clock::now(); - - auto sol = (!descr.dag || config.strategy_v0) - ? getV0Parameter(descr.constraint, config) - : getV1Parameter(descr.dag.getValue(), config); - + auto sol = concrete_optimizer::v0::optimize_bootstrap(constraint.p, security, + noise_factor, p_error); auto stop = chrono::high_resolution_clock::now(); if (sol.p_error == 1.0) { // The optimizer return a p_error = 1 if there is no solution @@ -89,7 +78,7 @@ llvm::Optional getParameter(optimizer::Description &descr, llvm::errs() << "concrete-optimizer time: " << duration_s.count() << "s\n"; } - display(descr.constraint, config, sol, duration); + display(constraint, optimizerConfig, sol, duration); return mlir::concretelang::V0Parameter{ sol.glwe_dimension, diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index d49e823f9..5c8a3ebfa 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -186,11 +186,6 @@ llvm::cl::opt displayOptimizerChoice( llvm::cl::desc("Display the information returned by the optimizer"), llvm::cl::init(false)); -llvm::cl::opt - optimizerV0("optimizer-v0", - llvm::cl::desc("Select the v0 parameters strategy"), - llvm::cl::init(false)); - llvm::cl::list fhelinalgTileSizes( "fhelinalg-tile-sizes", llvm::cl::desc( @@ -270,15 +265,8 @@ cmdlineCompilationOptions() { cmdline::v0Parameter[6]); } - if (!cmdline::v0Constraint.empty() && !cmdline::optimizerV0) { - return llvm::make_error( - "You must use --v0-constraint with --optimizer-v0-strategy", - llvm::inconvertibleErrorCode()); - } - options.optimizerConfig.p_error = cmdline::pbsErrorProbability; options.optimizerConfig.display = cmdline::displayOptimizerChoice; - options.optimizerConfig.strategy_v0 = cmdline::optimizerV0; return options; } diff --git a/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir b/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir index 27f048e5b..c1be85007 100644 --- a/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir +++ b/compiler/tests/check_tests/Conversion/TFHEGlobalParametrization/pbs_ks_bs.mlir @@ -1,4 +1,4 @@ -// RUN: concretecompiler --passes tfhe-global-parametrization --action=dump-std --optimizer-v0 --v0-parameter=2,10,750,1,23,3,4 --v0-constraint=4,0 %s 2>&1| FileCheck %s +// RUN: concretecompiler --passes tfhe-global-parametrization --action=dump-std --v0-parameter=2,10,750,1,23,3,4 -v0-constraint=4,0 %s 2>&1| FileCheck %s //CHECK: func @main(%[[A0:.*]]: !TFHE.glwe<{2048,1,64}{4}>) -> !TFHE.glwe<{2048,1,64}{4}> { //CHECK: %cst = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : tensor<16xi64> diff --git a/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir b/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir index 83bf18f8b..4470c33c4 100644 --- a/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir +++ b/compiler/tests/check_tests/Dialect/FHE/eint_error_p_too_big.mlir @@ -1,7 +1,6 @@ // RUN: not concretecompiler --action=dump-llvm-ir %s 2>&1| FileCheck %s // CHECK-LABEL: Could not determine V0 parameters -func @test(%arg0: !FHE.eint<9>, %arg1: tensor<512xi64>) { - %1 = "FHE.apply_lookup_table"(%arg0, %arg1): (!FHE.eint<9>, tensor<512xi64>) -> (!FHE.eint<9>) +func @test(%arg0: !FHE.eint<9>) { return } diff --git a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp index 9fc6baba4..c3bc187b8 100644 --- a/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp +++ b/compiler/tests/end_to_end_fixture/EndToEndFixture.cpp @@ -165,7 +165,7 @@ template <> struct llvm::yaml::MappingTraits { } }; -LLVM_YAML_IS_SEQUENCE_VECTOR(ValueDescription) +LLVM_YAML_IS_SEQUENCE_VECTOR(ValueDescription); template <> struct llvm::yaml::MappingTraits { static void mapping(IO &io, TestDescription &desc) { @@ -174,7 +174,7 @@ template <> struct llvm::yaml::MappingTraits { } }; -LLVM_YAML_IS_SEQUENCE_VECTOR(TestDescription) +LLVM_YAML_IS_SEQUENCE_VECTOR(TestDescription); template <> struct llvm::yaml::MappingTraits { static void mapping(IO &io, EndToEndDesc &desc) { diff --git a/compiler/tests/end_to_end_tests/CMakeLists.txt b/compiler/tests/end_to_end_tests/CMakeLists.txt index 82e5b8b9e..6922a00d0 100644 --- a/compiler/tests/end_to_end_tests/CMakeLists.txt +++ b/compiler/tests/end_to_end_tests/CMakeLists.txt @@ -7,7 +7,6 @@ function(add_concretecompiler_unittest test_name) endfunction() include_directories(${PROJECT_SOURCE_DIR}/include) -include_directories(${CONCRETE_OPTIMIZER_DIR}/concrete-optimizer-cpp/src/cpp) if(NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") link_libraries( diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_test.h b/compiler/tests/end_to_end_tests/end_to_end_jit_test.h index 00d84b295..c02f3656e 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_test.h +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_test.h @@ -23,10 +23,8 @@ internalCheckedJit(llvm::StringRef src, llvm::StringRef func = "main", auto options = mlir::concretelang::CompilationOptions(std::string(func.data())); - if (useDefaultFHEConstraints) { + if (useDefaultFHEConstraints) options.v0FHEConstraints = defaultV0Constraints; - options.optimizerConfig.strategy_v0 = true; - } // Allow loop parallelism in all cases options.loopParallelize = loopParallelize; diff --git a/compiler/tests/python/test_compilation.py b/compiler/tests/python/test_compilation.py index 29cb6ea81..833afcdf4 100644 --- a/compiler/tests/python/test_compilation.py +++ b/compiler/tests/python/test_compilation.py @@ -350,6 +350,6 @@ def test_compile_and_run_invalid_arg_number( def test_compile_invalid(mlir_input): engine = JITSupport.new() with pytest.raises( - RuntimeError, match=r"Could not find existing crypto parameters for" + RuntimeError, match=r"cannot find the function for generate client parameters" ): engine.compile(mlir_input)