feat(compiler): Add passes to lower mlir to mlir llvm ir and run jit and emit llvm code (#63)

This commit is contained in:
Quentin Bourgerie
2021-07-29 16:08:32 +02:00
committed by Ayoub Benaissa
parent c58abe6565
commit b4e57984b1
13 changed files with 382 additions and 30 deletions

View File

@@ -0,0 +1,16 @@
#ifndef ZAMALANG_CONVERSION_MLIRLOWERABLEDIALECTSTOLLVM_PASS_H_
#define ZAMALANG_CONVERSION_MLIRLOWERABLEDIALECTSTOLLVM_PASS_H_
#include "mlir/Pass/Pass.h"
namespace mlir {
template <typename T> class OperationPass;
namespace zamalang {
/// Create a pass to convert MLIR lowerable dialects to LLVM.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertMLIRLowerableDialectsToLLVMPass();
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -6,6 +6,9 @@
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Conversion/MLIRLowerableDialectsToLLVM/Pass.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#define GEN_PASS_CLASSES

View File

@@ -17,4 +17,11 @@ def HLFHEToMidLFHE : Pass<"hlfhe-to-midlfhe", "mlir::ModuleOp"> {
let dependentDialects = ["mlir::linalg::LinalgDialect"];
}
def MLIRLowerableDialectsToLLVM : Pass<"mlir-lowerable-dialects-to-llvm", "mlir::ModuleOp"> {
let summary = "Lowers operations from MLIR lowerable dialects to LLVM";
let constructor = "mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass()";
let dependentDialects = ["mlir::StandardOpsDialect", "mlir::scf::SCFDialect", "mlir::LLVM::LLVMDialect"];
let options = [];
}
#endif

View File

@@ -0,0 +1,56 @@
#ifndef ZAMALANG_SUPPORT_COMPILERTOOLS_H_
#define ZAMALANG_SUPPORT_COMPILERTOOLS_H_
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/Pass/PassManager.h>
namespace mlir {
namespace zamalang {
class CompilerTools {
public:
/// lowerHLFHEToMlirLLVMDialect run all passes to lower FHE dialects to mlir
/// LLVM dialect.
static mlir::LogicalResult lowerHLFHEToMlirLLVMDialect(
mlir::MLIRContext &context, mlir::Operation *module,
llvm::function_ref<bool(std::string)> enablePass = [](std::string pass) {
return true;
});
static llvm::Expected<std::unique_ptr<llvm::Module>>
toLLVMModule(llvm::LLVMContext &context, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
};
/// JITLambda is a tool to JIT compile an mlir module and to invoke a function
/// of the module.
class JITLambda {
public:
JITLambda(mlir::LLVM::LLVMFunctionType type, llvm::StringRef name)
: type(type), name(name){};
/// create a JITLambda that point to the function name of the given module.
static llvm::Expected<std::unique_ptr<JITLambda>>
create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline);
/// invokeRaw execute the jit lambda with a lits of arguments, the last one is
/// used to store the result of the computation.
/// Example:
/// uin64_t arg0 = 1;
/// uin64_t res;
/// llvm::SmallVector<void *> args{&arg1, &res};
/// lambda.invokeRaw(args);
llvm::Error invokeRaw(llvm::MutableArrayRef<void *> args);
private:
mlir::LLVM::LLVMFunctionType type;
llvm::StringRef name;
std::unique_ptr<mlir::ExecutionEngine> engine;
};
} // namespace zamalang
} // namespace mlir
#endif

View File

@@ -1,5 +1,6 @@
add_subdirectory(Dialect)
add_subdirectory(Conversion)
add_subdirectory(Support)
# CAPI needed only for python bindings
if (ZAMALANG_BINDINGS_PYTHON_ENABLED)

View File

@@ -1,2 +1,3 @@
add_subdirectory(HLFHEToMidLFHE)
add_subdirectory(HLFHETensorOpsToLinalg)
add_subdirectory(HLFHETensorOpsToLinalg)
add_subdirectory(MLIRLowerableDialectsToLLVM)

View File

@@ -119,7 +119,7 @@ void HLFHETensorOpsToLinalg::runOnFunction() {
namespace mlir {
namespace zamalang {
std::unique_ptr<mlir::Pass> createConvertHLFHETensorOpsToLinalg() {
std::unique_ptr<mlir::FunctionPass> createConvertHLFHETensorOpsToLinalg() {
return std::make_unique<HLFHETensorOpsToLinalg>();
}
} // namespace zamalang

View File

@@ -0,0 +1,21 @@
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
add_mlir_dialect_library(MLIRLowerableDialectsToLLVM
MLIRLowerableDialectsToLLVM.cpp
ADDITIONAL_HEADER_DIRS
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
${dialect_libs}
${conversion_libs}
MLIRIR
MLIRTransforms
MLIRLLVMIR
MLIRLLVMToLLVMIRTranslation
MLIRMath)
target_link_libraries(MLIRLowerableDialectsToLLVM PUBLIC MLIRIR)

View File

@@ -0,0 +1,61 @@
#include <iostream>
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
namespace {
struct MLIRLowerableDialectsToLLVMPass
: public MLIRLowerableDialectsToLLVMBase<MLIRLowerableDialectsToLLVMPass> {
void runOnOperation() final;
};
} // namespace
void MLIRLowerableDialectsToLLVMPass::runOnOperation() {
// Setup the conversion target. We reuse the LLVMConversionTarget that
// legalize LLVM dialect.
mlir::LLVMConversionTarget target(getContext());
target.addLegalOp<mlir::ModuleOp>();
// Setup the LLVMTypeConverter (that converts `std` types to `llvm` types) and
// add our types conversion to `llvm` compatible type.
mlir::LLVMTypeConverter typeConverter(&getContext());
// Setup the set of the patterns rewriter. At this point we want to
// convert the `scf` operations to `std` and `std` operations to `llvm`.
mlir::RewritePatternSet patterns(&getContext());
mlir::populateLoopToStdConversionPatterns(patterns);
mlir::populateStdToLLVMConversionPatterns(typeConverter, patterns);
// Apply a `FullConversion` to `llvm`.
auto module = getOperation();
if (mlir::applyFullConversion(module, target, std::move(patterns)).failed()) {
signalPassFailure();
}
}
namespace mlir {
namespace zamalang {
/// Create a pass for lowering operations the remaining mlir dialects
/// operations, to the LLVM dialect for codegen.
std::unique_ptr<OperationPass<ModuleOp>>
createConvertMLIRLowerableDialectsToLLVMPass() {
return std::make_unique<MLIRLowerableDialectsToLLVMPass>();
}
} // namespace zamalang
} // namespace mlir

View File

@@ -0,0 +1,16 @@
add_mlir_library(ZamalangSupport
CompilerTools.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/zamalang/Support
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
HLFHETensorOpsToLinalg
HLFHEToMidLFHE
MLIRLowerableDialectsToLLVM
MLIRExecutionEngine
${LLVM_PTHREAD_LIB})

View File

@@ -0,0 +1,115 @@
#include <llvm/Support/TargetSelect.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <mlir/Target/LLVMIR/Export.h>
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Support/CompilerTools.h"
namespace mlir {
namespace zamalang {
void initLLVMNativeTarget() {
// Initialize LLVM targets.
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
}
void addFilteredPassToPassManager(
mlir::PassManager &pm, std::unique_ptr<mlir::Pass> pass,
llvm::function_ref<bool(std::string)> enablePass) {
if (!enablePass(pass->getArgument().str())) {
return;
}
if (*pass->getOpName() == "module") {
pm.addPass(std::move(pass));
} else {
pm.nest(*pass->getOpName()).addPass(std::move(pass));
}
};
mlir::LogicalResult CompilerTools::lowerHLFHEToMlirLLVMDialect(
mlir::MLIRContext &context, mlir::Operation *module,
llvm::function_ref<bool(std::string)> enablePass) {
mlir::PassManager pm(&context);
// Add all passes to lower from HLFHE to LLVM Dialect
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHETensorOpsToLinalg(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass(), enablePass);
addFilteredPassToPassManager(
pm, mlir::zamalang::createConvertMLIRLowerableDialectsToLLVMPass(),
enablePass);
// Run the passes
if (pm.run(module).failed()) {
return mlir::failure();
}
return mlir::success();
}
llvm::Expected<std::unique_ptr<llvm::Module>> CompilerTools::toLLVMModule(
llvm::LLVMContext &context, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
initLLVMNativeTarget();
mlir::registerLLVMDialectTranslation(*module->getContext());
auto llvmModule = mlir::translateModuleToLLVMIR(module, context);
if (!llvmModule) {
return llvm::make_error<llvm::StringError>(
"failed to translate MLIR to LLVM IR", llvm::inconvertibleErrorCode());
}
if (auto err = optPipeline(llvmModule.get())) {
return llvm::make_error<llvm::StringError>("failed to optimize LLVM IR",
llvm::inconvertibleErrorCode());
}
return std::move(llvmModule);
}
llvm::Expected<std::unique_ptr<JITLambda>>
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
// Looking for the function
auto rangeOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
auto funcOp = llvm::find_if(rangeOps, [&](mlir::LLVM::LLVMFuncOp op) {
return op.getName() == name;
});
if (funcOp == rangeOps.end()) {
return llvm::make_error<llvm::StringError>(
"cannot find the function to JIT", llvm::inconvertibleErrorCode());
}
initLLVMNativeTarget();
mlir::registerLLVMDialectTranslation(*module->getContext());
// Create an MLIR execution engine. The execution engine eagerly
// JIT-compiles the module.
auto maybeEngine = mlir::ExecutionEngine::create(
module, /*llvmModuleBuilder=*/nullptr, optPipeline);
if (!maybeEngine) {
return llvm::make_error<llvm::StringError>(
"failed to construct the MLIR ExecutionEngine",
llvm::inconvertibleErrorCode());
}
auto &engine = maybeEngine.get();
auto lambda = std::make_unique<JITLambda>((*funcOp).getType(), name);
lambda->engine = std::move(engine);
return std::move(lambda);
}
llvm::Error JITLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
if (this->type.getNumParams() == args.size() - 1 /*For the result*/) {
return this->engine->invokePacked(this->name, args);
}
return llvm::make_error<llvm::StringError>(
"wrong number of argument when invoke the JIT lambda",
llvm::inconvertibleErrorCode());
}
} // namespace zamalang
} // namespace mlir

View File

@@ -1,14 +1,24 @@
add_llvm_tool(zamacompiler main.cpp)
target_compile_options(zamacompiler PRIVATE -fexceptions)
llvm_update_compile_flags(zamacompiler)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
target_link_libraries(zamacompiler
PRIVATE
${dialect_libs}
${conversion_libs}
MLIRTransforms
MidLFHEDialect
HLFHEDialect
HLFHETensorOpsToLinalg
HLFHEToMidLFHE
MLIRIR
MLIRLLVMIR
MLIRLLVMToLLVMIRTranslation
ZamalangSupport
)
mlir_check_all_link_libraries(zamacompiler)

View File

@@ -5,18 +5,19 @@
#include <llvm/Support/ToolOutputFile.h>
#include <mlir/Dialect/Linalg/IR/LinalgOps.h>
#include <mlir/Dialect/MemRef/IR/MemRef.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Parser.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <mlir/Support/ToolUtilities.h>
#include "zamalang/Conversion/HLFHETensorOpsToLinalg/Pass.h"
#include "zamalang/Conversion/HLFHEToMidLFHE/Pass.h"
#include "zamalang/Conversion/Passes.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHEDialect.h"
#include "zamalang/Dialect/HLFHE/IR/HLFHETypes.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHEDialect.h"
#include "zamalang/Dialect/MidLFHE/IR/MidLFHETypes.h"
#include "zamalang/Support/CompilerTools.h"
namespace cmdline {
@@ -49,24 +50,58 @@ llvm::cl::opt<bool> splitInputFile(
llvm::cl::desc("Split the input file into pieces and process each "
"chunk independently"),
llvm::cl::init(false));
llvm::cl::opt<bool> runJit("run-jit", llvm::cl::desc("JIT the code and run it"),
llvm::cl::init<bool>(false));
llvm::cl::list<int>
jitArgs("jit-args",
llvm::cl::desc("Value of arguments to pass to the main func"),
llvm::cl::value_desc("passname"), llvm::cl::ZeroOrMore);
llvm::cl::opt<bool> toLLVM("to-llvm", llvm::cl::desc("Compile to llvm and "),
llvm::cl::init<bool>(false));
}; // namespace cmdline
void addPassCmdLineFiltered(mlir::PassManager &pm,
std::unique_ptr<mlir::Pass> pass) {
if (cmdline::roundTrip)
return;
auto passName = pass->getName();
if (cmdline::passes.size() == 0 ||
std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return pass->getArgument() == p; })) {
if (*pass->getOpName() == "module") {
pm.addPass(std::move(pass));
} else {
pm.nest(*pass->getOpName()).addPass(std::move(pass));
}
auto defaultOptPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
mlir::LogicalResult dumpLLVMIR(mlir::ModuleOp module, llvm::raw_ostream &os) {
llvm::LLVMContext context;
auto llvmModule = mlir::zamalang::CompilerTools::toLLVMModule(
context, module, defaultOptPipeline);
if (!llvmModule) {
return mlir::failure();
}
return;
os << **llvmModule;
return mlir::success();
}
mlir::LogicalResult runJit(mlir::ModuleOp module, llvm::raw_ostream &os) {
// Create the JIT lambda
auto maybeLambda =
mlir::zamalang::JITLambda::create("main", module, defaultOptPipeline);
if (!maybeLambda) {
return mlir::failure();
}
auto lambda = maybeLambda.get().get();
// Create buffer to copy argument
std::vector<int64_t> dummy(cmdline::jitArgs.size());
llvm::SmallVector<void *> llvmArgs;
for (auto i = 0; i < cmdline::jitArgs.size(); i++) {
dummy[i] = cmdline::jitArgs[i];
llvmArgs.push_back(&dummy[i]);
}
// Add the result pointer
uint64_t res = 0;
llvmArgs.push_back(&res);
// Invoke the lambda
if (lambda->invokeRaw(llvmArgs)) {
return mlir::failure();
}
std::cerr << res << "\n";
return mlir::success();
}
// Process a single source buffer
@@ -82,14 +117,11 @@ mlir::LogicalResult
processInputBuffer(mlir::MLIRContext &context,
std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::raw_ostream &os, bool verifyDiagnostics) {
mlir::PassManager pm(&context);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
mlir::SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr,
&context);
auto module = mlir::parseSourceFile(sourceMgr, &context);
if (verifyDiagnostics)
@@ -98,17 +130,30 @@ processInputBuffer(mlir::MLIRContext &context,
if (!module)
return mlir::failure();
addPassCmdLineFiltered(pm,
mlir::zamalang::createConvertHLFHETensorOpsToLinalg());
addPassCmdLineFiltered(pm, mlir::zamalang::createConvertHLFHEToMidLFHEPass());
if (cmdline::roundTrip) {
module->print(os);
return mlir::success();
}
if (pm.run(*module).failed()) {
llvm::errs() << "Could not run passes!\n";
if (mlir::zamalang::CompilerTools::lowerHLFHEToMlirLLVMDialect(
context, *module,
[](std::string passName) {
return cmdline::passes.size() == 0 ||
std::any_of(
cmdline::passes.begin(), cmdline::passes.end(),
[&](const std::string &p) { return passName == p; });
})
.failed()) {
return mlir::failure();
}
if (cmdline::runJit) {
return runJit(module.get(), os);
}
if (cmdline::toLLVM) {
return dumpLLVMIR(module.get(), os);
}
module->print(os);
return mlir::success();
}