mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
This commit contains several incremental improvements towards a clear
interface for lambdas:
- Unification of static and JIT compilation by using the static
compilation path of `CompilerEngine` within a new subclass
`JitCompilerEngine`.
- Clear ownership for compilation artefacts through
`CompilationContext`, making it impossible to destroy objects used
directly or indirectly before destruction of their users.
- Clear interface for lambdas generated by the compiler through
`JitCompilerEngine::Lambda` with a templated call operator,
encapsulating otherwise manual orchestration of `CompilerEngine`,
`JITLambda`, and `CompilerEngine::Argument`.
- Improved error handling through `llvm::Expected<T>` and proper
error checking following the conventions for `llvm::Expected<T>`
and `llvm::Error`.
Co-authored-by: youben11 <ayoub.benaissa@zama.ai>
106 lines
3.7 KiB
C++
106 lines
3.7 KiB
C++
#include "llvm/Support/Error.h"
|
|
#include <llvm/ADT/STLExtras.h>
|
|
#include <llvm/Support/TargetSelect.h>
|
|
#include <mlir/ExecutionEngine/OptUtils.h>
|
|
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
|
#include <zamalang/Support/JitCompilerEngine.h>
|
|
|
|
namespace mlir {
|
|
namespace zamalang {
|
|
|
|
JitCompilerEngine::JitCompilerEngine(
|
|
std::shared_ptr<CompilationContext> compilationContext,
|
|
unsigned int optimizationLevel)
|
|
: CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) {
|
|
}
|
|
|
|
// Returns the `LLVMFuncOp` operation in the compiled module with the
|
|
// specified name. If no LLVMFuncOp with that name exists or if there
|
|
// was no prior call to `compile()` resulting in an MLIR module in the
|
|
// LLVM dialect, an error is returned.
|
|
llvm::Expected<mlir::LLVM::LLVMFuncOp>
|
|
JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
|
|
auto funcOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
|
|
auto funcOp = llvm::find_if(
|
|
funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; });
|
|
|
|
if (funcOp == funcOps.end()) {
|
|
return StreamStringError()
|
|
<< "Module does not contain function named '" << name.str() << "'";
|
|
}
|
|
|
|
return *funcOp;
|
|
}
|
|
|
|
// Build a lambda from the function with the name given in
|
|
// `funcName` from the sources in `buffer`.
|
|
llvm::Expected<JitCompilerEngine::Lambda>
|
|
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
|
llvm::StringRef funcName) {
|
|
llvm::SourceMgr sm;
|
|
|
|
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
|
|
|
llvm::Expected<JitCompilerEngine::Lambda> res =
|
|
this->buildLambda(sm, funcName);
|
|
|
|
return std::move(res);
|
|
}
|
|
|
|
// Build a lambda from the function with the name given in `funcName`
|
|
// from the source string `s`.
|
|
llvm::Expected<JitCompilerEngine::Lambda>
|
|
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName) {
|
|
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
|
llvm::Expected<JitCompilerEngine::Lambda> res =
|
|
this->buildLambda(std::move(mb), funcName);
|
|
|
|
return std::move(res);
|
|
}
|
|
|
|
// Build a lambda from the function with the name given in
|
|
// `funcName` from the sources managed by the source manager `sm`.
|
|
llvm::Expected<JitCompilerEngine::Lambda>
|
|
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName) {
|
|
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
|
|
|
|
this->setGenerateKeySet(true);
|
|
this->setGenerateClientParameters(true);
|
|
this->setClientParametersFuncName(funcName);
|
|
|
|
// First, compile to LLVM Dialect
|
|
llvm::Expected<CompilerEngine::CompilationResult> compResOrErr =
|
|
this->compile(sm, Target::LLVM_IR);
|
|
|
|
if (!compResOrErr)
|
|
return std::move(compResOrErr.takeError());
|
|
|
|
mlir::ModuleOp module = compResOrErr->mlirModuleRef->get();
|
|
|
|
// Locate function to JIT-compile
|
|
llvm::Expected<mlir::LLVM::LLVMFuncOp> funcOrError =
|
|
this->findLLVMFuncOp(compResOrErr->mlirModuleRef->get(), funcName);
|
|
|
|
if (!funcOrError)
|
|
return std::move(funcOrError.takeError());
|
|
|
|
// Prepare LLVM infrastructure for JIT compilation
|
|
llvm::InitializeNativeTarget();
|
|
llvm::InitializeNativeTargetAsmPrinter();
|
|
mlir::registerLLVMDialectTranslation(mlirContext);
|
|
|
|
std::function<llvm::Error(llvm::Module *)> optPipeline =
|
|
mlir::makeOptimizingTransformer(3, 0, nullptr);
|
|
|
|
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
|
|
mlir::zamalang::JITLambda::create(funcName, module, optPipeline);
|
|
|
|
if (!lambdaOrErr)
|
|
return std::move(lambdaOrErr.takeError());
|
|
|
|
return Lambda{this->compilationContext, std::move(lambdaOrErr.get()),
|
|
std::move(compResOrErr->keySet)};
|
|
}
|
|
} // namespace zamalang
|
|
} // namespace mlir
|