mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(compiler/support): Refactor lambda support to have fatorized supoort for both lib and jit lambda
This commit is contained in:
@@ -36,9 +36,20 @@ public:
|
||||
template <typename... Args>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(KeySet &keySet, Args... args) {
|
||||
auto arguments = std::make_unique<EncryptedArguments>();
|
||||
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
|
||||
return arguments;
|
||||
auto encryptedArgs = std::make_unique<EncryptedArguments>();
|
||||
OUTCOME_TRYV(encryptedArgs->pushArgs(keySet, args...));
|
||||
return std::move(encryptedArgs);
|
||||
}
|
||||
|
||||
template <typename ArgT>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(KeySet &keySet, const llvm::ArrayRef<ArgT> args) {
|
||||
auto encryptedArgs = EncryptedArguments::empty();
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
OUTCOME_TRYV(encryptedArgs->pushArg(args[i], keySet));
|
||||
}
|
||||
OUTCOME_TRYV(encryptedArgs->checkAllArgs(keySet));
|
||||
return std::move(encryptedArgs);
|
||||
}
|
||||
|
||||
static std::unique_ptr<EncryptedArguments> empty() {
|
||||
|
||||
@@ -168,8 +168,12 @@ public:
|
||||
compile(llvm::SourceMgr &sm, Target target,
|
||||
llvm::Optional<std::shared_ptr<Library>> lib = {});
|
||||
|
||||
template <class T>
|
||||
llvm::Expected<CompilerEngine::Library> compile(std::vector<T> inputs,
|
||||
llvm::Expected<CompilerEngine::Library>
|
||||
compile(std::vector<std::string> inputs, std::string libraryPath);
|
||||
|
||||
/// Compile and emit artifact to the given libraryPath from an LLVM source
|
||||
/// manager.
|
||||
llvm::Expected<CompilerEngine::Library> compile(llvm::SourceMgr &sm,
|
||||
std::string libraryPath);
|
||||
|
||||
void setFHEConstraints(const mlir::concretelang::V0FHEConstraint &c);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/LambdaArgument.h>
|
||||
#include <concretelang/Support/LambdaSupport.h>
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
|
||||
namespace mlir {
|
||||
@@ -19,200 +20,6 @@ namespace concretelang {
|
||||
using ::concretelang::clientlib::KeySetCache;
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
namespace {
|
||||
// Generic function template as well as specializations of
|
||||
// `typedResult` must be declared at namespace scope due to return
|
||||
// type template specialization
|
||||
|
||||
// Helper function for `JitCompilerEngine::Lambda::operator()`
|
||||
// implementing type-dependent preparation of the result.
|
||||
template <typename ResT>
|
||||
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result);
|
||||
|
||||
// Specialization of `typedResult()` for scalar results, forwarding
|
||||
// scalar value to caller
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
if (clearResult.value().size() != 1) {
|
||||
return StreamStringError("typedResult expect only one value but got ")
|
||||
<< clearResult.value().size();
|
||||
}
|
||||
return clearResult.value()[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline llvm::Expected<std::vector<T>>
|
||||
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedVectorResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
return std::move(clearResult.value());
|
||||
}
|
||||
|
||||
// Specializations of `typedResult()` for vector results, initializing
|
||||
// an `std::vector` of the right size with the results and forwarding
|
||||
// it to the caller with move semantics.
|
||||
//
|
||||
// Cannot factor out into a template template <typename T> inline
|
||||
// llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
||||
// to ambiguity with scalar template
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint8_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint16_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint16_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint32_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint32_t>(keySet, result);
|
||||
// }
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint64_t>(keySet, result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
buildTensorLambdaResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
llvm::Expected<std::vector<T>> tensorOrError =
|
||||
typedResult<std::vector<T>>(keySet, result);
|
||||
|
||||
if (auto err = tensorOrError.takeError())
|
||||
return std::move(err);
|
||||
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
|
||||
result.buffers[0].sizes.end() - 1);
|
||||
|
||||
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
||||
*tensorOrError, tensorDim);
|
||||
}
|
||||
|
||||
// Specialization of `typedResult()` for a single result wrapped into
|
||||
// a `LambdaArgument`.
|
||||
template <>
|
||||
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto gate = keySet.outputGate(0);
|
||||
// scalar case
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (clearResult.has_error()) {
|
||||
return StreamStringError("typedResult: ") << clearResult.error().mesg;
|
||||
}
|
||||
auto res = clearResult.value()[0];
|
||||
|
||||
return std::make_unique<IntLambdaArgument<uint64_t>>(res);
|
||||
}
|
||||
// tensor case
|
||||
// auto width = gate.shape.width;
|
||||
|
||||
// if (width > 32)
|
||||
return buildTensorLambdaResult<uint64_t>(keySet, result);
|
||||
// else if (width > 16)
|
||||
// return buildTensorLambdaResult<uint32_t>(keySet, result);
|
||||
// else if (width > 8)
|
||||
// return buildTensorLambdaResult<uint16_t>(keySet, result);
|
||||
// else if (width <= 8)
|
||||
// return buildTensorLambdaResult<uint8_t>(keySet, result);
|
||||
|
||||
// return StreamStringError("Cannot handle scalars with more than 64 bits");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Adaptor class that push arguments specified as instances of
|
||||
// `LambdaArgument` to `clientlib::EncryptedArguments`.
|
||||
class JITLambdaArgumentAdaptor {
|
||||
public:
|
||||
// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
// argument or a plaintext / encrypted tensor argument with a
|
||||
// backing integer type `IntT` and push the argument to `encryptedArgs`.
|
||||
//
|
||||
// Returns `true` if `arg` has one of the types above and its value
|
||||
// was successfully added to `encryptedArgs`, `false` if none of the types
|
||||
// matches or an error if a type matched, but adding the argument to
|
||||
// `encryptedArgs` failed.
|
||||
template <typename IntT>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
|
||||
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else if (auto tla = arg.dyn_cast<
|
||||
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
|
||||
auto res =
|
||||
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recursive case for `tryAddArg<IntT>(...)`
|
||||
template <typename IntT, typename NextIntT, typename... IntTs>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
llvm::Expected<bool> successOrError =
|
||||
tryAddArg<IntT>(encryptedArgs, arg, keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
|
||||
// error if either the argument type is unsupported or if the argument types
|
||||
// is supported, but adding it to `encryptedArgs` failed.
|
||||
static inline llvm::Error
|
||||
addArgument(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
// Try the supported integer types; size_t needs explicit
|
||||
// treatment, since it may alias none of the fixed size integer
|
||||
// types
|
||||
llvm::Expected<bool> successOrError =
|
||||
JITLambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t,
|
||||
uint8_t, size_t>(encryptedArgs, arg,
|
||||
keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return StreamStringError("Unknown argument type");
|
||||
else
|
||||
return llvm::Error::success();
|
||||
}
|
||||
};
|
||||
|
||||
// A compiler engine that JIT-compiles a source and produces a lambda
|
||||
// object directly invocable through its call operator.
|
||||
class JitCompilerEngine : public CompilerEngine {
|
||||
@@ -246,32 +53,16 @@ public:
|
||||
// Invocation with an dynamic list of arguments of different
|
||||
// types, specified as `LambdaArgument`s
|
||||
template <typename ResT = uint64_t>
|
||||
llvm::Expected<ResT>
|
||||
operator()(llvm::ArrayRef<LambdaArgument *> lambdaArgs) {
|
||||
// Encrypt the arguments
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::empty();
|
||||
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
|
||||
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
|
||||
args, clientParameters, *this->keySet);
|
||||
|
||||
for (size_t i = 0; i < lambdaArgs.size(); i++) {
|
||||
if (llvm::Error err = JITLambdaArgumentAdaptor::addArgument(
|
||||
*encryptedArgs, *lambdaArgs[i], *this->keySet)) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
|
||||
auto check = encryptedArgs->checkAllArgs(*this->keySet);
|
||||
if (check.has_error()) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
|
||||
// Export as public arguments
|
||||
auto publicArguments = encryptedArgs->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
if (auto err = publicArguments.takeError()) {
|
||||
return err;
|
||||
}
|
||||
|
||||
// Call the lambda
|
||||
auto publicResult = this->innerLambda->call(*publicArguments.value());
|
||||
auto publicResult = this->innerLambda->call(**publicArguments);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -283,22 +74,13 @@ public:
|
||||
template <typename T, typename ResT = uint64_t>
|
||||
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
||||
// Encrypt the arguments
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::empty();
|
||||
|
||||
for (size_t i = 0; i < args.size(); i++) {
|
||||
auto res = encryptedArgs->pushArg(args[i], *keySet);
|
||||
if (res.has_error()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
}
|
||||
}
|
||||
|
||||
auto check = encryptedArgs->checkAllArgs(*this->keySet);
|
||||
if (check.has_error()) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args);
|
||||
if (encryptedArgs.has_error()) {
|
||||
return StreamStringError(encryptedArgs.error().mesg);
|
||||
}
|
||||
|
||||
// Export as public arguments
|
||||
auto publicArguments = encryptedArgs->exportPublicArguments(
|
||||
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
||||
clientParameters, keySet->runtimeContext());
|
||||
if (!publicArguments.has_value()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
|
||||
69
compiler/include/concretelang/Support/JitLambdaSupport.h
Normal file
69
compiler/include/concretelang/Support/JitLambdaSupport.h
Normal file
@@ -0,0 +1,69 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_JITLAMBDA_SUPPORT
|
||||
#define CONCRETELANG_SUPPORT_JITLAMBDA_SUPPORT
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/LambdaSupport.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
/// JitCompilationResult is the result of a Jit compilation, the server JIT
|
||||
/// lambda and the clientParameters.
|
||||
struct JitCompilationResult {
|
||||
std::unique_ptr<concretelang::JITLambda> lambda;
|
||||
clientlib::ClientParameters clientParameters;
|
||||
};
|
||||
|
||||
/// JitLambdaSupport is the instantiated LambdaSupport for the Jit Compilation.
|
||||
class JitLambdaSupport
|
||||
: public LambdaSupport<concretelang::JITLambda *, JitCompilationResult> {
|
||||
|
||||
public:
|
||||
JitLambdaSupport(
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath = llvm::None,
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline =
|
||||
mlir::makeOptimizingTransformer(3, 0, nullptr))
|
||||
: runtimeLibPath(runtimeLibPath), llvmOptPipeline(llvmOptPipeline) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, std::string funcname = "main") override;
|
||||
using LambdaSupport::compile;
|
||||
|
||||
llvm::Expected<concretelang::JITLambda *>
|
||||
loadServerLambda(JitCompilationResult &result) override {
|
||||
return result.lambda.get();
|
||||
}
|
||||
|
||||
llvm::Expected<clientlib::ClientParameters>
|
||||
loadClientParameters(JitCompilationResult &result) override {
|
||||
return result.clientParameters;
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
serverCall(concretelang::JITLambda *lambda,
|
||||
clientlib::PublicArguments &args) override {
|
||||
return lambda->call(args);
|
||||
}
|
||||
|
||||
private:
|
||||
llvm::Optional<llvm::StringRef> runtimeLibPath;
|
||||
llvm::function_ref<llvm::Error(llvm::Module *)> llvmOptPipeline;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
318
compiler/include/concretelang/Support/LambdaSupport.h
Normal file
318
compiler/include/concretelang/Support/LambdaSupport.h
Normal file
@@ -0,0 +1,318 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_LAMBDASUPPORT
|
||||
#define CONCRETELANG_SUPPORT_LAMBDASUPPORT
|
||||
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/Support/LambdaArgument.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientLambda.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/Serializers.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/ServerLib/ServerLambda.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
|
||||
namespace {
|
||||
// Generic function template as well as specializations of
|
||||
// `typedResult` must be declared at namespace scope due to return
|
||||
// type template specialization
|
||||
|
||||
// Helper function for `JitCompilerEngine::Lambda::operator()`
|
||||
// implementing type-dependent preparation of the result.
|
||||
template <typename ResT>
|
||||
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result);
|
||||
|
||||
// Specialization of `typedResult()` for scalar results, forwarding
|
||||
// scalar value to caller
|
||||
template <>
|
||||
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
if (clearResult.value().size() != 1) {
|
||||
return StreamStringError("typedResult expect only one value but got ")
|
||||
<< clearResult.value().size();
|
||||
}
|
||||
return clearResult.value()[0];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline llvm::Expected<std::vector<T>>
|
||||
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (!clearResult.has_value()) {
|
||||
return StreamStringError("typedVectorResult cannot get clear text vector")
|
||||
<< clearResult.error().mesg;
|
||||
}
|
||||
return std::move(clearResult.value());
|
||||
}
|
||||
|
||||
// Specializations of `typedResult()` for vector results, initializing
|
||||
// an `std::vector` of the right size with the results and forwarding
|
||||
// it to the caller with move semantics.
|
||||
//
|
||||
// Cannot factor out into a template template <typename T> inline
|
||||
// llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
||||
// to ambiguity with scalar template
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint8_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint8_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint16_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint16_t>(keySet, result);
|
||||
// }
|
||||
// template <>
|
||||
// inline llvm::Expected<std::vector<uint32_t>>
|
||||
// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
// return typedVectorResult<uint32_t>(keySet, result);
|
||||
// }
|
||||
template <>
|
||||
inline llvm::Expected<std::vector<uint64_t>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
return typedVectorResult<uint64_t>(keySet, result);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
buildTensorLambdaResult(clientlib::KeySet &keySet,
|
||||
clientlib::PublicResult &result) {
|
||||
llvm::Expected<std::vector<T>> tensorOrError =
|
||||
typedResult<std::vector<T>>(keySet, result);
|
||||
|
||||
if (auto err = tensorOrError.takeError())
|
||||
return std::move(err);
|
||||
std::vector<int64_t> tensorDim(result.buffers[0].sizes.begin(),
|
||||
result.buffers[0].sizes.end() - 1);
|
||||
|
||||
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
||||
*tensorOrError, tensorDim);
|
||||
}
|
||||
|
||||
// pecialization of `typedResult()` for a single result wrapped into
|
||||
// a `LambdaArgument`.
|
||||
template <>
|
||||
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
|
||||
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
||||
auto gate = keySet.outputGate(0);
|
||||
// scalar case
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
auto clearResult = result.asClearTextVector(keySet, 0);
|
||||
if (clearResult.has_error()) {
|
||||
return StreamStringError("typedResult: ") << clearResult.error().mesg;
|
||||
}
|
||||
auto res = clearResult.value()[0];
|
||||
|
||||
return std::make_unique<IntLambdaArgument<uint64_t>>(res);
|
||||
}
|
||||
// tensor case
|
||||
// auto width = gate.shape.width;
|
||||
|
||||
// if (width > 32)
|
||||
return buildTensorLambdaResult<uint64_t>(keySet, result);
|
||||
// else if (width > 16)
|
||||
// return buildTensorLambdaResult<uint32_t>(keySet, result);
|
||||
// else if (width > 8)
|
||||
// return buildTensorLambdaResult<uint16_t>(keySet, result);
|
||||
// else if (width <= 8)
|
||||
// return buildTensorLambdaResult<uint8_t>(keySet, result);
|
||||
|
||||
// return StreamStringError("Cannot handle scalars with more than 64 bits");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Adaptor class that push arguments specified as instances of
|
||||
// `LambdaArgument` to `clientlib::EncryptedArguments`.
|
||||
class LambdaArgumentAdaptor {
|
||||
public:
|
||||
// Checks if the argument `arg` is an plaintext / encrypted integer
|
||||
// argument or a plaintext / encrypted tensor argument with a
|
||||
// backing integer type `IntT` and push the argument to `encryptedArgs`.
|
||||
//
|
||||
// Returns `true` if `arg` has one of the types above and its value
|
||||
// was successfully added to `encryptedArgs`, `false` if none of the types
|
||||
// matches or an error if a type matched, but adding the argument to
|
||||
// `encryptedArgs` failed.
|
||||
template <typename IntT>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
|
||||
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
} else if (auto tla = arg.dyn_cast<
|
||||
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
|
||||
auto res =
|
||||
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
|
||||
if (!res.has_value()) {
|
||||
return StreamStringError(res.error().mesg);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Recursive case for `tryAddArg<IntT>(...)`
|
||||
template <typename IntT, typename NextIntT, typename... IntTs>
|
||||
static inline llvm::Expected<bool>
|
||||
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
llvm::Expected<bool> successOrError =
|
||||
tryAddArg<IntT>(encryptedArgs, arg, keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
|
||||
else
|
||||
return true;
|
||||
}
|
||||
|
||||
// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
|
||||
// error if either the argument type is unsupported or if the argument types
|
||||
// is supported, but adding it to `encryptedArgs` failed.
|
||||
static inline llvm::Error
|
||||
addArgument(clientlib::EncryptedArguments &encryptedArgs,
|
||||
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
||||
// Try the supported integer types; size_t needs explicit
|
||||
// treatment, since it may alias none of the fixed size integer
|
||||
// types
|
||||
llvm::Expected<bool> successOrError =
|
||||
LambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t, uint8_t,
|
||||
size_t>(encryptedArgs, arg, keySet);
|
||||
|
||||
if (!successOrError)
|
||||
return successOrError.takeError();
|
||||
|
||||
if (successOrError.get() == false)
|
||||
return StreamStringError("Unknown argument type");
|
||||
else
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
/// Encrypts and build public arguments from lambda arguments
|
||||
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
|
||||
exportArguments(llvm::ArrayRef<LambdaArgument *> args,
|
||||
clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySet &keySet) {
|
||||
|
||||
auto encryptedArgs = clientlib::EncryptedArguments::empty();
|
||||
for (auto arg : args) {
|
||||
if (auto err = LambdaArgumentAdaptor::addArgument(*encryptedArgs, *arg,
|
||||
keySet)) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
auto check = encryptedArgs->checkAllArgs(keySet);
|
||||
if (check.has_error()) {
|
||||
return StreamStringError(check.error().mesg);
|
||||
}
|
||||
auto publicArguments = encryptedArgs->exportPublicArguments(
|
||||
clientParameters, keySet.runtimeContext());
|
||||
if (publicArguments.has_error()) {
|
||||
return StreamStringError(publicArguments.error().mesg);
|
||||
}
|
||||
return std::move(publicArguments.value());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Lambda, typename CompilationResult> class LambdaSupport {
|
||||
|
||||
public:
|
||||
virtual ~LambdaSupport() {}
|
||||
|
||||
/// Compile the mlir program and produces a compilation result if succeed.
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
|
||||
llvm::SourceMgr &program, std::string funcname = "main");
|
||||
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>>
|
||||
compile(llvm::StringRef program, std::string funcname = "main") {
|
||||
return compile(llvm::MemoryBuffer::getMemBuffer(program), funcname);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<CompilationResult>>
|
||||
compile(std::unique_ptr<llvm::MemoryBuffer> program,
|
||||
std::string funcname = "main") {
|
||||
llvm::SourceMgr sm;
|
||||
sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc());
|
||||
return compile(sm, funcname);
|
||||
}
|
||||
|
||||
/// Load the server lambda from the compilation result.
|
||||
llvm::Expected<Lambda> virtual loadServerLambda(CompilationResult &result);
|
||||
|
||||
/// Load the client parameters from the compilation result.
|
||||
llvm::Expected<clientlib::ClientParameters> virtual loadClientParameters(
|
||||
CompilationResult &result);
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
|
||||
Lambda lambda, clientlib::PublicArguments &args);
|
||||
|
||||
/// Build the client KeySet from the client parameters.
|
||||
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
|
||||
keySet(clientlib::ClientParameters clientParameters,
|
||||
llvm::Optional<clientlib::KeySetCache> cache) {
|
||||
std::shared_ptr<clientlib::KeySetCache> cachePtr;
|
||||
if (cache.hasValue()) {
|
||||
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.getValue());
|
||||
}
|
||||
auto keySet =
|
||||
clientlib::KeySetCache::generate(cachePtr, clientParameters, 0, 0);
|
||||
if (keySet.has_error()) {
|
||||
return StreamStringError(keySet.error().mesg);
|
||||
}
|
||||
return std::move(keySet.value());
|
||||
}
|
||||
|
||||
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
|
||||
exportArguments(clientlib::ClientParameters clientParameters,
|
||||
clientlib::KeySet &keySet,
|
||||
llvm::ArrayRef<LambdaArgument *> args) {
|
||||
return LambdaArgumentAdaptor::exportArguments(args, clientParameters,
|
||||
keySet);
|
||||
}
|
||||
|
||||
template <typename ResT>
|
||||
static llvm::Expected<ResT>
|
||||
call(Lambda lambda, clientlib::PublicArguments &publicArguments) {
|
||||
// Call the lambda
|
||||
auto publicResult = LambdaSupport<Lambda, CompilationResult>().serverCall(
|
||||
lambda, publicArguments);
|
||||
if (auto err = publicResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Decrypt the result
|
||||
return typedResult<ResT>(keySet, **publicResult);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
103
compiler/include/concretelang/Support/LibraryLambdaSupport.h
Normal file
103
compiler/include/concretelang/Support/LibraryLambdaSupport.h
Normal file
@@ -0,0 +1,103 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_SUPPORT_LIBRARY_LAMBDA_SUPPORT
|
||||
#define CONCRETELANG_SUPPORT_LIBRARY_LAMBDA_SUPPORT
|
||||
|
||||
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
|
||||
#include <mlir/ExecutionEngine/ExecutionEngine.h>
|
||||
#include <mlir/ExecutionEngine/OptUtils.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include <concretelang/ServerLib/ServerLambda.h>
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/LambdaSupport.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
namespace clientlib = ::concretelang::clientlib;
|
||||
namespace serverlib = ::concretelang::serverlib;
|
||||
|
||||
/// LibraryCompilationResult is the result of a compilation to a library.
|
||||
struct LibraryCompilationResult {
|
||||
/// The output path where the compilation artifact has been generated.
|
||||
std::string libraryPath;
|
||||
std::string funcName;
|
||||
};
|
||||
|
||||
class LibraryLambdaSupport
|
||||
: public LambdaSupport<serverlib::ServerLambda, LibraryCompilationResult> {
|
||||
|
||||
public:
|
||||
LibraryLambdaSupport(std::string outputPath = "/tmp/toto")
|
||||
: outputPath(outputPath) {}
|
||||
|
||||
llvm::Expected<std::unique_ptr<LibraryCompilationResult>>
|
||||
compile(llvm::SourceMgr &program, std::string funcname = "main") override {
|
||||
// Setup the compiler engine
|
||||
auto context = CompilationContext::createShared();
|
||||
concretelang::CompilerEngine engine(context);
|
||||
engine.setClientParametersFuncName(funcname);
|
||||
|
||||
// Compile to a library
|
||||
auto library = engine.compile(program, outputPath);
|
||||
if (auto err = library.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
auto result = std::make_unique<LibraryCompilationResult>();
|
||||
result->libraryPath = outputPath;
|
||||
result->funcName = funcname;
|
||||
return std::move(result);
|
||||
}
|
||||
using LambdaSupport::compile;
|
||||
|
||||
/// Load the server lambda from the compilation result.
|
||||
llvm::Expected<serverlib::ServerLambda>
|
||||
loadServerLambda(LibraryCompilationResult &result) override {
|
||||
auto lambda =
|
||||
serverlib::ServerLambda::load(result.funcName, result.libraryPath);
|
||||
if (lambda.has_error()) {
|
||||
return StreamStringError(lambda.error().mesg);
|
||||
}
|
||||
return lambda.value();
|
||||
}
|
||||
|
||||
/// Load the client parameters from the compilation result.
|
||||
llvm::Expected<clientlib::ClientParameters>
|
||||
loadClientParameters(LibraryCompilationResult &result) override {
|
||||
auto path = ClientParameters::getClientParametersPath(result.libraryPath);
|
||||
auto params = ClientParameters::load(path);
|
||||
if (params.has_error()) {
|
||||
return StreamStringError(params.error().mesg);
|
||||
}
|
||||
auto param = llvm::find_if(params.value(), [&](ClientParameters param) {
|
||||
return param.functionName == result.funcName;
|
||||
});
|
||||
if (param == params.value().end()) {
|
||||
return StreamStringError("ClientLambda: cannot find function(")
|
||||
<< result.funcName << ") in client parameters path(" << path
|
||||
<< ")";
|
||||
}
|
||||
return *param;
|
||||
}
|
||||
|
||||
/// Call the lambda with the public arguments.
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
serverCall(serverlib::ServerLambda lambda,
|
||||
clientlib::PublicArguments &args) override {
|
||||
return lambda.call(args);
|
||||
}
|
||||
|
||||
private:
|
||||
std::string outputPath;
|
||||
};
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -3,6 +3,7 @@ add_mlir_library(ConcretelangSupport
|
||||
Jit.cpp
|
||||
CompilerEngine.cpp
|
||||
JitCompilerEngine.cpp
|
||||
JitLambdaSupport.cpp
|
||||
LambdaArgument.cpp
|
||||
V0Parameters.cpp
|
||||
V0Curves.cpp
|
||||
@@ -32,4 +33,5 @@ add_mlir_library(ConcretelangSupport
|
||||
|
||||
ConcretelangRuntime
|
||||
ConcretelangClientLib
|
||||
ConcretelangServerLib
|
||||
)
|
||||
|
||||
@@ -383,19 +383,17 @@ llvm::Expected<CompilerEngine::CompilationResult>
|
||||
CompilerEngine::compile(std::unique_ptr<llvm::MemoryBuffer> buffer,
|
||||
Target target, OptionalLib lib) {
|
||||
llvm::SourceMgr sm;
|
||||
|
||||
sm.AddNewSourceBuffer(std::move(buffer), llvm::SMLoc());
|
||||
|
||||
return this->compile(sm, target, lib);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
llvm::Expected<CompilerEngine::Library>
|
||||
CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
|
||||
CompilerEngine::compile(std::vector<std::string> inputs,
|
||||
std::string libraryPath) {
|
||||
using Library = mlir::concretelang::CompilerEngine::Library;
|
||||
auto outputLib = std::make_shared<Library>(libraryPath);
|
||||
auto target = CompilerEngine::Target::LIBRARY;
|
||||
|
||||
for (auto input : inputs) {
|
||||
auto compilation = compile(input, target, outputLib);
|
||||
if (!compilation) {
|
||||
@@ -403,6 +401,24 @@ CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
|
||||
<< llvm::toString(compilation.takeError());
|
||||
}
|
||||
}
|
||||
if (auto err = outputLib->emitArtifacts()) {
|
||||
return StreamStringError("Can't emit artifacts: ")
|
||||
<< llvm::toString(std::move(err));
|
||||
}
|
||||
return *outputLib.get();
|
||||
}
|
||||
|
||||
llvm::Expected<CompilerEngine::Library>
|
||||
CompilerEngine::compile(llvm::SourceMgr &sm, std::string libraryPath) {
|
||||
using Library = mlir::concretelang::CompilerEngine::Library;
|
||||
auto outputLib = std::make_shared<Library>(libraryPath);
|
||||
auto target = CompilerEngine::Target::LIBRARY;
|
||||
|
||||
auto compilation = compile(sm, target, outputLib);
|
||||
if (!compilation) {
|
||||
return StreamStringError("Can't compile: ")
|
||||
<< llvm::toString(compilation.takeError());
|
||||
}
|
||||
|
||||
if (auto err = outputLib->emitArtifacts()) {
|
||||
return StreamStringError("Can't emit artifacts: ")
|
||||
@@ -411,11 +427,6 @@ CompilerEngine::compile(std::vector<T> inputs, std::string libraryPath) {
|
||||
return *outputLib.get();
|
||||
}
|
||||
|
||||
// explicit instantiation for a vector of string (for linking with lib/CAPI)
|
||||
template llvm::Expected<CompilerEngine::Library>
|
||||
CompilerEngine::compile(std::vector<std::string> inputs,
|
||||
std::string libraryPath);
|
||||
|
||||
/** Returns the path of the shared library */
|
||||
std::string CompilerEngine::Library::getSharedLibraryPath(std::string path) {
|
||||
return path + DOT_SHARED_LIB_EXT;
|
||||
|
||||
49
compiler/lib/Support/JitLambdaSupport.cpp
Normal file
49
compiler/lib/Support/JitLambdaSupport.cpp
Normal file
@@ -0,0 +1,49 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <concretelang/Support/JitLambdaSupport.h>
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
llvm::Expected<std::unique_ptr<JitCompilationResult>>
|
||||
JitLambdaSupport::compile(llvm::SourceMgr &program, std::string funcname) {
|
||||
// Setup the compiler engine
|
||||
auto context = std::make_shared<CompilationContext>();
|
||||
concretelang::CompilerEngine engine(context);
|
||||
|
||||
// We need client parameters to be generated
|
||||
engine.setGenerateClientParameters(true);
|
||||
engine.setClientParametersFuncName(funcname);
|
||||
|
||||
// Compile to LLVM Dialect
|
||||
auto compilationResult =
|
||||
engine.compile(program, CompilerEngine::Target::LLVM_IR);
|
||||
|
||||
if (auto err = compilationResult.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
// Compile from LLVM Dialect to JITLambda
|
||||
auto mlirModule = compilationResult.get().mlirModuleRef->get();
|
||||
auto lambda = concretelang::JITLambda::create(
|
||||
funcname, mlirModule, llvmOptPipeline, runtimeLibPath);
|
||||
if (auto err = lambda.takeError()) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
if (!compilationResult.get().clientParameters.hasValue()) {
|
||||
// i.e. that should not occurs
|
||||
return StreamStringError("No client parameters has been generated");
|
||||
}
|
||||
auto result = std::make_unique<JitCompilationResult>();
|
||||
result->lambda = std::move(*lambda);
|
||||
result->clientParameters =
|
||||
compilationResult.get().clientParameters.getValue();
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
@@ -61,8 +61,8 @@ valueDescriptionToLambdaArgument(ValueDescription desc) {
|
||||
}
|
||||
|
||||
llvm::Error checkResult(ScalarDesc &desc,
|
||||
mlir::concretelang::LambdaArgument *res) {
|
||||
auto res64 = res->dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
mlir::concretelang::LambdaArgument &res) {
|
||||
auto res64 = res.dyn_cast<mlir::concretelang::IntLambdaArgument<uint64_t>>();
|
||||
if (res64 == nullptr) {
|
||||
return StreamStringError("invocation result is not a scalar");
|
||||
}
|
||||
@@ -85,7 +85,7 @@ checkTensorResult(TensorDescription &desc,
|
||||
<< resShape.size() << " expected " << desc.shape.size();
|
||||
}
|
||||
for (size_t i = 0; i < desc.shape.size(); i++) {
|
||||
if ((uint64_t)resShape[i] != desc.shape[i]) {
|
||||
if (resShape[i] != desc.shape[i]) {
|
||||
return StreamStringError("shape differs at pos ")
|
||||
<< i << ", got " << resShape[i] << " expected " << desc.shape[i];
|
||||
}
|
||||
@@ -112,37 +112,36 @@ checkTensorResult(TensorDescription &desc,
|
||||
}
|
||||
|
||||
llvm::Error checkResult(TensorDescription &desc,
|
||||
mlir::concretelang::LambdaArgument *res) {
|
||||
mlir::concretelang::LambdaArgument &res) {
|
||||
switch (desc.width) {
|
||||
case 8:
|
||||
return checkTensorResult<uint8_t>(
|
||||
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint8_t>>>());
|
||||
case 16:
|
||||
return checkTensorResult<uint16_t>(
|
||||
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint16_t>>>());
|
||||
case 32:
|
||||
return checkTensorResult<uint32_t>(
|
||||
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint32_t>>>());
|
||||
case 64:
|
||||
return checkTensorResult<uint64_t>(
|
||||
desc, res->dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
desc, res.dyn_cast<mlir::concretelang::TensorLambdaArgument<
|
||||
mlir::concretelang::IntLambdaArgument<uint64_t>>>());
|
||||
default:
|
||||
return StreamStringError("Unsupported width");
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Error
|
||||
checkResult(ValueDescription &desc,
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument> &res) {
|
||||
llvm::Error checkResult(ValueDescription &desc,
|
||||
mlir::concretelang::LambdaArgument &res) {
|
||||
switch (desc.tag) {
|
||||
case ValueDescription::SCALAR:
|
||||
return checkResult(desc.scalar, res.get());
|
||||
return checkResult(desc.scalar, res);
|
||||
case ValueDescription::TENSOR:
|
||||
return checkResult(desc.tensor, res.get());
|
||||
return checkResult(desc.tensor, res);
|
||||
}
|
||||
assert(false);
|
||||
}
|
||||
@@ -190,7 +189,7 @@ template <> struct llvm::yaml::MappingTraits<EndToEndDesc> {
|
||||
}
|
||||
};
|
||||
|
||||
LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(EndToEndDesc);
|
||||
LLVM_YAML_IS_DOCUMENT_LIST_VECTOR(EndToEndDesc)
|
||||
|
||||
std::vector<EndToEndDesc> loadEndToEndDesc(std::string path) {
|
||||
std::ifstream file(path);
|
||||
|
||||
@@ -4,53 +4,88 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "EndToEndFixture.h"
|
||||
#include "concretelang/Support/JitLambdaSupport.h"
|
||||
#include "concretelang/Support/LibraryLambdaSupport.h"
|
||||
|
||||
class EndToEndJitTest : public testing::TestWithParam<EndToEndDesc> {};
|
||||
|
||||
TEST_P(EndToEndJitTest, compile_and_run) {
|
||||
EndToEndDesc desc = GetParam();
|
||||
|
||||
// Compile program
|
||||
// mlir::concretelang::JitCompilerEngine::Lambda lambda =
|
||||
checkedJit(lambda, desc.program);
|
||||
|
||||
// Prepare arguments
|
||||
for (auto test : desc.tests) {
|
||||
std::vector<mlir::concretelang::LambdaArgument *> inputArguments;
|
||||
inputArguments.reserve(test.inputs.size());
|
||||
for (auto input : test.inputs) {
|
||||
auto arg = valueDescriptionToLambdaArgument(input);
|
||||
ASSERT_EXPECTED_SUCCESS(arg);
|
||||
inputArguments.push_back(arg.get());
|
||||
}
|
||||
|
||||
// Call the lambda
|
||||
auto res =
|
||||
lambda.operator()<std::unique_ptr<mlir::concretelang::LambdaArgument>>(
|
||||
llvm::ArrayRef<mlir::concretelang::LambdaArgument *>(
|
||||
inputArguments));
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
if (test.outputs.size() != 1) {
|
||||
FAIL() << "Only one result function are supported.";
|
||||
}
|
||||
ASSERT_LLVM_ERROR(checkResult(test.outputs[0], res.get()));
|
||||
|
||||
// Free arguments
|
||||
for (auto arg : inputArguments) {
|
||||
delete arg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE(prefix, path) \
|
||||
namespace prefix { \
|
||||
auto valuesVector = loadEndToEndDesc(path); \
|
||||
auto values = testing::ValuesIn<std::vector<EndToEndDesc>>(valuesVector); \
|
||||
INSTANTIATE_TEST_SUITE_P(prefix, EndToEndJitTest, values, \
|
||||
printEndToEndDesc); \
|
||||
// Macro to define and end to end TestSuite that run test thanks the
|
||||
// LambdaSupport according a EndToEndDesc
|
||||
#define INSTANTIATE_END_TO_END_COMPILE_AND_RUN(TestSuite, LambdaSupport) \
|
||||
TEST_P(TestSuite, compile_and_run) { \
|
||||
\
|
||||
auto desc = GetParam(); \
|
||||
\
|
||||
LambdaSupport support; \
|
||||
\
|
||||
/* 1 - Compile the program */ \
|
||||
auto compilationResult = support.compile(desc.program); \
|
||||
ASSERT_EXPECTED_SUCCESS(compilationResult); \
|
||||
\
|
||||
/* 2 - Load the client parameters and build the keySet */ \
|
||||
auto clientParameters = support.loadClientParameters(**compilationResult); \
|
||||
ASSERT_EXPECTED_SUCCESS(clientParameters); \
|
||||
\
|
||||
auto keySet = support.keySet(*clientParameters, getTestKeySetCache()); \
|
||||
ASSERT_EXPECTED_SUCCESS(keySet); \
|
||||
\
|
||||
/* 3 - Load the server lambda */ \
|
||||
auto serverLambda = support.loadServerLambda(**compilationResult); \
|
||||
ASSERT_EXPECTED_SUCCESS(serverLambda); \
|
||||
\
|
||||
/* For each test entries */ \
|
||||
for (auto test : desc.tests) { \
|
||||
std::vector<mlir::concretelang::LambdaArgument *> inputArguments; \
|
||||
inputArguments.reserve(test.inputs.size()); \
|
||||
for (auto input : test.inputs) { \
|
||||
auto arg = valueDescToLambdaArgument(input); \
|
||||
ASSERT_EXPECTED_SUCCESS(arg); \
|
||||
inputArguments.push_back(arg.get()); \
|
||||
} \
|
||||
/* 4 - Create the public arguments */ \
|
||||
auto publicArguments = support.exportArguments( \
|
||||
*clientParameters, **keySet, inputArguments); \
|
||||
ASSERT_EXPECTED_SUCCESS(publicArguments); \
|
||||
\
|
||||
/* 5 - Call the server lambda */ \
|
||||
auto publicResult = \
|
||||
support.serverCall(*serverLambda, **publicArguments); \
|
||||
ASSERT_EXPECTED_SUCCESS(publicResult); \
|
||||
\
|
||||
/* 6 - Decrypt the public result */ \
|
||||
auto result = mlir::concretelang::typedResult< \
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument>>( \
|
||||
**keySet, **publicResult); \
|
||||
\
|
||||
ASSERT_EXPECTED_SUCCESS(result); \
|
||||
\
|
||||
for (auto arg : inputArguments) { \
|
||||
delete arg; \
|
||||
} \
|
||||
} \
|
||||
}
|
||||
|
||||
INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE(
|
||||
FHE, "tests/unittest/end_to_end_fhe.yaml")
|
||||
INSTANTIATE_END_TO_END_JIT_TEST_SUITE_FROM_FILE(
|
||||
EncryptedTensor, "tests/unittest/end_to_end_encrypted_tensor.yaml")
|
||||
#define INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE(prefix, suite, \
|
||||
lambdasupport, path) \
|
||||
namespace prefix##suite { \
|
||||
auto valuesVector = loadEndToEndDesc(path); \
|
||||
auto values = testing::ValuesIn<std::vector<EndToEndDesc>>(valuesVector); \
|
||||
INSTANTIATE_TEST_SUITE_P(prefix, suite, values, printEndToEndDesc); \
|
||||
}
|
||||
|
||||
#define INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(suite, \
|
||||
lambdasupport) \
|
||||
\
|
||||
class suite : public testing::TestWithParam<EndToEndDesc> {}; \
|
||||
INSTANTIATE_END_TO_END_COMPILE_AND_RUN(suite, lambdasupport) \
|
||||
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE( \
|
||||
FHE, suite, lambdasupport, "tests/unittest/end_to_end_fhe.yaml") \
|
||||
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_FILE( \
|
||||
EncryptedTensor, suite, lambdasupport, \
|
||||
"tests/unittest/end_to_end_encrypted_tensor.yaml")
|
||||
|
||||
/// Instantiate the test suite for Jit
|
||||
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
|
||||
JitTest, mlir::concretelang::JitLambdaSupport)
|
||||
|
||||
/// Instantiate the test suite for Jit
|
||||
INSTANTIATE_END_TO_END_TEST_SUITE_FROM_ALL_TEST_FILES(
|
||||
LibraryTest, mlir::concretelang::LibraryLambdaSupport)
|
||||
Reference in New Issue
Block a user