From 3fccc98e68bd351a1b01fd2d9fe1ad4b7211b0c7 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Thu, 17 Mar 2022 15:51:26 +0100 Subject: [PATCH] cleanup(compiler): Remove references to JitCompilerEngine to the main --- compiler/src/main.cpp | 70 ++++++++++++++++++++++--------------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 301fd038b..b1d5a5ec2 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -20,6 +20,7 @@ #include #include "concretelang/ClientLib/KeySet.h" +#include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/Utils/GlobalFHEContext.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" @@ -30,13 +31,16 @@ #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" #include "concretelang/Runtime/runtime_api.h" +#include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" -#include "concretelang/Support/JitCompilerEngine.h" +#include "concretelang/Support/JitLambdaSupport.h" #include "concretelang/Support/LLVMEmitFile.h" #include "concretelang/Support/Pipeline.h" #include "concretelang/Support/logging.h" #include "mlir/IR/BuiltinOps.h" +namespace clientlib = concretelang::clientlib; + enum Action { ROUND_TRIP, DUMP_FHE, @@ -153,7 +157,7 @@ llvm::cl::opt dataflowParallelize( llvm::cl::opt funcName("funcname", llvm::cl::desc("Name of the function to compile, default 'main'"), - llvm::cl::init("main")); + llvm::cl::init("")); llvm::cl::list jitArgs("jit-args", @@ -262,46 +266,35 @@ mlir::LogicalResult processInputBuffer( llvm::Optional overrideMaxMANP, bool verifyDiagnostics, llvm::Optional> fhelinalgTileSizes, bool autoParallelize, bool loopParallelize, bool dataflowParallelize, - llvm::Optional keySetCache, - llvm::raw_ostream &os, + llvm::Optional keySetCache, llvm::raw_ostream &os, std::shared_ptr outputLib) { std::shared_ptr ccx = mlir::concretelang::CompilationContext::createShared(); - mlir::concretelang::JitCompilerEngine ce{ccx}; + mlir::concretelang::CompilationOptions options; - ce.setVerifyDiagnostics(verifyDiagnostics); - ce.setAutoParallelize(autoParallelize); - ce.setLoopParallelize(loopParallelize); - ce.setDataflowParallelize(dataflowParallelize); - if (cmdline::passes.size() != 0) { - ce.setEnablePass([](mlir::Pass *pass) { - return std::any_of( - cmdline::passes.begin(), cmdline::passes.end(), - [&](const std::string &p) { return pass->getArgument() == p; }); - }); - } + options.verifyDiagnostics = verifyDiagnostics; + options.autoParallelize = autoParallelize; + options.loopParallelize = loopParallelize; + options.dataflowParallelize = dataflowParallelize; - if (overrideMaxEintPrecision.hasValue()) - ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue()); + if (overrideMaxEintPrecision.hasValue() && overrideMaxMANP.hasValue()) + options.v0FHEConstraints = { + overrideMaxMANP.hasValue(), + overrideMaxEintPrecision.hasValue(), + }; - if (overrideMaxMANP.hasValue()) - ce.setMaxMANP(overrideMaxMANP.getValue()); + if (!funcName.empty()) + options.clientParametersFuncName = funcName; - ce.setClientParametersFuncName(funcName); if (fhelinalgTileSizes.hasValue()) - ce.setFHELinalgTileSizes(*fhelinalgTileSizes); + options.fhelinalgTileSizes = *fhelinalgTileSizes; if (action == Action::JIT_INVOKE) { - llvm::Expected lambdaOrErr = - ce.buildLambda(std::move(buffer), funcName, keySetCache); - - if (!lambdaOrErr) { - mlir::concretelang::log_error() - << "Failed to JIT-compile " << funcName << ": " - << llvm::toString(lambdaOrErr.takeError()); - return mlir::failure(); - } + auto lambdaOrErr = + mlir::concretelang::ClientServer:: + create(buffer->getBuffer(), options, keySetCache, + mlir::concretelang::JitLambdaSupport()); llvm::Expected resOrErr = (*lambdaOrErr)(jitArgs); @@ -314,6 +307,16 @@ mlir::LogicalResult processInputBuffer( os << *resOrErr << "\n"; } else { + mlir::concretelang::CompilerEngine ce{ccx}; + ce.setCompilationOptions(options); + + if (cmdline::passes.size() != 0) { + ce.setEnablePass([](mlir::Pass *pass) { + return std::any_of( + cmdline::passes.begin(), cmdline::passes.end(), + [&](const std::string &p) { return pass->getArgument() == p; }); + }); + } enum mlir::concretelang::CompilerEngine::Target target; switch (action) { @@ -412,10 +415,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) { if (!cmdline::fhelinalgTileSizes.empty()) fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes); - llvm::Optional jitKeySetCache; + llvm::Optional jitKeySetCache; if (!cmdline::jitKeySetCachePath.empty()) { - jitKeySetCache = - mlir::concretelang::KeySetCache(cmdline::jitKeySetCachePath); + jitKeySetCache = clientlib::KeySetCache(cmdline::jitKeySetCachePath); } // In case of compilation to library, the real output is the library.