cleanup(compiler): Remove references to JitCompilerEngine to the main

This commit is contained in:
Quentin Bourgerie
2022-03-17 15:51:26 +01:00
parent 1620259807
commit 3fccc98e68

View File

@@ -20,6 +20,7 @@
#include <sstream>
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Conversion/Passes.h"
#include "concretelang/Conversion/Utils/GlobalFHEContext.h"
#include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h"
@@ -30,13 +31,16 @@
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/JitCompilerEngine.h"
#include "concretelang/Support/JitLambdaSupport.h"
#include "concretelang/Support/LLVMEmitFile.h"
#include "concretelang/Support/Pipeline.h"
#include "concretelang/Support/logging.h"
#include "mlir/IR/BuiltinOps.h"
namespace clientlib = concretelang::clientlib;
enum Action {
ROUND_TRIP,
DUMP_FHE,
@@ -153,7 +157,7 @@ llvm::cl::opt<bool> dataflowParallelize(
llvm::cl::opt<std::string>
funcName("funcname",
llvm::cl::desc("Name of the function to compile, default 'main'"),
llvm::cl::init<std::string>("main"));
llvm::cl::init<std::string>(""));
llvm::cl::list<uint64_t>
jitArgs("jit-args",
@@ -262,46 +266,35 @@ mlir::LogicalResult processInputBuffer(
llvm::Optional<size_t> overrideMaxMANP, bool verifyDiagnostics,
llvm::Optional<llvm::ArrayRef<int64_t>> fhelinalgTileSizes,
bool autoParallelize, bool loopParallelize, bool dataflowParallelize,
llvm::Optional<mlir::concretelang::KeySetCache> keySetCache,
llvm::raw_ostream &os,
llvm::Optional<clientlib::KeySetCache> keySetCache, llvm::raw_ostream &os,
std::shared_ptr<mlir::concretelang::CompilerEngine::Library> outputLib) {
std::shared_ptr<mlir::concretelang::CompilationContext> ccx =
mlir::concretelang::CompilationContext::createShared();
mlir::concretelang::JitCompilerEngine ce{ccx};
mlir::concretelang::CompilationOptions options;
ce.setVerifyDiagnostics(verifyDiagnostics);
ce.setAutoParallelize(autoParallelize);
ce.setLoopParallelize(loopParallelize);
ce.setDataflowParallelize(dataflowParallelize);
if (cmdline::passes.size() != 0) {
ce.setEnablePass([](mlir::Pass *pass) {
return std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return pass->getArgument() == p; });
});
}
options.verifyDiagnostics = verifyDiagnostics;
options.autoParallelize = autoParallelize;
options.loopParallelize = loopParallelize;
options.dataflowParallelize = dataflowParallelize;
if (overrideMaxEintPrecision.hasValue())
ce.setMaxEintPrecision(overrideMaxEintPrecision.getValue());
if (overrideMaxEintPrecision.hasValue() && overrideMaxMANP.hasValue())
options.v0FHEConstraints = {
overrideMaxMANP.hasValue(),
overrideMaxEintPrecision.hasValue(),
};
if (overrideMaxMANP.hasValue())
ce.setMaxMANP(overrideMaxMANP.getValue());
if (!funcName.empty())
options.clientParametersFuncName = funcName;
ce.setClientParametersFuncName(funcName);
if (fhelinalgTileSizes.hasValue())
ce.setFHELinalgTileSizes(*fhelinalgTileSizes);
options.fhelinalgTileSizes = *fhelinalgTileSizes;
if (action == Action::JIT_INVOKE) {
llvm::Expected<mlir::concretelang::JitCompilerEngine::Lambda> lambdaOrErr =
ce.buildLambda(std::move(buffer), funcName, keySetCache);
if (!lambdaOrErr) {
mlir::concretelang::log_error()
<< "Failed to JIT-compile " << funcName << ": "
<< llvm::toString(lambdaOrErr.takeError());
return mlir::failure();
}
auto lambdaOrErr =
mlir::concretelang::ClientServer<mlir::concretelang::JitLambdaSupport>::
create(buffer->getBuffer(), options, keySetCache,
mlir::concretelang::JitLambdaSupport());
llvm::Expected<uint64_t> resOrErr = (*lambdaOrErr)(jitArgs);
@@ -314,6 +307,16 @@ mlir::LogicalResult processInputBuffer(
os << *resOrErr << "\n";
} else {
mlir::concretelang::CompilerEngine ce{ccx};
ce.setCompilationOptions(options);
if (cmdline::passes.size() != 0) {
ce.setEnablePass([](mlir::Pass *pass) {
return std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return pass->getArgument() == p; });
});
}
enum mlir::concretelang::CompilerEngine::Target target;
switch (action) {
@@ -412,10 +415,9 @@ mlir::LogicalResult compilerMain(int argc, char **argv) {
if (!cmdline::fhelinalgTileSizes.empty())
fhelinalgTileSizes.emplace(cmdline::fhelinalgTileSizes);
llvm::Optional<mlir::concretelang::KeySetCache> jitKeySetCache;
llvm::Optional<clientlib::KeySetCache> jitKeySetCache;
if (!cmdline::jitKeySetCachePath.empty()) {
jitKeySetCache =
mlir::concretelang::KeySetCache(cmdline::jitKeySetCachePath);
jitKeySetCache = clientlib::KeySetCache(cmdline::jitKeySetCachePath);
}
// In case of compilation to library, the real output is the library.