refactor(compiler): Refactor CompilerEngine and related classes

This commit contains several incremental improvements towards a clear
interface for lambdas:

  - Unification of static and JIT compilation by using the static
    compilation path of `CompilerEngine` within a new subclass
    `JitCompilerEngine`.

  - Clear ownership for compilation artefacts through
    `CompilationContext`, making it impossible to destroy objects used
    directly or indirectly before destruction of their users.

  - Clear interface for lambdas generated by the compiler through
    `JitCompilerEngine::Lambda` with a templated call operator,
    encapsulating otherwise manual orchestration of `CompilerEngine`,
    `JITLambda`, and `CompilerEngine::Argument`.

  - Improved error handling through `llvm::Expected<T>` and proper
    error checking following the conventions for `llvm::Expected<T>`
    and `llvm::Error`.

Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
Andi Drebes
2021-10-18 15:38:12 +02:00
parent d738104c4b
commit 1187cfbd62
61 changed files with 1690 additions and 997 deletions

View File

@@ -1,62 +1,83 @@
#include "zamalang-c/Support/CompilerEngine.h"
#include "zamalang/Support/CompilerEngine.h"
#include "zamalang/Support/ExecutionArgument.h"
#include "zamalang/Support/Jit.h"
#include "zamalang/Support/JitCompilerEngine.h"
#include "zamalang/Support/logging.h"
using mlir::zamalang::CompilerEngine;
// using mlir::zamalang::CompilerEngine;
using mlir::zamalang::ExecutionArgument;
using mlir::zamalang::JitCompilerEngine;
void compilerEngineCompile(compilerEngine engine, const char *module) {
auto error = engine.ptr->compile(module);
if (error) {
llvm::errs() << "Compilation failed: " << error << "\n";
llvm::consumeError(std::move(error));
mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module,
const char *funcName) {
mlir::zamalang::JitCompilerEngine engine;
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
engine.buildLambda(module, funcName);
if (!lambdaOrErr) {
mlir::zamalang::log_error()
<< "Compilation failed: "
<< llvm::toString(std::move(lambdaOrErr.takeError())) << "\n";
throw std::runtime_error(
"failed compiling, see previous logs for more info");
}
return std::move(*lambdaOrErr);
}
uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) {
auto args_size = args.size;
auto maybeArgument = engine.ptr->buildArgument();
if (auto err = maybeArgument.takeError()) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error(
"failed building arguments, see previous logs for more info");
uint64_t invokeLambda(lambda l, executionArguments args) {
mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr =
(mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr;
if (args.size != lambda_ptr->getNumArguments()) {
throw std::invalid_argument("wrong number of arguments");
}
// Set the integer/tensor arguments
auto arguments = std::move(maybeArgument.get());
for (auto i = 0; i < args_size; i++) {
std::vector<mlir::zamalang::LambdaArgument *> lambdaArgumentsRef;
for (auto i = 0; i < args.size; i++) {
if (args.data[i].isInt()) { // integer argument
if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed pushing integer argument, see "
"previous logs for more info");
}
lambdaArgumentsRef.push_back(new mlir::zamalang::IntLambdaArgument<>(
args.data[i].getIntegerArgument()));
} else { // tensor argument
assert(args.data[i].isTensor() && "should be tensor argument");
if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(),
args.data[i].getTensorSize())) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed pushing tensor argument, see "
"previous logs for more info");
}
llvm::MutableArrayRef<uint8_t> tensor(args.data[i].getTensorArgument(),
args.data[i].getTensorSize());
lambdaArgumentsRef.push_back(
new mlir::zamalang::TensorLambdaArgument<
mlir::zamalang::IntLambdaArgument<uint8_t>>(tensor));
}
}
// Invoke the lambda
if (auto err = engine.ptr->invoke(*arguments)) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
throw std::runtime_error("failed running, see previous logs for more info");
}
uint64_t result = 0;
if (auto err = arguments->getResult(0, result)) {
llvm::errs() << "Execution failed: " << err << "\n";
llvm::consumeError(std::move(err));
// Run lambda
llvm::Expected<uint64_t> resOrError = (*lambda_ptr)(
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(lambdaArgumentsRef));
// Free heap
for (size_t i = 0; i < lambdaArgumentsRef.size(); i++)
delete lambdaArgumentsRef[i];
if (!resOrError) {
mlir::zamalang::log_error()
<< "Lambda invokation failed: "
<< llvm::toString(std::move(resOrError.takeError())) << "\n";
throw std::runtime_error(
"failed getting result, see previous logs for more info");
"failed invoking lambda, see previous logs for more info");
}
return result;
}
return *resOrError;
}
std::string roundTrip(const char *module) {
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
mlir::zamalang::CompilationContext::createShared();
mlir::zamalang::JitCompilerEngine ce{ccx};
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
ce.compile(module, mlir::zamalang::CompilerEngine::Target::ROUND_TRIP);
if (!retOrErr) {
mlir::zamalang::log_error()
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
throw std::runtime_error(
"mlir parsing failed, see previous logs for more info");
}
std::string result;
llvm::raw_string_ostream os(result);
retOrErr->mlirModuleRef->get().print(os);
return os.str();
}