feat(compiler): Add passes to lower mlir to mlir llvm ir and run jit and emit llvm code (#63)

This commit is contained in:
Quentin Bourgerie
2021-07-29 16:08:32 +02:00
committed by Ayoub Benaissa
parent c58abe6565
commit b4e57984b1
13 changed files with 382 additions and 30 deletions

View File

@@ -5,18 +5,19 @@
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
namespace cmdline {
@@ -49,24 +50,58 @@ llvm::cl::opt<bool> splitInputFile(
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
llvm::cl::init<bool>(false));
llvm::cl::list<int>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
void addPassCmdLineFiltered(mlir::PassManager &pm,
std::unique_ptr<mlir::Pass> pass) {
if (cmdline::roundTrip)
return;
auto passName = pass->getName();
if (cmdline::passes.size() == 0 ||
std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return pass->getArgument() == p; })) {
if (*pass->getOpName() == "module") {
pm.addPass(std::move(pass));
} else {
pm.nest(*pass->getOpName()).addPass(std::move(pass));
}
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
llvm::LLVMContext context;
auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule(
context, module, defaultOptPipeline);
if (!llvmModule) {
return mlir::failure();
}
return;
os << **llvmModule;
return mlir::success();
}
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
// Create the JIT lambda
auto maybeLambda =
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
if (!maybeLambda) {
return mlir::failure();
}
auto lambda = maybeLambda.get().get();
// Create buffer to copy argument
std::vector<int64_t> dummy(cmdline::jitArgs.size());
llvm::SmallVector<void *> llvmArgs;
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
dummy[i] = cmdline::jitArgs[i];
llvmArgs.push_back(&dummy[i]);
}
// Add the result pointer
uint64_t res = 0;
llvmArgs.push_back(&res);
// Invoke the lambda
if (lambda->invokeRaw(llvmArgs)) {
return mlir::failure();
}
std::cerr << res << "\n";
return mlir::success();
}
// Process a single source buffer
@@ -82,14 +117,11 @@ mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
mlir::PassManager pm(&context);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
auto module = mlir::parseSourceFile(sourceMgr, &context);
if (verifyDiagnostics)
@@ -98,17 +130,30 @@ processInputBuffer(mlir::MLIRContext &context,
if (!module)
return mlir::failure();
addPassCmdLineFiltered(pm,
mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
addPassCmdLineFiltered(pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass());
if (cmdline::roundTrip) {
module->print(os);
return mlir::success();
}
if (pm.run(*module).failed()) {
llvm::errs() << "Could not run passes!\n";
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
context, *module,
[](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return passName == p; });
})
.failed()) {
return mlir::failure();
}
if (cmdline::runJit) {
return runJit(module.get(), os);
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);
}
module->print(os);
return mlir::success();
}