Files
concrete/compiler/lib/Support/JitCompilerEngine.cpp

138 lines
5.1 KiB
C++

// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "llvm/Support/Error.h"
#include <concretelang/Support/JitCompilerEngine.h>
#include <llvm/ADT/STLExtras.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
namespace mlir {
namespace concretelang {
JitCompilerEngine::JitCompilerEngine(
std::shared_ptr<CompilationContext> compilationContext,
unsigned int optimizationLevel)
: CompilerEngine(compilationContext), optimizationLevel(optimizationLevel) {
}
// Returns the `LLVMFuncOp` operation in the compiled module with the
// specified name. If no LLVMFuncOp with that name exists or if there
// was no prior call to `compile()` resulting in an MLIR module in the
// LLVM dialect, an error is returned.
llvm::Expected<mlir::LLVM::LLVMFuncOp>
JitCompilerEngine::findLLVMFuncOp(mlir::ModuleOp module, llvm::StringRef name) {
auto funcOps = module.getOps<mlir::LLVM::LLVMFuncOp>();
auto funcOp = llvm::find_if(
funcOps, [&](mlir::LLVM::LLVMFuncOp op) { return op.getName() == name; });
if (funcOp == funcOps.end()) {
return StreamStringError()
<< "Module does not contain function named '" << name.str() << "'";
}
return *funcOp;
}
// Build a lambda from the function with the name given in
// `funcName` from the sources in `buffer`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
llvm::StringRef funcName,
llvm::Optional<KeySetCache> cache,
llvm::Optional<llvm::StringRef> runtimeLibPath) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
return this->buildLambda(sm, funcName, cache, runtimeLibPath);
}
// Build a lambda from the function with the name given in `funcName`
// from the source string `s`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
llvm::Optional<KeySetCache> cache,
llvm::Optional<llvm::StringRef> runtimeLibPath) {
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
llvm::Expected<JitCompilerEngine::Lambda> res =
this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath);
return res;
}
// Build a lambda from the function with the name given in
// `funcName` from the sources managed by the source manager `sm`.
llvm::Expected<JitCompilerEngine::Lambda>
JitCompilerEngine::buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName,
llvm::Optional<KeySetCache> cache,
llvm::Optional<llvm::StringRef> runtimeLibPath) {
MLIRContext &mlirContext = *this->compilationContext->getMLIRContext();
this->setGenerateClientParameters(true);
this->setClientParametersFuncName(funcName);
// First, compile to LLVM Dialect
llvm::Expected<CompilerEngine::CompilationResult> compResOrErr =
this->compile(sm, Target::LLVM_IR);
if (!compResOrErr)
return compResOrErr.takeError();
auto compRes = std::move(compResOrErr.get());
mlir::ModuleOp module = compRes.mlirModuleRef->get();
// Locate function to JIT-compile
llvm::Expected<mlir::LLVM::LLVMFuncOp> funcOrError =
this->findLLVMFuncOp(compRes.mlirModuleRef->get(), funcName);
if (!funcOrError)
return StreamStringError() << "Cannot find function \"" << funcName
<< "\": " << std::move(funcOrError.takeError());
// Prepare LLVM infrastructure for JIT compilation
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
mlir::registerLLVMDialectTranslation(mlirContext);
auto optPipeline = mlir::makeOptimizingTransformer(3, 0, nullptr);
llvm::Expected<std::unique_ptr<JITLambda>> lambdaOrErr =
mlir::concretelang::JITLambda::create(funcName, module, optPipeline,
runtimeLibPath);
if (!lambdaOrErr) {
return StreamStringError()
<< "Cannot create lambda: " << lambdaOrErr.takeError();
}
auto lambda = std::move(lambdaOrErr.get());
// Generate the KeySet for encrypting lambda arguments, decrypting lambda
// results
if (!compRes.clientParameters.hasValue()) {
return StreamStringError("Cannot generate the keySet since client "
"parameters has not been computed");
}
std::shared_ptr<KeySetCache> cachePtr;
if (cache.hasValue()) {
cachePtr = std::make_shared<KeySetCache>(cache.getValue());
}
auto keySetOrErr =
KeySetCache::generate(cachePtr, *compRes.clientParameters, 0, 0);
if (!keySetOrErr) {
return StreamStringError(keySetOrErr.error().mesg);
}
auto keySet = std::move(keySetOrErr.value());
return Lambda{this->compilationContext, std::move(lambda), std::move(keySet),
*compRes.clientParameters};
}
} // namespace concretelang
} // namespace mlir