refactor(compiler): Move JIT functionality to separate source file

This commit is contained in:
Andi Drebes
2021-09-13 15:18:22 +02:00
committed by Quentin Bourgerie
parent b9e2690823
commit 6a76177a47
4 changed files with 84 additions and 44 deletions

View File

@@ -21,6 +21,7 @@
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
#include "zamalang/Support/logging.h"
#include "zamalang/Support/Jit.h"
namespace cmdline {
@@ -79,7 +80,8 @@ llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
std::function<llvm::Error(llvm::Module *)> defaultOptPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
llvm::LLVMContext context;
@@ -92,48 +94,6 @@ mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
return mlir::success();
}
mlir::LogicalResult runJit(mlir::ModuleOp module,
mlir::zamalang::KeySet &keySet,
llvm::raw_ostream &os) {
// Create the JIT lambda
auto maybeLambda = mlir::zamalang::JITLambda::create(
cmdline::jitFuncname, module, defaultOptPipeline);
if (!maybeLambda) {
return mlir::failure();
}
auto lambda = std::move(maybeLambda.get());
// Create the arguments of the JIT lambda
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
if (auto err = maybeArguments.takeError()) {
mlir::zamalang::log_error()
<< "Cannot create lambda arguments: " << err << "\n";
return mlir::failure();
}
// Set the arguments
auto arguments = std::move(maybeArguments.get());
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
if (auto err = arguments->setArg(i, cmdline::jitArgs[i])) {
mlir::zamalang::log_error()
<< "Cannot push argument " << i << ": " << err << "\n";
return mlir::failure();
}
}
// Invoke the lambda
if (auto err = lambda->invoke(*arguments)) {
mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n";
return mlir::failure();
}
uint64_t res = 0;
if (auto err = arguments->getResult(0, res)) {
mlir::zamalang::log_error() << "Cannot get result : " << err << "\n";
return mlir::failure();
}
llvm::errs() << res << "\n";
return mlir::success();
}
// Process a single source buffer
//
// If `verifyDiagnostics` is `true`, the procedure only checks if the
@@ -226,7 +186,9 @@ processInputBuffer(mlir::MLIRContext &context,
if (cmdline::runJit) {
mlir::zamalang::log_verbose() << "### JIT compile & running\n";
return runJit(module.get(), *keySet, os);
return mlir::zamalang::runJit(module.get(), cmdline::jitFuncname,
cmdline::jitArgs, *keySet,
defaultOptPipeline, os);
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);