enhance(compiler/support): Refactor lambda support to have fatorized supoort for both lib and jit lambda

This commit is contained in:
Quentin Bourgerie
2022-03-10 13:51:47 +01:00
parent af4851a72b
commit f8968eb489
11 changed files with 687 additions and 304 deletions

View File

@@ -36,9 +36,20 @@ public:
template <typename... Args>
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
create(KeySet &keySet, Args... args) {
auto arguments = std::make_unique<EncryptedArguments>();
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
return arguments;
auto encryptedArgs = std::make_unique<EncryptedArguments>();
OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...));
return std::move(encryptedArgs);
}
template <typename ArgT>
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
create(KeySet &keySet, const llvm::ArrayRef<ArgT> args) {
auto encryptedArgs = EncryptedArguments::empty();
for (size_t i = 0; i < args.size(); i++) {
OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet));
}
OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet));
return std::move(encryptedArgs);
}
static std::unique_ptr<EncryptedArguments> empty() {

View File

@@ -168,8 +168,12 @@ public:
compile(llvm::SourceMgr &sm, Target target,
llvm::Optional<std::shared_ptr<Library>> lib = {});
template <class T>
llvm::Expected<CompilerEngine::Library> compile(std::vector<T> inputs,
llvm::Expected<CompilerEngine::Library>
compile(std::vector<std::string> inputs, std::string libraryPath);
/// Compile and emit artifact to the given libraryPath from an LLVM source
/// manager.
llvm::Expected<CompilerEngine::Library> compile(llvm::SourceMgr &sm,
std::string libraryPath);
void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c);

View File

@@ -11,6 +11,7 @@
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/LambdaArgument.h>
#include <concretelang/Support/LambdaSupport.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
namespace mlir {
@@ -19,200 +20,6 @@ namespace concretelang {
using ::concretelang::clientlib::KeySetCache;
namespace clientlib = ::concretelang::clientlib;
namespace {
// Generic function template as well as specializations of
// `typedResult` must be declared at namespace scope due to return
// type template specialization
// Helper function for `JitCompilerEngine::Lambda::operator()`
// implementing type-dependent preparation of the result.
template <typename ResT>
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result);
// Specialization of `typedResult()` for scalar results, forwarding
// scalar value to caller
template <>
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedResult cannot get clear text vector")
<< clearResult.error().mesg;
}
if (clearResult.value().size() != 1) {
return StreamStringError("typedResult expect only one value but got ")
<< clearResult.value().size();
}
return clearResult.value()[0];
}
template <typename T>
inline llvm::Expected<std::vector<T>>
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedVectorResult cannot get clear text vector")
<< clearResult.error().mesg;
}
return std::move(clearResult.value());
}
// Specializations of `typedResult()` for vector results, initializing
// an `std::vector` of the right size with the results and forwarding
// it to the caller with move semantics.
//
// Cannot factor out into a template template <typename T> inline
// llvm::Expected<std::vector<uint8_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
// to ambiguity with scalar template
// template <>
// inline llvm::Expected<std::vector<uint8_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint8_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint16_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint16_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint32_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint32_t>(keySet, result);
// }
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint64_t>(keySet, result);
}
template <typename T>
llvm::Expected<std::unique_ptr<LambdaArgument>>
buildTensorLambdaResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
llvm::Expected<std::vector<T>> tensorOrError =
typedResult<std::vector<T>>(keySet, result);
if (auto err = tensorOrError.takeError())
return std::move(err);
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
result.buffers[0].sizes.end() - 1);
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, tensorDim);
}
// Specialization of `typedResult()` for a single result wrapped into
// a `LambdaArgument`.
template <>
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto gate = keySet.outputGate(0);
// scalar case
if (gate.shape.dimensions.empty()) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (clearResult.has_error()) {
return StreamStringError("typedResult: ") << clearResult.error().mesg;
}
auto res = clearResult.value()[0];
return std::make_unique<IntLambdaArgument<uint64_t>>(res);
}
// tensor case
// auto width = gate.shape.width;
// if (width > 32)
return buildTensorLambdaResult<uint64_t>(keySet, result);
// else if (width > 16)
// return buildTensorLambdaResult<uint32_t>(keySet, result);
// else if (width > 8)
// return buildTensorLambdaResult<uint16_t>(keySet, result);
// else if (width <= 8)
// return buildTensorLambdaResult<uint8_t>(keySet, result);
// return StreamStringError("Cannot handle scalars with more than 64 bits");
}
} // namespace
// Adaptor class that push arguments specified as instances of
// `LambdaArgument` to `clientlib::EncryptedArguments`.
class JITLambdaArgumentAdaptor {
public:
// Checks if the argument `arg` is an plaintext / encrypted integer
// argument or a plaintext / encrypted tensor argument with a
// backing integer type `IntT` and push the argument to `encryptedArgs`.
//
// Returns `true` if `arg` has one of the types above and its value
// was successfully added to `encryptedArgs`, `false` if none of the types
// matches or an error if a type matched, but adding the argument to
// `encryptedArgs` failed.
template <typename IntT>
static inline llvm::Expected<bool>
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
} else if (auto tla = arg.dyn_cast<
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
auto res =
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
}
return false;
}
// Recursive case for `tryAddArg<IntT>(...)`
template <typename IntT, typename NextIntT, typename... IntTs>
static inline llvm::Expected<bool>
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
llvm::Expected<bool> successOrError =
tryAddArg<IntT>(encryptedArgs, arg, keySet);
if (!successOrError)
return successOrError.takeError();
if (successOrError.get() == false)
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
else
return true;
}
// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
// error if either the argument type is unsupported or if the argument types
// is supported, but adding it to `encryptedArgs` failed.
static inline llvm::Error
addArgument(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
// Try the supported integer types; size_t needs explicit
// treatment, since it may alias none of the fixed size integer
// types
llvm::Expected<bool> successOrError =
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
uint8_t, size_t>(encryptedArgs, arg,
keySet);
if (!successOrError)
return successOrError.takeError();
if (successOrError.get() == false)
return StreamStringError("Unknown argument type");
else
return llvm::Error::success();
}
};
// A compiler engine that JIT-compiles a source and produces a lambda
// object directly invocable through its call operator.
class JitCompilerEngine : public CompilerEngine {
@@ -246,32 +53,16 @@ public:
// Invocation with an dynamic list of arguments of different
// types, specified as `LambdaArgument`s
template <typename ResT = uint64_t>
llvm::Expected<ResT>
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
// Encrypt the arguments
auto encryptedArgs = clientlib::EncryptedArguments::empty();
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
args, clientParameters, *this->keySet);
for (size_t i = 0; i < lambdaArgs.size(); i++) {
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
*encryptedArgs, *lambdaArgs[i], *this->keySet)) {
return std::move(err);
}
}
auto check = encryptedArgs->checkAllArgs(*this->keySet);
if (check.has_error()) {
return StreamStringError(check.error().mesg);
}
// Export as public arguments
auto publicArguments = encryptedArgs->exportPublicArguments(
clientParameters, keySet->runtimeContext());
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);
if (auto err = publicArguments.takeError()) {
return err;
}
// Call the lambda
auto publicResult = this->innerLambda->call(*publicArguments.value());
auto publicResult = this->innerLambda->call(**publicArguments);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
@@ -283,22 +74,13 @@ public:
template <typename T, typename ResT = uint64_t>
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
// Encrypt the arguments
auto encryptedArgs = clientlib::EncryptedArguments::empty();
for (size_t i = 0; i < args.size(); i++) {
auto res = encryptedArgs->pushArg(args[i], *keySet);
if (res.has_error()) {
return StreamStringError(res.error().mesg);
}
}
auto check = encryptedArgs->checkAllArgs(*this->keySet);
if (check.has_error()) {
return StreamStringError(check.error().mesg);
auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args);
if (encryptedArgs.has_error()) {
return StreamStringError(encryptedArgs.error().mesg);
}
// Export as public arguments
auto publicArguments = encryptedArgs->exportPublicArguments(
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
clientParameters, keySet->runtimeContext());
if (!publicArguments.has_value()) {
return StreamStringError(publicArguments.error().mesg);

View File

@@ -0,0 +1,69 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_JITLAMBDA_SUPPORT
#define CONCRETELANG_SUPPORT_JITLAMBDA_SUPPORT
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/LambdaSupport.h>
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
/// JitCompilationResult is the result of a Jit compilation, the server JIT
/// lambda and the clientParameters.
struct JitCompilationResult {
std::unique_ptr<concretelang::JITLambda> lambda;
clientlib::ClientParameters clientParameters;
};
/// JitLambdaSupport is the instantiated LambdaSupport for the Jit Compilation.
class JitLambdaSupport
: public LambdaSupport<concretelang::JITLambda *, JitCompilationResult> {
public:
JitLambdaSupport(
llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None,
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline =
mlir::makeOptimizingTransformer(3, 0, nullptr))
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
llvm::Expected<std::unique_ptr<JitCompilationResult>>
compile(llvm::SourceMgr &program, std::string funcname = "main") override;
using LambdaSupport::compile;
llvm::Expected<concretelang::JITLambda *>
loadServerLambda(JitCompilationResult &result) override {
return result.lambda.get();
}
llvm::Expected<clientlib::ClientParameters>
loadClientParameters(JitCompilationResult &result) override {
return result.clientParameters;
}
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(concretelang::JITLambda *lambda,
clientlib::PublicArguments &args) override {
return lambda->call(args);
}
private:
llvm::Optional<llvm::StringRef> runtimeLibPath;
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,318 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_LAMBDASUPPORT
#define CONCRETELANG_SUPPORT_LAMBDASUPPORT
#include "boost/outcome.h"
#include "concretelang/Support/LambdaArgument.h"
#include "concretelang/ClientLib/ClientLambda.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Common/Error.h"
#include "concretelang/ServerLib/ServerLambda.h"
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
namespace {
// Generic function template as well as specializations of
// `typedResult` must be declared at namespace scope due to return
// type template specialization
// Helper function for `JitCompilerEngine::Lambda::operator()`
// implementing type-dependent preparation of the result.
template <typename ResT>
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result);
// Specialization of `typedResult()` for scalar results, forwarding
// scalar value to caller
template <>
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedResult cannot get clear text vector")
<< clearResult.error().mesg;
}
if (clearResult.value().size() != 1) {
return StreamStringError("typedResult expect only one value but got ")
<< clearResult.value().size();
}
return clearResult.value()[0];
}
template <typename T>
inline llvm::Expected<std::vector<T>>
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (!clearResult.has_value()) {
return StreamStringError("typedVectorResult cannot get clear text vector")
<< clearResult.error().mesg;
}
return std::move(clearResult.value());
}
// Specializations of `typedResult()` for vector results, initializing
// an `std::vector` of the right size with the results and forwarding
// it to the caller with move semantics.
//
// Cannot factor out into a template template <typename T> inline
// llvm::Expected<std::vector<uint8_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
// to ambiguity with scalar template
// template <>
// inline llvm::Expected<std::vector<uint8_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint8_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint16_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint16_t>(keySet, result);
// }
// template <>
// inline llvm::Expected<std::vector<uint32_t>>
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
// return typedVectorResult<uint32_t>(keySet, result);
// }
template <>
inline llvm::Expected<std::vector<uint64_t>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
return typedVectorResult<uint64_t>(keySet, result);
}
template <typename T>
llvm::Expected<std::unique_ptr<LambdaArgument>>
buildTensorLambdaResult(clientlib::KeySet &keySet,
clientlib::PublicResult &result) {
llvm::Expected<std::vector<T>> tensorOrError =
typedResult<std::vector<T>>(keySet, result);
if (auto err = tensorOrError.takeError())
return std::move(err);
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
result.buffers[0].sizes.end() - 1);
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
*tensorOrError, tensorDim);
}
// pecialization of `typedResult()` for a single result wrapped into
// a `LambdaArgument`.
template <>
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
auto gate = keySet.outputGate(0);
// scalar case
if (gate.shape.dimensions.empty()) {
auto clearResult = result.asClearTextVector(keySet, 0);
if (clearResult.has_error()) {
return StreamStringError("typedResult: ") << clearResult.error().mesg;
}
auto res = clearResult.value()[0];
return std::make_unique<IntLambdaArgument<uint64_t>>(res);
}
// tensor case
// auto width = gate.shape.width;
// if (width > 32)
return buildTensorLambdaResult<uint64_t>(keySet, result);
// else if (width > 16)
// return buildTensorLambdaResult<uint32_t>(keySet, result);
// else if (width > 8)
// return buildTensorLambdaResult<uint16_t>(keySet, result);
// else if (width <= 8)
// return buildTensorLambdaResult<uint8_t>(keySet, result);
// return StreamStringError("Cannot handle scalars with more than 64 bits");
}
} // namespace
// Adaptor class that push arguments specified as instances of
// `LambdaArgument` to `clientlib::EncryptedArguments`.
class LambdaArgumentAdaptor {
public:
// Checks if the argument `arg` is an plaintext / encrypted integer
// argument or a plaintext / encrypted tensor argument with a
// backing integer type `IntT` and push the argument to `encryptedArgs`.
//
// Returns `true` if `arg` has one of the types above and its value
// was successfully added to `encryptedArgs`, `false` if none of the types
// matches or an error if a type matched, but adding the argument to
// `encryptedArgs` failed.
template <typename IntT>
static inline llvm::Expected<bool>
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
} else if (auto tla = arg.dyn_cast<
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
auto res =
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
if (!res.has_value()) {
return StreamStringError(res.error().mesg);
} else {
return true;
}
}
return false;
}
// Recursive case for `tryAddArg<IntT>(...)`
template <typename IntT, typename NextIntT, typename... IntTs>
static inline llvm::Expected<bool>
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
llvm::Expected<bool> successOrError =
tryAddArg<IntT>(encryptedArgs, arg, keySet);
if (!successOrError)
return successOrError.takeError();
if (successOrError.get() == false)
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
else
return true;
}
// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
// error if either the argument type is unsupported or if the argument types
// is supported, but adding it to `encryptedArgs` failed.
static inline llvm::Error
addArgument(clientlib::EncryptedArguments &encryptedArgs,
const LambdaArgument &arg, clientlib::KeySet &keySet) {
// Try the supported integer types; size_t needs explicit
// treatment, since it may alias none of the fixed size integer
// types
llvm::Expected<bool> successOrError =
LambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t, uint8_t,
size_t>(encryptedArgs, arg, keySet);
if (!successOrError)
return successOrError.takeError();
if (successOrError.get() == false)
return StreamStringError("Unknown argument type");
else
return llvm::Error::success();
}
/// Encrypts and build public arguments from lambda arguments
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
exportArguments(llvm::ArrayRef<LambdaArgument *> args,
clientlib::ClientParameters clientParameters,
clientlib::KeySet &keySet) {
auto encryptedArgs = clientlib::EncryptedArguments::empty();
for (auto arg : args) {
if (auto err = LambdaArgumentAdaptor::addArgument(*encryptedArgs, *arg,
keySet)) {
return std::move(err);
}
}
auto check = encryptedArgs->checkAllArgs(keySet);
if (check.has_error()) {
return StreamStringError(check.error().mesg);
}
auto publicArguments = encryptedArgs->exportPublicArguments(
clientParameters, keySet.runtimeContext());
if (publicArguments.has_error()) {
return StreamStringError(publicArguments.error().mesg);
}
return std::move(publicArguments.value());
}
};
template <typename Lambda, typename CompilationResult> class LambdaSupport {
public:
virtual ~LambdaSupport() {}
/// Compile the mlir program and produces a compilation result if succeed.
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
llvm::SourceMgr &program, std::string funcname = "main");
llvm::Expected<std::unique_ptr<CompilationResult>>
compile(llvm::StringRef program, std::string funcname = "main") {
return compile(llvm::MemoryBuffer::getMemBuffer(program), funcname);
}
llvm::Expected<std::unique_ptr<CompilationResult>>
compile(std::unique_ptr<llvm::MemoryBuffer> program,
std::string funcname = "main") {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc());
return compile(sm, funcname);
}
/// Load the server lambda from the compilation result.
llvm::Expected<Lambda> virtual loadServerLambda(CompilationResult &result);
/// Load the client parameters from the compilation result.
llvm::Expected<clientlib::ClientParameters> virtual loadClientParameters(
CompilationResult &result);
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
Lambda lambda, clientlib::PublicArguments &args);
/// Build the client KeySet from the client parameters.
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
keySet(clientlib::ClientParameters clientParameters,
llvm::Optional<clientlib::KeySetCache> cache) {
std::shared_ptr<clientlib::KeySetCache> cachePtr;
if (cache.hasValue()) {
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.getValue());
}
auto keySet =
clientlib::KeySetCache::generate(cachePtr, clientParameters, 0, 0);
if (keySet.has_error()) {
return StreamStringError(keySet.error().mesg);
}
return std::move(keySet.value());
}
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
exportArguments(clientlib::ClientParameters clientParameters,
clientlib::KeySet &keySet,
llvm::ArrayRef<LambdaArgument *> args) {
return LambdaArgumentAdaptor::exportArguments(args, clientParameters,
keySet);
}
template <typename ResT>
static llvm::Expected<ResT>
call(Lambda lambda, clientlib::PublicArguments &publicArguments) {
// Call the lambda
auto publicResult = LambdaSupport<Lambda, CompilationResult>().serverCall(
lambda, publicArguments);
if (auto err = publicResult.takeError()) {
return std::move(err);
}
// Decrypt the result
return typedResult<ResT>(keySet, **publicResult);
}
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -0,0 +1,103 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_SUPPORT_LIBRARY_LAMBDA_SUPPORT
#define CONCRETELANG_SUPPORT_LIBRARY_LAMBDA_SUPPORT
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/ExecutionEngine/ExecutionEngine.h>
#include <mlir/ExecutionEngine/OptUtils.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include <concretelang/ServerLib/ServerLambda.h>
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/LambdaSupport.h>
namespace mlir {
namespace concretelang {
namespace clientlib = ::concretelang::clientlib;
namespace serverlib = ::concretelang::serverlib;
/// LibraryCompilationResult is the result of a compilation to a library.
struct LibraryCompilationResult {
/// The output path where the compilation artifact has been generated.
std::string libraryPath;
std::string funcName;
};
class LibraryLambdaSupport
: public LambdaSupport<serverlib::ServerLambda, LibraryCompilationResult> {
public:
LibraryLambdaSupport(std::string outputPath = "/tmp/toto")
: outputPath(outputPath) {}
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
compile(llvm::SourceMgr &program, std::string funcname = "main") override {
// Setup the compiler engine
auto context = CompilationContext::createShared();
concretelang::CompilerEngine engine(context);
engine.setClientParametersFuncName(funcname);
// Compile to a library
auto library = engine.compile(program, outputPath);
if (auto err = library.takeError()) {
return std::move(err);
}
auto result = std::make_unique<LibraryCompilationResult>();
result->libraryPath = outputPath;
result->funcName = funcname;
return std::move(result);
}
using LambdaSupport::compile;
/// Load the server lambda from the compilation result.
llvm::Expected<serverlib::ServerLambda>
loadServerLambda(LibraryCompilationResult &result) override {
auto lambda =
serverlib::ServerLambda::load(result.funcName, result.libraryPath);
if (lambda.has_error()) {
return StreamStringError(lambda.error().mesg);
}
return lambda.value();
}
/// Load the client parameters from the compilation result.
llvm::Expected<clientlib::ClientParameters>
loadClientParameters(LibraryCompilationResult &result) override {
auto path = ClientParameters::getClientParametersPath(result.libraryPath);
auto params = ClientParameters::load(path);
if (params.has_error()) {
return StreamStringError(params.error().mesg);
}
auto param = llvm::find_if(params.value(), [&](ClientParameters param) {
return param.functionName == result.funcName;
});
if (param == params.value().end()) {
return StreamStringError("ClientLambda: cannot find function(")
<< result.funcName << ") in client parameters path(" << path
<< ")";
}
return *param;
}
/// Call the lambda with the public arguments.
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
serverCall(serverlib::ServerLambda lambda,
clientlib::PublicArguments &args) override {
return lambda.call(args);
}
private:
std::string outputPath;
};
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -3,6 +3,7 @@ add_mlir_library(ConcretelangSupport
Jit.cpp
CompilerEngine.cpp
JitCompilerEngine.cpp
JitLambdaSupport.cpp
LambdaArgument.cpp
V0Parameters.cpp
V0Curves.cpp
@@ -32,4 +33,5 @@ add_mlir_library(ConcretelangSupport
ConcretelangRuntime
ConcretelangClientLib
ConcretelangServerLib
)

View File

@@ -383,19 +383,17 @@ llvm::Expected<CompilerEngine::CompilationResult>
CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
Target target, OptionalLib lib) {
llvm::SourceMgr sm;
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
return this->compile(sm, target, lib);
}
template <class T>
llvm::Expected<CompilerEngine::Library>
CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
CompilerEngine::compile(std::vector<std::string> inputs,
std::string libraryPath) {
using Library = mlir::concretelang::CompilerEngine::Library;
auto outputLib = std::make_shared<Library>(libraryPath);
auto target = CompilerEngine::Target::LIBRARY;
for (auto input : inputs) {
auto compilation = compile(input, target, outputLib);
if (!compilation) {
@@ -403,6 +401,24 @@ CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
<< llvm::toString(compilation.takeError());
}
}
if (auto err = outputLib->emitArtifacts()) {
return StreamStringError("Can't emit artifacts: ")
<< llvm::toString(std::move(err));
}
return *outputLib.get();
}
llvm::Expected<CompilerEngine::Library>
CompilerEngine::compile(llvm::SourceMgr &sm, std::string libraryPath) {
using Library = mlir::concretelang::CompilerEngine::Library;
auto outputLib = std::make_shared<Library>(libraryPath);
auto target = CompilerEngine::Target::LIBRARY;
auto compilation = compile(sm, target, outputLib);
if (!compilation) {
return StreamStringError("Can't compile: ")
<< llvm::toString(compilation.takeError());
}
if (auto err = outputLib->emitArtifacts()) {
return StreamStringError("Can't emit artifacts: ")
@@ -411,11 +427,6 @@ CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
return *outputLib.get();
}
// explicit instantiation for a vector of string (for linking with lib/CAPI)
template llvm::Expected<CompilerEngine::Library>
CompilerEngine::compile(std::vector<std::string> inputs,
std::string libraryPath);
/** Returns the path of the shared library */
std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) {
return path + DOT_SHARED_LIB_EXT;

View File

@@ -0,0 +1,49 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include <concretelang/Support/JitLambdaSupport.h>
namespace mlir {
namespace concretelang {
llvm::Expected<std::unique_ptr<JitCompilationResult>>
JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
// Setup the compiler engine
auto context = std::make_shared<CompilationContext>();
concretelang::CompilerEngine engine(context);
// We need client parameters to be generated
engine.setGenerateClientParameters(true);
engine.setClientParametersFuncName(funcname);
// Compile to LLVM Dialect
auto compilationResult =
engine.compile(program, CompilerEngine::Target::LLVM_IR);
if (auto err = compilationResult.takeError()) {
return std::move(err);
}
// Compile from LLVM Dialect to JITLambda
auto mlirModule = compilationResult.get().mlirModuleRef->get();
auto lambda = concretelang::JITLambda::create(
funcname, mlirModule, llvmOptPipeline, runtimeLibPath);
if (auto err = lambda.takeError()) {
return std::move(err);
}
if (!compilationResult.get().clientParameters.hasValue()) {
// i.e. that should not occurs
return StreamStringError("No client parameters has been generated");
}
auto result = std::make_unique<JitCompilationResult>();
result->lambda = std::move(*lambda);
result->clientParameters =
compilationResult.get().clientParameters.getValue();
return std::move(result);
}
} // namespace concretelang
} // namespace mlir

View File

@@ -61,8 +61,8 @@ valueDescriptionToLambdaArgument(ValueDescription desc) {
}
llvm::Error checkResult(ScalarDesc &desc,
mlir::concretelang::LambdaArgument *res) {
auto res64 = res->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
mlir::concretelang::LambdaArgument &res) {
auto res64 = res.dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
if (res64 == nullptr) {
return StreamStringError("invocation result is not a scalar");
}
@@ -85,7 +85,7 @@ checkTensorResult(TensorDescription &desc,
<< resShape.size() << " expected " << desc.shape.size();
}
for (size_t i = 0; i < desc.shape.size(); i++) {
if ((uint64_t)resShape[i] != desc.shape[i]) {
if (resShape[i] != desc.shape[i]) {
return StreamStringError("shape differs at pos ")
<< i << ", got " << resShape[i] << " expected " << desc.shape[i];
}
@@ -112,37 +112,36 @@ checkTensorResult(TensorDescription &desc,
}
llvm::Error checkResult(TensorDescription &desc,
mlir::concretelang::LambdaArgument *res) {
mlir::concretelang::LambdaArgument &res) {
switch (desc.width) {
case 8:
return checkTensorResult<uint8_t>(
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint8_t>>>());
case 16:
return checkTensorResult<uint16_t>(
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint16_t>>>());
case 32:
return checkTensorResult<uint32_t>(
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint32_t>>>());
case 64:
return checkTensorResult<uint64_t>(
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
mlir::concretelang::IntLambdaArgument<uint64_t>>>());
default:
return StreamStringError("Unsupported width");
}
}
llvm::Error
checkResult(ValueDescription &desc,
std::unique_ptr<mlir::concretelang::LambdaArgument> &res) {
llvm::Error checkResult(ValueDescription &desc,
mlir::concretelang::LambdaArgument &res) {
switch (desc.tag) {
case ValueDescription::SCALAR:
return checkResult(desc.scalar, res.get());
return checkResult(desc.scalar, res);
case ValueDescription::TENSOR:
return checkResult(desc.tensor, res.get());
return checkResult(desc.tensor, res);
}
assert(false);
}
@@ -190,7 +189,7 @@ template <> struct llvm::yaml::MappingTraits<EndToEndDesc> {
}
};
LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(EndToEndDesc);
LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(EndToEndDesc)
std::vector<EndToEndDesc> loadEndToEndDesc(std::string path) {
std::ifstream file(path);

View File

@@ -4,53 +4,88 @@
#include <type_traits>
#include "EndToEndFixture.h"
#include "concretelang/Support/JitLambdaSupport.h"
#include "concretelang/Support/LibraryLambdaSupport.h"
class EndToEndJitTest : public testing::TestWithParam<EndToEndDesc> {};
TEST_P(EndToEndJitTest, compile_and_run) {
EndToEndDesc desc = GetParam();
// Compile program
// mlir::concretelang::JitCompilerEngine::Lambda lambda =
checkedJit(lambda, desc.program);
// Prepare arguments
for (auto test : desc.tests) {
std::vector<mlir::concretelang::LambdaArgument *> inputArguments;
inputArguments.reserve(test.inputs.size());
for (auto input : test.inputs) {
auto arg = valueDescriptionToLambdaArgument(input);
ASSERT_EXPECTED_SUCCESS(arg);
inputArguments.push_back(arg.get());
}
// Call the lambda
auto res =
lambda.operator()<std::unique_ptr<mlir::concretelang::LambdaArgument>>(
llvm::ArrayRef<mlir::concretelang::LambdaArgument *>(
inputArguments));
ASSERT_EXPECTED_SUCCESS(res);
if (test.outputs.size() != 1) {
FAIL() << "Only one result function are supported.";
}
ASSERT_LLVM_ERROR(checkResult(test.outputs[0], res.get()));
// Free arguments
for (auto arg : inputArguments) {
delete arg;
}
}
}
#define INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE(prefix, path) \
namespace prefix { \
auto valuesVector = loadEndToEndDesc(path); \
auto values = testing::ValuesIn<std::vector<EndToEndDesc>>(valuesVector); \
INSTANTIATE_TEST_SUITE_P(prefix, EndToEndJitTest, values, \
printEndToEndDesc); \
// Macro to define and end to end TestSuite that run test thanks the
// LambdaSupport according a EndToEndDesc
#define INSTANTIATE_END_TO_END_COMPILE_AND_RUN(TestSuite, LambdaSupport) \
TEST_P(TestSuite, compile_and_run) { \
\
auto desc = GetParam(); \
\
LambdaSupport support; \
\
/* 1 - Compile the program */ \
auto compilationResult = support.compile(desc.program); \
ASSERT_EXPECTED_SUCCESS(compilationResult); \
\
/* 2 - Load the client parameters and build the keySet */ \
auto clientParameters = support.loadClientParameters(**compilationResult); \
ASSERT_EXPECTED_SUCCESS(clientParameters); \
\
auto keySet = support.keySet(*clientParameters, getTestKeySetCache()); \
ASSERT_EXPECTED_SUCCESS(keySet); \
\
/* 3 - Load the server lambda */ \
auto serverLambda = support.loadServerLambda(**compilationResult); \
ASSERT_EXPECTED_SUCCESS(serverLambda); \
\
/* For each test entries */ \
for (auto test : desc.tests) { \
std::vector<mlir::concretelang::LambdaArgument *> inputArguments; \
inputArguments.reserve(test.inputs.size()); \
for (auto input : test.inputs) { \
auto arg = valueDescToLambdaArgument(input); \
ASSERT_EXPECTED_SUCCESS(arg); \
inputArguments.push_back(arg.get()); \
} \
/* 4 - Create the public arguments */ \
auto publicArguments = support.exportArguments( \
*clientParameters, **keySet, inputArguments); \
ASSERT_EXPECTED_SUCCESS(publicArguments); \
\
/* 5 - Call the server lambda */ \
auto publicResult = \
support.serverCall(*serverLambda, **publicArguments); \
ASSERT_EXPECTED_SUCCESS(publicResult); \
\
/* 6 - Decrypt the public result */ \
auto result = mlir::concretelang::typedResult< \
std::unique_ptr<mlir::concretelang::LambdaArgument>>( \
**keySet, **publicResult); \
\
ASSERT_EXPECTED_SUCCESS(result); \
\
for (auto arg : inputArguments) { \
delete arg; \
} \
} \
}
INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE(
FHE, "tests/unittest/end_to_end_fhe.yaml")
INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE(
EncryptedTensor, "tests/unittest/end_to_end_encrypted_tensor.yaml")
#define INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE(prefix, suite, \
lambdasupport, path) \
namespace prefix##suite { \
auto valuesVector = loadEndToEndDesc(path); \
auto values = testing::ValuesIn<std::vector<EndToEndDesc>>(valuesVector); \
INSTANTIATE_TEST_SUITE_P(prefix, suite, values, printEndToEndDesc); \
}
#define INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(suite, \
lambdasupport) \
\
class suite : public testing::TestWithParam<EndToEndDesc> {}; \
INSTANTIATE_END_TO_END_COMPILE_AND_RUN(suite, lambdasupport) \
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE( \
FHE, suite, lambdasupport, "tests/unittest/end_to_end_fhe.yaml") \
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE( \
EncryptedTensor, suite, lambdasupport, \
"tests/unittest/end_to_end_encrypted_tensor.yaml")
/// Instantiate the test suite for Jit
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
JitTest, mlir::concretelang::JitLambdaSupport)
/// Instantiate the test suite for Jit
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
LibraryTest, mlir::concretelang::LibraryLambdaSupport)