mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
This patch adds support for scalar results to the client/server protocol and tests. In addition to `TensorData`, a new type `ScalarData` is added. Previous representations of scalar values using one-dimensional `TensorData` instances have been replaced with proper instantiations of `ScalarData`. The generic use of `TensorData` for scalar and tensor values has been replaced with uses of a new variant `ScalarOrTensorData`, which can either hold an instance of `TensorData` or `ScalarData`.
517 lines
19 KiB
C++
517 lines
19 KiB
C++
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
|
// Exceptions. See
|
|
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
|
// for license information.
|
|
|
|
#ifndef CONCRETELANG_SUPPORT_LAMBDASUPPORT
|
|
#define CONCRETELANG_SUPPORT_LAMBDASUPPORT
|
|
|
|
#include "boost/outcome.h"
|
|
|
|
#include "concretelang/Support/LambdaArgument.h"
|
|
|
|
#include "concretelang/ClientLib/ClientLambda.h"
|
|
#include "concretelang/ClientLib/ClientParameters.h"
|
|
#include "concretelang/ClientLib/KeySetCache.h"
|
|
#include "concretelang/ClientLib/Serializers.h"
|
|
#include "concretelang/Common/Error.h"
|
|
#include "concretelang/ServerLib/ServerLambda.h"
|
|
|
|
namespace mlir {
|
|
namespace concretelang {
|
|
|
|
namespace clientlib = ::concretelang::clientlib;
|
|
|
|
namespace {
|
|
// Generic function template as well as specializations of
|
|
// `typedResult` must be declared at namespace scope due to return
|
|
// type template specialization
|
|
|
|
/// Helper function for implementing type-dependent preparation of the result.
|
|
template <typename ResT>
|
|
llvm::Expected<ResT> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result);
|
|
|
|
template <typename T>
|
|
inline llvm::Expected<T> typedScalarResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
auto clearResult = result.asClearTextScalar<T>(keySet, 0);
|
|
if (!clearResult.has_value()) {
|
|
return StreamStringError("typedResult cannot get clear text scalar")
|
|
<< clearResult.error().mesg;
|
|
}
|
|
|
|
return clearResult.value();
|
|
}
|
|
|
|
/// Specializations of `typedResult()` for scalar results, forwarding
|
|
/// scalar value to caller.
|
|
|
|
template <>
|
|
inline llvm::Expected<uint64_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<uint64_t>(keySet, result));
|
|
}
|
|
template <>
|
|
inline llvm::Expected<int64_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<int64_t>(keySet, result));
|
|
}
|
|
template <>
|
|
inline llvm::Expected<uint32_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<uint32_t>(keySet, result));
|
|
}
|
|
template <>
|
|
inline llvm::Expected<int32_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<int32_t>(keySet, result));
|
|
}
|
|
template <>
|
|
inline llvm::Expected<uint16_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<uint16_t>(keySet, result));
|
|
}
|
|
template <>
|
|
inline llvm::Expected<int16_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<int16_t>(keySet, result));
|
|
}
|
|
template <>
|
|
inline llvm::Expected<uint8_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<uint8_t>(keySet, result));
|
|
}
|
|
template <>
|
|
inline llvm::Expected<int8_t> typedResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
return std::move(typedScalarResult<int8_t>(keySet, result));
|
|
}
|
|
|
|
template <typename T>
|
|
inline llvm::Expected<std::vector<T>>
|
|
typedVectorResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
auto clearResult = result.asClearTextVector<T>(keySet, 0);
|
|
if (!clearResult.has_value()) {
|
|
return StreamStringError("typedVectorResult cannot get clear text vector")
|
|
<< clearResult.error().mesg;
|
|
}
|
|
return std::move(clearResult.value());
|
|
}
|
|
|
|
/// Specializations of `typedResult()` for vector results, initializing
|
|
/// an `std::vector` of the right size with the results and forwarding
|
|
/// it to the caller with move semantics.
|
|
/// Cannot factor out into a template template <typename T> inline
|
|
/// llvm::Expected<std::vector<uint8_t>>
|
|
/// typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result); due
|
|
/// to ambiguity with scalar template
|
|
template <>
|
|
inline llvm::Expected<std::vector<uint8_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<uint8_t>(keySet, result);
|
|
}
|
|
template <>
|
|
inline llvm::Expected<std::vector<int8_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<int8_t>(keySet, result);
|
|
}
|
|
template <>
|
|
inline llvm::Expected<std::vector<uint16_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<uint16_t>(keySet, result);
|
|
}
|
|
template <>
|
|
inline llvm::Expected<std::vector<int16_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<int16_t>(keySet, result);
|
|
}
|
|
template <>
|
|
inline llvm::Expected<std::vector<uint32_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<uint32_t>(keySet, result);
|
|
}
|
|
template <>
|
|
inline llvm::Expected<std::vector<int32_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<int32_t>(keySet, result);
|
|
}
|
|
template <>
|
|
inline llvm::Expected<std::vector<uint64_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<uint64_t>(keySet, result);
|
|
}
|
|
template <>
|
|
inline llvm::Expected<std::vector<int64_t>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
return typedVectorResult<int64_t>(keySet, result);
|
|
}
|
|
|
|
template <typename T>
|
|
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
|
buildTensorLambdaResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
llvm::Expected<std::vector<T>> tensorOrError =
|
|
typedResult<std::vector<T>>(keySet, result);
|
|
if (auto err = tensorOrError.takeError())
|
|
return std::move(err);
|
|
|
|
auto tensorDim = result.asClearTextShape(0);
|
|
if (tensorDim.has_error())
|
|
return StreamStringError(tensorDim.error().mesg);
|
|
|
|
return std::make_unique<TensorLambdaArgument<IntLambdaArgument<T>>>(
|
|
*tensorOrError, tensorDim.value());
|
|
}
|
|
|
|
template <typename T>
|
|
llvm::Expected<std::unique_ptr<LambdaArgument>>
|
|
buildScalarLambdaResult(clientlib::KeySet &keySet,
|
|
clientlib::PublicResult &result) {
|
|
llvm::Expected<T> scalarOrError = typedResult<T>(keySet, result);
|
|
if (auto err = scalarOrError.takeError())
|
|
return std::move(err);
|
|
|
|
return std::make_unique<IntLambdaArgument<T>>(*scalarOrError);
|
|
}
|
|
|
|
/// pecialization of `typedResult()` for a single result wrapped into
|
|
/// a `LambdaArgument`.
|
|
template <>
|
|
inline llvm::Expected<std::unique_ptr<LambdaArgument>>
|
|
typedResult(clientlib::KeySet &keySet, clientlib::PublicResult &result) {
|
|
auto gate = keySet.outputGate(0);
|
|
auto width = gate.shape.width;
|
|
bool sign = gate.shape.sign;
|
|
|
|
if (width > 64)
|
|
return StreamStringError("Cannot handle values with more than 64 bits");
|
|
|
|
// By convention, decrypted integers are always 64 bits wide
|
|
if (gate.isEncrypted())
|
|
width = 64;
|
|
|
|
if (gate.shape.dimensions.empty()) {
|
|
// scalar case
|
|
if (width > 32) {
|
|
return (sign) ? buildScalarLambdaResult<int64_t>(keySet, result)
|
|
: buildScalarLambdaResult<uint64_t>(keySet, result);
|
|
} else if (width > 16) {
|
|
return (sign) ? buildScalarLambdaResult<int32_t>(keySet, result)
|
|
: buildScalarLambdaResult<uint32_t>(keySet, result);
|
|
} else if (width > 8) {
|
|
return (sign) ? buildScalarLambdaResult<int16_t>(keySet, result)
|
|
: buildScalarLambdaResult<uint16_t>(keySet, result);
|
|
} else if (width <= 8) {
|
|
return (sign) ? buildScalarLambdaResult<int8_t>(keySet, result)
|
|
: buildScalarLambdaResult<uint8_t>(keySet, result);
|
|
}
|
|
} else {
|
|
// tensor case
|
|
if (width > 32) {
|
|
return (sign) ? buildTensorLambdaResult<int64_t>(keySet, result)
|
|
: buildTensorLambdaResult<uint64_t>(keySet, result);
|
|
} else if (width > 16) {
|
|
return (sign) ? buildTensorLambdaResult<int32_t>(keySet, result)
|
|
: buildTensorLambdaResult<uint32_t>(keySet, result);
|
|
} else if (width > 8) {
|
|
return (sign) ? buildTensorLambdaResult<int16_t>(keySet, result)
|
|
: buildTensorLambdaResult<uint16_t>(keySet, result);
|
|
} else if (width <= 8) {
|
|
return (sign) ? buildTensorLambdaResult<int8_t>(keySet, result)
|
|
: buildTensorLambdaResult<uint8_t>(keySet, result);
|
|
}
|
|
}
|
|
|
|
assert(false && "Cannot happen");
|
|
}
|
|
} // namespace
|
|
|
|
/// Adaptor class that push arguments specified as instances of
|
|
/// `LambdaArgument` to `clientlib::EncryptedArguments`.
|
|
class LambdaArgumentAdaptor {
|
|
public:
|
|
/// Checks if the argument `arg` is an plaintext / encrypted integer
|
|
/// argument or a plaintext / encrypted tensor argument with a
|
|
/// backing integer type `IntT` and push the argument to `encryptedArgs`.
|
|
///
|
|
/// Returns `true` if `arg` has one of the types above and its value
|
|
/// was successfully added to `encryptedArgs`, `false` if none of the types
|
|
/// matches or an error if a type matched, but adding the argument to
|
|
/// `encryptedArgs` failed.
|
|
template <typename IntT>
|
|
static inline llvm::Expected<bool>
|
|
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
|
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
|
if (auto ila = arg.dyn_cast<IntLambdaArgument<IntT>>()) {
|
|
auto res = encryptedArgs.pushArg(ila->getValue(), keySet);
|
|
if (!res.has_value()) {
|
|
return StreamStringError(res.error().mesg);
|
|
} else {
|
|
return true;
|
|
}
|
|
} else if (auto tla = arg.dyn_cast<
|
|
TensorLambdaArgument<IntLambdaArgument<IntT>>>()) {
|
|
auto res =
|
|
encryptedArgs.pushArg(tla->getValue(), tla->getDimensions(), keySet);
|
|
if (!res.has_value()) {
|
|
return StreamStringError(res.error().mesg);
|
|
} else {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/// Recursive case for `tryAddArg<IntT>(...)`
|
|
template <typename IntT, typename NextIntT, typename... IntTs>
|
|
static inline llvm::Expected<bool>
|
|
tryAddArg(clientlib::EncryptedArguments &encryptedArgs,
|
|
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
|
llvm::Expected<bool> successOrError =
|
|
tryAddArg<IntT>(encryptedArgs, arg, keySet);
|
|
|
|
if (!successOrError)
|
|
return successOrError.takeError();
|
|
|
|
if (successOrError.get() == false)
|
|
return tryAddArg<NextIntT, IntTs...>(encryptedArgs, arg, keySet);
|
|
else
|
|
return true;
|
|
}
|
|
|
|
/// Attempts to push a single argument `arg` to `encryptedArgs`. Returns an
|
|
/// error if either the argument type is unsupported or if the argument types
|
|
/// is supported, but adding it to `encryptedArgs` failed.
|
|
static inline llvm::Error
|
|
addArgument(clientlib::EncryptedArguments &encryptedArgs,
|
|
const LambdaArgument &arg, clientlib::KeySet &keySet) {
|
|
// Try the supported integer types; size_t needs explicit
|
|
// treatment, since it may alias none of the fixed size integer
|
|
// types
|
|
llvm::Expected<bool> successOrError =
|
|
LambdaArgumentAdaptor::tryAddArg<uint64_t, uint32_t, uint16_t, uint8_t,
|
|
size_t>(encryptedArgs, arg, keySet);
|
|
|
|
if (!successOrError)
|
|
return successOrError.takeError();
|
|
|
|
if (successOrError.get() == false)
|
|
return StreamStringError("Unknown argument type");
|
|
else
|
|
return llvm::Error::success();
|
|
}
|
|
|
|
/// Encrypts and build public arguments from lambda arguments
|
|
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
|
|
exportArguments(llvm::ArrayRef<const LambdaArgument *> args,
|
|
clientlib::ClientParameters clientParameters,
|
|
clientlib::KeySet &keySet) {
|
|
|
|
auto encryptedArgs = clientlib::EncryptedArguments::empty();
|
|
for (auto arg : args) {
|
|
if (auto err = LambdaArgumentAdaptor::addArgument(*encryptedArgs, *arg,
|
|
keySet)) {
|
|
return std::move(err);
|
|
}
|
|
}
|
|
auto check = encryptedArgs->checkAllArgs(keySet);
|
|
if (check.has_error()) {
|
|
return StreamStringError(check.error().mesg);
|
|
}
|
|
auto publicArguments = encryptedArgs->exportPublicArguments(
|
|
clientParameters, keySet.runtimeContext());
|
|
if (publicArguments.has_error()) {
|
|
return StreamStringError(publicArguments.error().mesg);
|
|
}
|
|
return std::move(publicArguments.value());
|
|
}
|
|
};
|
|
|
|
template <typename Lambda, typename CompilationResult> class LambdaSupport {
|
|
public:
|
|
typedef Lambda lambda;
|
|
typedef CompilationResult compilationResult;
|
|
|
|
virtual ~LambdaSupport() {}
|
|
|
|
/// Compile the mlir program and produces a compilation result if succeed.
|
|
llvm::Expected<std::unique_ptr<CompilationResult>> virtual compile(
|
|
llvm::SourceMgr &program,
|
|
CompilationOptions options = CompilationOptions("main")) = 0;
|
|
|
|
llvm::Expected<std::unique_ptr<CompilationResult>>
|
|
compile(llvm::StringRef program,
|
|
CompilationOptions options = CompilationOptions("main")) {
|
|
return compile(llvm::MemoryBuffer::getMemBuffer(program), options);
|
|
}
|
|
|
|
llvm::Expected<std::unique_ptr<CompilationResult>>
|
|
compile(std::unique_ptr<llvm::MemoryBuffer> program,
|
|
CompilationOptions options = CompilationOptions("main")) {
|
|
llvm::SourceMgr sm;
|
|
sm.AddNewSourceBuffer(std::move(program), llvm::SMLoc());
|
|
return compile(sm, options);
|
|
}
|
|
|
|
/// Load the server lambda from the compilation result.
|
|
llvm::Expected<Lambda> virtual loadServerLambda(
|
|
CompilationResult &result) = 0;
|
|
|
|
/// Load the client parameters from the compilation result.
|
|
llvm::Expected<clientlib::ClientParameters> virtual loadClientParameters(
|
|
CompilationResult &result) = 0;
|
|
|
|
/// Load the compilation feedback from the compilation result.
|
|
llvm::Expected<CompilationFeedback> virtual loadCompilationFeedback(
|
|
CompilationResult &result) = 0;
|
|
|
|
/// Call the lambda with the public arguments.
|
|
llvm::Expected<std::unique_ptr<clientlib::PublicResult>> virtual serverCall(
|
|
Lambda lambda, clientlib::PublicArguments &args,
|
|
clientlib::EvaluationKeys &evaluationKeys) = 0;
|
|
|
|
/// Build the client KeySet from the client parameters.
|
|
static llvm::Expected<std::unique_ptr<clientlib::KeySet>>
|
|
keySet(clientlib::ClientParameters clientParameters,
|
|
llvm::Optional<clientlib::KeySetCache> cache) {
|
|
std::shared_ptr<clientlib::KeySetCache> cachePtr;
|
|
if (cache.hasValue()) {
|
|
cachePtr = std::make_shared<clientlib::KeySetCache>(cache.getValue());
|
|
}
|
|
auto keySet =
|
|
clientlib::KeySetCache::generate(cachePtr, clientParameters, 0, 0);
|
|
if (keySet.has_error()) {
|
|
return StreamStringError(keySet.error().mesg);
|
|
}
|
|
return std::move(keySet.value());
|
|
}
|
|
|
|
static llvm::Expected<std::unique_ptr<clientlib::PublicArguments>>
|
|
exportArguments(clientlib::ClientParameters clientParameters,
|
|
clientlib::KeySet &keySet,
|
|
llvm::ArrayRef<const LambdaArgument *> args) {
|
|
return LambdaArgumentAdaptor::exportArguments(args, clientParameters,
|
|
keySet);
|
|
}
|
|
|
|
template <typename ResT>
|
|
static llvm::Expected<ResT> call(Lambda lambda,
|
|
clientlib::PublicArguments &publicArguments,
|
|
clientlib::EvaluationKeys &evaluationKeys) {
|
|
// Call the lambda
|
|
auto publicResult = LambdaSupport<Lambda, CompilationResult>().serverCall(
|
|
lambda, publicArguments, evaluationKeys);
|
|
if (auto err = publicResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
|
|
// Decrypt the result
|
|
return typedResult<ResT>(keySet, **publicResult);
|
|
}
|
|
};
|
|
|
|
template <class LambdaSupport> class ClientServer {
|
|
public:
|
|
static llvm::Expected<ClientServer>
|
|
create(llvm::StringRef program,
|
|
CompilationOptions options = CompilationOptions("main"),
|
|
llvm::Optional<clientlib::KeySetCache> cache = {},
|
|
LambdaSupport support = LambdaSupport()) {
|
|
auto compilationResult = support.compile(program, options);
|
|
if (auto err = compilationResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
auto lambda = support.loadServerLambda(**compilationResult);
|
|
if (auto err = lambda.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
auto clientParameters = support.loadClientParameters(**compilationResult);
|
|
if (auto err = clientParameters.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
auto keySet = support.keySet(*clientParameters, cache);
|
|
if (auto err = keySet.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
auto f = ClientServer();
|
|
f.lambda = *lambda;
|
|
f.compilationResult = std::move(*compilationResult);
|
|
f.keySet = std::move(*keySet);
|
|
f.clientParameters = *clientParameters;
|
|
f.support = support;
|
|
return std::move(f);
|
|
}
|
|
|
|
template <typename ResT = uint64_t>
|
|
llvm::Expected<ResT> operator()(llvm::ArrayRef<LambdaArgument *> args) {
|
|
auto publicArguments = LambdaArgumentAdaptor::exportArguments(
|
|
args, clientParameters, *this->keySet);
|
|
|
|
if (auto err = publicArguments.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
|
|
auto evaluationKeys = this->keySet->evaluationKeys();
|
|
auto publicResult =
|
|
support.serverCall(lambda, **publicArguments, evaluationKeys);
|
|
if (auto err = publicResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
return typedResult<ResT>(*keySet, **publicResult);
|
|
}
|
|
|
|
template <typename T, typename ResT = uint64_t>
|
|
llvm::Expected<ResT> operator()(const llvm::ArrayRef<T> args) {
|
|
auto encryptedArgs = clientlib::EncryptedArguments::create(*keySet, args);
|
|
if (encryptedArgs.has_error()) {
|
|
return StreamStringError(encryptedArgs.error().mesg);
|
|
}
|
|
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
|
clientParameters, keySet->runtimeContext());
|
|
if (!publicArguments.has_value()) {
|
|
return StreamStringError(publicArguments.error().mesg);
|
|
}
|
|
auto evaluationKeys = keySet->evaluationKeys();
|
|
auto publicResult =
|
|
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
|
|
if (auto err = publicResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
return typedResult<ResT>(*keySet, **publicResult);
|
|
}
|
|
|
|
template <typename ResT = uint64_t, typename... Args>
|
|
llvm::Expected<ResT> operator()(const Args... args) {
|
|
auto encryptedArgs =
|
|
clientlib::EncryptedArguments::create(*keySet, args...);
|
|
if (encryptedArgs.has_error()) {
|
|
return StreamStringError(encryptedArgs.error().mesg);
|
|
}
|
|
auto publicArguments = encryptedArgs.value()->exportPublicArguments(
|
|
clientParameters, keySet->runtimeContext());
|
|
if (publicArguments.has_error()) {
|
|
return StreamStringError(publicArguments.error().mesg);
|
|
}
|
|
auto evaluationKeys = keySet->evaluationKeys();
|
|
auto publicResult =
|
|
support.serverCall(lambda, *publicArguments.value(), evaluationKeys);
|
|
if (auto err = publicResult.takeError()) {
|
|
return std::move(err);
|
|
}
|
|
return typedResult<ResT>(*keySet, **publicResult);
|
|
}
|
|
|
|
private:
|
|
typename LambdaSupport::lambda lambda;
|
|
std::unique_ptr<typename LambdaSupport::compilationResult> compilationResult;
|
|
std::unique_ptr<clientlib::KeySet> keySet;
|
|
clientlib::ClientParameters clientParameters;
|
|
LambdaSupport support;
|
|
};
|
|
|
|
} // namespace concretelang
|
|
} // namespace mlir
|
|
|
|
#endif
|