mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
164 lines
5.9 KiB
C++
164 lines
5.9 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
|
// for license information.
|
|
|
|
#ifndef CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
|
#define CONCRETELANG_SUPPORT_JIT_COMPILER_ENGINE_H
|
|
|
|
#include "concretelang/ClientLib/KeySetCache.h"
|
|
#include <concretelang/Support/CompilerEngine.h>
|
|
#include <concretelang/Support/Error.h>
|
|
#include <concretelang/Support/Jit.h>
|
|
#include <concretelang/Support/LambdaArgument.h>
|
|
#include <concretelang/Support/LambdaSupport.h>
|
|
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
using ::concretelang::clientlib::KeySetCache;
|
|
namespace clientlib = ::concretelang::clientlib;
|
|
|
|
// A compiler engine that JIT-compiles a source and produces a lambda
|
|
// object directly invocable through its call operator.
|
|
class JitCompilerEngine : public CompilerEngine {
|
|
public:
|
|
// Wrapper class around `JITLambda` and `JITLambda::Argument` that
|
|
// allows for direct invocation of a compiled function through
|
|
// `operator ()`.
|
|
class Lambda {
|
|
public:
|
|
Lambda(Lambda &&other)
|
|
: innerLambda(std::move(other.innerLambda)),
|
|
keySet(std::move(other.keySet)),
|
|
compilationContext(other.compilationContext),
|
|
clientParameters(other.clientParameters) {}
|
|
|
|
Lambda(std::shared_ptr<CompilationContext> compilationContext,
|
|
std::unique_ptr<JITLambda> lambda, std::unique_ptr<KeySet> keySet,
|
|
clientlib::ClientParameters clientParameters)
|
|
: innerLambda(std::move(lambda)), keySet(std::move(keySet)),
|
|
compilationContext(compilationContext),
|
|
clientParameters(clientParameters) {}
|
|
|
|
// Returns the number of arguments required for an invocation of
|
|
// the lambda
|
|
size_t getNumArguments() { return this->keySet->numInputs(); }
|
|
|
|
// Returns the number of results an invocation of the lambda
|
|
// produces
|
|
size_t getNumResults() { return this->keySet->numOutputs(); }
|
|
|
|
// Invocation with an dynamic list of arguments of different
|
|
// types, specified as `LambdaArgument`s
|
|
template <typename ResT = uint64_t>
|
|
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
|
|
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
|
|
args, clientParameters, *this->keySet);
|
|
|
|
if (auto err = publicArguments.takeError()) {
|
|
return err;
|
|
}
|
|
|
|
// Call the lambda
|
|
auto publicResult = this->innerLambda->call(**publicArguments);
|
|
if (auto err = publicResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
|
|
return typedResult<ResT>(*keySet, **publicResult);
|
|
}
|
|
|
|
// Invocation with an array of arguments of the same type
|
|
template <typename T, typename ResT = uint64_t>
|
|
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
|
// Encrypt the arguments
|
|
auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args);
|
|
if (encryptedArgs.has_error()) {
|
|
return StreamStringError(encryptedArgs.error().mesg);
|
|
}
|
|
|
|
// Export as public arguments
|
|
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
|
clientParameters, keySet->runtimeContext());
|
|
if (!publicArguments.has_value()) {
|
|
return StreamStringError(publicArguments.error().mesg);
|
|
}
|
|
|
|
// Call the lambda
|
|
auto publicResult = this->innerLambda->call(*publicArguments.value());
|
|
if (auto err = publicResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
|
|
return typedResult<ResT>(*keySet, **publicResult);
|
|
}
|
|
|
|
// Invocation with arguments of different types
|
|
template <typename ResT = uint64_t, typename... Ts>
|
|
llvm::Expected<ResT> operator()(const Ts... ts) {
|
|
// Encrypt the arguments
|
|
auto encryptedArgs =
|
|
clientlib::EncryptedArguments::create(*keySet, ts...);
|
|
|
|
if (encryptedArgs.has_error()) {
|
|
return StreamStringError(encryptedArgs.error().mesg);
|
|
}
|
|
|
|
// Export as public arguments
|
|
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
|
clientParameters, keySet->runtimeContext());
|
|
if (!publicArguments.has_value()) {
|
|
return StreamStringError(publicArguments.error().mesg);
|
|
}
|
|
|
|
// Call the lambda
|
|
auto publicResult = this->innerLambda->call(*publicArguments.value());
|
|
if (auto err = publicResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
|
|
return typedResult<ResT>(*keySet, **publicResult);
|
|
}
|
|
|
|
protected:
|
|
std::unique_ptr<JITLambda> innerLambda;
|
|
std::unique_ptr<KeySet> keySet;
|
|
std::shared_ptr<CompilationContext> compilationContext;
|
|
const clientlib::ClientParameters clientParameters;
|
|
};
|
|
|
|
JitCompilerEngine(std::shared_ptr<CompilationContext> compilationContext =
|
|
CompilationContext::createShared(),
|
|
unsigned int optimizationLevel = 3);
|
|
|
|
/// Build a Lambda from a source MLIR, with `funcName` as entrypoint.
|
|
/// Use runtimeLibPath as a shared library if specified.
|
|
llvm::Expected<Lambda>
|
|
buildLambda(llvm::StringRef src, llvm::StringRef funcName = "main",
|
|
llvm::Optional<KeySetCache> cachePath = {},
|
|
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
|
|
|
llvm::Expected<Lambda>
|
|
buildLambda(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
|
llvm::StringRef funcName = "main",
|
|
llvm::Optional<KeySetCache> cachePath = {},
|
|
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
|
|
|
llvm::Expected<Lambda>
|
|
buildLambda(llvm::SourceMgr &sm, llvm::StringRef funcName = "main",
|
|
llvm::Optional<KeySetCache> cachePath = {},
|
|
llvm::Optional<llvm::StringRef> runtimeLibPath = {});
|
|
|
|
protected:
|
|
llvm::Expected<mlir::LLVM::LLVMFuncOp> findLLVMFuncOp(mlir::ModuleOp module,
|
|
llvm::StringRef name);
|
|
unsigned int optimizationLevel;
|
|
};
|
|
|
|
} // namespace concretelang
|
|
} // namespace mlir
|
|
|
|
#endif
|