refactor(compiler): Refactor CompilerEngine and related classes

This commit contains several incremental improvements towards a clear
interface for lambdas:

  - Unification of static and JIT compilation by using the static
    compilation path of `CompilerEngine` within a new subclass
    `JitCompilerEngine`.

  - Clear ownership for compilation artefacts through
    `CompilationContext`, making it impossible to destroy objects used
    directly or indirectly before destruction of their users.

  - Clear interface for lambdas generated by the compiler through
    `JitCompilerEngine::Lambda` with a templated call operator,
    encapsulating otherwise manual orchestration of `CompilerEngine`,
    `JITLambda`, and `CompilerEngine::Argument`.

  - Improved error handling through `llvm::Expected<T>` and proper
    error checking following the conventions for `llvm::Expected<T>`
    and `llvm::Error`.

Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
Andi Drebes
2021-10-18 15:38:12 +02:00
parent d738104c4b
commit 1187cfbd62
61 changed files with 1690 additions and 997 deletions

View File

@@ -0,0 +1,105 @@
#include "llvm/Support/Error.h"
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <zamalang/Support/JitCompilerEngine.h>
namespace mlir {
namespace zamalang {
JitCompilerEngine::JitCompilerEngine(
std::shared_ptr<CompilationContext> compilationContext,
unsigned int optimizationLevel)
: CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) {
}
// Returns the `LLVMFuncOp` operation in the compiled module with the
// specified name. If no LLVMFuncOp with that name exists or if there
// was no prior call to `compile()` resulting in an MLIR module in the
// LLVM dialect, an error is returned.
llvm::Expected<mlir::LLVM::LLVMFuncOp>
JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
auto funcOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
auto funcOp = llvm::find_if(
funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; });
if (funcOp == funcOps.end()) {
return StreamStringError()
<< "Module does not contain function named '" << name.str() << "'";
}
return *funcOp;
}
// Build a lambda from the function with the name given in
// `funcName` from the sources in `buffer`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(sm, funcName);
return std::move(res);
}
// Build a lambda from the function with the name given in `funcName`
// from the source string `s`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) {
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(std::move(mb), funcName);
return std::move(res);
}
// Build a lambda from the function with the name given in
// `funcName` from the sources managed by the source manager `sm`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
this->setGenerateKeySet(true);
this->setGenerateClientParameters(true);
this->setClientParametersFuncName(funcName);
// First, compile to LLVM Dialect
llvm::Expected<CompilerEngine::CompilationResult> compResOrErr =
this->compile(sm, Target::LLVM_IR);
if (!compResOrErr)
return std::move(compResOrErr.takeError());
mlir::ModuleOp module = compResOrErr->mlirModuleRef->get();
// Locate function to JIT-compile
llvm::Expected<mlir::LLVM::LLVMFuncOp> funcOrError =
this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName);
if (!funcOrError)
return std::move(funcOrError.takeError());
// Prepare LLVM infrastructure for JIT compilation
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(mlirContext);
std::function<llvm::Error(llvm::Module *)> optPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr);
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
mlir::zamalang::JITLambda::create(funcName, module, optPipeline);
if (!lambdaOrErr)
return std::move(lambdaOrErr.takeError());
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
std::move(compResOrErr->keySet)};
}
} // namespace zamalang
} // namespace mlir