mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
refactor(compiler): Refactor CompilerEngine and related classes
This commit contains several incremental improvements towards a clear
interface for lambdas:
- Unification of static and JIT compilation by using the static
compilation path of `CompilerEngine` within a new subclass
`JitCompilerEngine`.
- Clear ownership for compilation artefacts through
`CompilationContext`, making it impossible to destroy objects used
directly or indirectly before destruction of their users.
- Clear interface for lambdas generated by the compiler through
`JitCompilerEngine::Lambda` with a templated call operator,
encapsulating otherwise manual orchestration of `CompilerEngine`,
`JITLambda`, and `CompilerEngine::Argument`.
- Improved error handling through `llvm::Expected<T>` and proper
error checking following the conventions for `llvm::Expected<T>`
and `llvm::Error`.
Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include "llvm/Support/Error.h"
|
||||
#include <llvm/ADT/ArrayRef.h>
|
||||
#include <llvm/ADT/SmallVector.h>
|
||||
#include <llvm/ADT/StringRef.h>
|
||||
@@ -12,56 +13,6 @@
|
||||
namespace mlir {
|
||||
namespace zamalang {
|
||||
|
||||
// JIT-compiles `module` invokes `func` with the arguments passed in
|
||||
// `jitArguments` and `keySet`
|
||||
mlir::LogicalResult
|
||||
runJit(mlir::ModuleOp module, llvm::StringRef func,
|
||||
llvm::ArrayRef<uint64_t> funcArgs, mlir::zamalang::KeySet &keySet,
|
||||
std::function<llvm::Error(llvm::Module *)> optPipeline,
|
||||
llvm::raw_ostream &os) {
|
||||
// Create the JIT lambda
|
||||
auto maybeLambda =
|
||||
mlir::zamalang::JITLambda::create(func, module, optPipeline);
|
||||
if (!maybeLambda) {
|
||||
return mlir::failure();
|
||||
}
|
||||
auto lambda = std::move(maybeLambda.get());
|
||||
|
||||
// Create the arguments of the JIT lambda
|
||||
auto maybeArguments = mlir::zamalang::JITLambda::Argument::create(keySet);
|
||||
if (auto err = maybeArguments.takeError()) {
|
||||
::mlir::zamalang::log_error()
|
||||
<< "Cannot create lambda arguments: " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
// Set the arguments
|
||||
auto arguments = std::move(maybeArguments.get());
|
||||
for (size_t i = 0; i < funcArgs.size(); i++) {
|
||||
if (auto err = arguments->setArg(i, funcArgs[i])) {
|
||||
::mlir::zamalang::log_error()
|
||||
<< "Cannot push argument " << i << ": " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
// Invoke the lambda
|
||||
if (auto err = lambda->invoke(*arguments)) {
|
||||
::mlir::zamalang::log_error() << "Cannot invoke : " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
uint64_t res = 0;
|
||||
if (auto err = arguments->getResult(0, res)) {
|
||||
::mlir::zamalang::log_error() << "Cannot get result : " << err << "\n";
|
||||
llvm::consumeError(std::move(err));
|
||||
return mlir::failure();
|
||||
}
|
||||
llvm::errs() << res << "\n";
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITLambda>>
|
||||
JITLambda::create(llvm::StringRef name, mlir::ModuleOp &module,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> optPipeline) {
|
||||
|
||||
Reference in New Issue
Block a user