mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
refactor(compiler): Refactor CompilerEngine and related classes
This commit contains several incremental improvements towards a clear
interface for lambdas:
- Unification of static and JIT compilation by using the static
compilation path of `CompilerEngine` within a new subclass
`JitCompilerEngine`.
- Clear ownership for compilation artefacts through
`CompilationContext`, making it impossible to destroy objects used
directly or indirectly before destruction of their users.
- Clear interface for lambdas generated by the compiler through
`JitCompilerEngine::Lambda` with a templated call operator,
encapsulating otherwise manual orchestration of `CompilerEngine`,
`JITLambda`, and `CompilerEngine::Argument`.
- Improved error handling through `llvm::Expected<T>` and proper
error checking following the conventions for `llvm::Expected<T>`
and `llvm::Error`.
Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
@@ -1,62 +1,83 @@
|
||||
#include "zamalang-c/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/CompilerEngine.h"
|
||||
#include "zamalang/Support/ExecutionArgument.h"
|
||||
#include "zamalang/Support/Jit.h"
|
||||
#include "zamalang/Support/JitCompilerEngine.h"
|
||||
#include "zamalang/Support/logging.h"
|
||||
|
||||
using mlir::zamalang::CompilerEngine;
|
||||
// using mlir::zamalang::CompilerEngine;
|
||||
using mlir::zamalang::ExecutionArgument;
|
||||
using mlir::zamalang::JitCompilerEngine;
|
||||
|
||||
void compilerEngineCompile(compilerEngine engine, const char *module) {
|
||||
auto error = engine.ptr->compile(module);
|
||||
if (error) {
|
||||
llvm::errs() << "Compilation failed: " << error << "\n";
|
||||
llvm::consumeError(std::move(error));
|
||||
mlir::zamalang::JitCompilerEngine::Lambda buildLambda(const char *module,
|
||||
const char *funcName) {
|
||||
mlir::zamalang::JitCompilerEngine engine;
|
||||
llvm::Expected<mlir::zamalang::JitCompilerEngine::Lambda> lambdaOrErr =
|
||||
engine.buildLambda(module, funcName);
|
||||
if (!lambdaOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Compilation failed: "
|
||||
<< llvm::toString(std::move(lambdaOrErr.takeError())) << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed compiling, see previous logs for more info");
|
||||
}
|
||||
return std::move(*lambdaOrErr);
|
||||
}
|
||||
|
||||
uint64_t compilerEngineRun(compilerEngine engine, exectuionArguments args) {
|
||||
auto args_size = args.size;
|
||||
auto maybeArgument = engine.ptr->buildArgument();
|
||||
if (auto err = maybeArgument.takeError()) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error(
|
||||
"failed building arguments, see previous logs for more info");
|
||||
uint64_t invokeLambda(lambda l, executionArguments args) {
|
||||
mlir::zamalang::JitCompilerEngine::Lambda *lambda_ptr =
|
||||
(mlir::zamalang::JitCompilerEngine::Lambda *)l.ptr;
|
||||
|
||||
if (args.size != lambda_ptr->getNumArguments()) {
|
||||
throw std::invalid_argument("wrong number of arguments");
|
||||
}
|
||||
// Set the integer/tensor arguments
|
||||
auto arguments = std::move(maybeArgument.get());
|
||||
for (auto i = 0; i < args_size; i++) {
|
||||
std::vector<mlir::zamalang::LambdaArgument *> lambdaArgumentsRef;
|
||||
for (auto i = 0; i < args.size; i++) {
|
||||
if (args.data[i].isInt()) { // integer argument
|
||||
if (auto err = arguments->setArg(i, args.data[i].getIntegerArgument())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed pushing integer argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
lambdaArgumentsRef.push_back(new mlir::zamalang::IntLambdaArgument<>(
|
||||
args.data[i].getIntegerArgument()));
|
||||
} else { // tensor argument
|
||||
assert(args.data[i].isTensor() && "should be tensor argument");
|
||||
if (auto err = arguments->setArg(i, args.data[i].getTensorArgument(),
|
||||
args.data[i].getTensorSize())) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed pushing tensor argument, see "
|
||||
"previous logs for more info");
|
||||
}
|
||||
llvm::MutableArrayRef<uint8_t> tensor(args.data[i].getTensorArgument(),
|
||||
args.data[i].getTensorSize());
|
||||
lambdaArgumentsRef.push_back(
|
||||
new mlir::zamalang::TensorLambdaArgument<
|
||||
mlir::zamalang::IntLambdaArgument<uint8_t>>(tensor));
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = engine.ptr->invoke(*arguments)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
throw std::runtime_error("failed running, see previous logs for more info");
|
||||
}
|
||||
uint64_t result = 0;
|
||||
if (auto err = arguments->getResult(0, result)) {
|
||||
llvm::errs() << "Execution failed: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
// Run lambda
|
||||
llvm::Expected<uint64_t> resOrError = (*lambda_ptr)(
|
||||
llvm::ArrayRef<mlir::zamalang::LambdaArgument *>(lambdaArgumentsRef));
|
||||
// Free heap
|
||||
for (size_t i = 0; i < lambdaArgumentsRef.size(); i++)
|
||||
delete lambdaArgumentsRef[i];
|
||||
|
||||
if (!resOrError) {
|
||||
mlir::zamalang::log_error()
|
||||
<< "Lambda invokation failed: "
|
||||
<< llvm::toString(std::move(resOrError.takeError())) << "\n";
|
||||
throw std::runtime_error(
|
||||
"failed getting result, see previous logs for more info");
|
||||
"failed invoking lambda, see previous logs for more info");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
return *resOrError;
|
||||
}
|
||||
|
||||
std::string roundTrip(const char *module) {
|
||||
std::shared_ptr<mlir::zamalang::CompilationContext> ccx =
|
||||
mlir::zamalang::CompilationContext::createShared();
|
||||
mlir::zamalang::JitCompilerEngine ce{ccx};
|
||||
|
||||
llvm::Expected<mlir::zamalang::CompilerEngine::CompilationResult> retOrErr =
|
||||
ce.compile(module, mlir::zamalang::CompilerEngine::Target::ROUND_TRIP);
|
||||
if (!retOrErr) {
|
||||
mlir::zamalang::log_error()
|
||||
<< llvm::toString(std::move(retOrErr.takeError())) << "\n";
|
||||
throw std::runtime_error(
|
||||
"mlir parsing failed, see previous logs for more info");
|
||||
}
|
||||
|
||||
std::string result;
|
||||
llvm::raw_string_ostream os(result);
|
||||
retOrErr->mlirModuleRef->get().print(os);
|
||||
return os.str();
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user