mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
enhance(client/server): Don't decrypt directly from istream use a intermediate container to represent public result
This commit is contained in:
@@ -9,7 +9,7 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
@@ -33,15 +33,10 @@ class ClientLambda {
|
||||
/// Low-level class to create the client side view of a FHE function.
|
||||
public:
|
||||
virtual ~ClientLambda() = default;
|
||||
static outcome::checked<ClientLambda, StringError>
|
||||
/// Construct a ClientLambda from a ClientParameter file.
|
||||
load(std::string funcName, std::string jsonPath);
|
||||
|
||||
/// Emit a call to the given ostream, no meta-date are include, so it's the
|
||||
/// responsability of the the caller/callee to verify the add/verify the
|
||||
/// function to be called.
|
||||
outcome::checked<void, StringError>
|
||||
untypedSerializeCall(PublicArguments &publicArguments, std::ostream &ostream);
|
||||
/// Construct a ClientLambda from a ClientParameter file.
|
||||
static outcome::checked<ClientLambda, StringError> load(std::string funcName,
|
||||
std::string jsonPath);
|
||||
|
||||
/// Generate or get from cache a KeySet suitable for this ClientLambda
|
||||
outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
@@ -49,19 +44,19 @@ public:
|
||||
uint64_t seed_lsb);
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
decryptReturnedValues(KeySet &keySet, std::istream &istream);
|
||||
decryptReturnedValues(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
decryptReturnedScalar(KeySet &keySet, std::istream &istream);
|
||||
decryptReturnedScalar(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
decryptReturnedTensor1(KeySet &keySet, std::istream &istream);
|
||||
decryptReturnedTensor1(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
decryptReturnedTensor2(KeySet &keySet, std::istream &istream);
|
||||
decryptReturnedTensor2(KeySet &keySet, PublicResult &result);
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
decryptReturnedTensor3(KeySet &keySet, std::istream &istream);
|
||||
decryptReturnedTensor3(KeySet &keySet, PublicResult &result);
|
||||
|
||||
public:
|
||||
ClientParameters clientParameters;
|
||||
@@ -70,7 +65,7 @@ public:
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream);
|
||||
PublicResult &result);
|
||||
|
||||
template <typename Result, typename... Args>
|
||||
class TypedClientLambda : public ClientLambda {
|
||||
@@ -90,19 +85,21 @@ public:
|
||||
serializeCall(Args... args, std::shared_ptr<KeySet> keySet,
|
||||
std::ostream &ostream) {
|
||||
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
|
||||
return ClientLambda::untypedSerializeCall(publicArguments, ostream);
|
||||
return publicArguments->serialize(ostream);
|
||||
}
|
||||
|
||||
outcome::checked<PublicArguments, StringError>
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
publicArguments(Args... args, std::shared_ptr<KeySet> keySet) {
|
||||
OUTCOME_TRY(auto clientArguments, EncryptedArgs::create(keySet, args...));
|
||||
return clientArguments->asPublicArguments(clientParameters,
|
||||
keySet->runtimeContext());
|
||||
OUTCOME_TRY(auto clientArguments,
|
||||
EncryptedArguments::create(keySet, args...));
|
||||
|
||||
return clientArguments->exportPublicArguments(clientParameters,
|
||||
keySet->runtimeContext());
|
||||
}
|
||||
|
||||
outcome::checked<Result, StringError> decryptReturned(KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return topLevelDecryptResult<Result>((*this), keySet, istream);
|
||||
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,
|
||||
PublicResult &result) {
|
||||
return topLevelDecryptResult<Result>((*this), keySet, result);
|
||||
}
|
||||
|
||||
TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) {
|
||||
@@ -115,25 +112,25 @@ protected:
|
||||
template <typename Result_>
|
||||
friend outcome::checked<Result_, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream);
|
||||
PublicResult &result);
|
||||
};
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream);
|
||||
PublicResult &result);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream);
|
||||
PublicResult &result);
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream);
|
||||
PublicResult &result);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -23,39 +23,52 @@ using concretelang::error::StringError;
|
||||
|
||||
class PublicArguments;
|
||||
|
||||
class EncryptedArgs {
|
||||
class EncryptedArguments {
|
||||
/// Temporary object used to hold and encrypt parameters before calling a
|
||||
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
|
||||
/// Otherwise convert it to a PublicArguments and use
|
||||
/// serializeCall(PublicArguments, KeySet).
|
||||
public:
|
||||
// Create EncryptedArgument that use the given KeySet to perform
|
||||
// encryption/decryption operations.
|
||||
EncryptedArguments() : currentPos(0) {}
|
||||
|
||||
/// Encrypts args thanks the given KeySet and pack the encrypted arguments to
|
||||
/// an EncryptedArguments
|
||||
template <typename... Args>
|
||||
static outcome::checked<std::shared_ptr<EncryptedArgs>, StringError>
|
||||
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
|
||||
create(std::shared_ptr<KeySet> keySet, Args... args) {
|
||||
auto arguments = std::make_shared<EncryptedArgs>();
|
||||
auto arguments = std::make_unique<EncryptedArguments>();
|
||||
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
|
||||
return arguments;
|
||||
}
|
||||
|
||||
/** Low level interface */
|
||||
/// Export encrypted arguments as public arguments, reset the encrypted
|
||||
/// arguments, i.e. move all buffers to the PublicArguments and reset the
|
||||
/// positional counter.
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext);
|
||||
|
||||
public:
|
||||
// Add a scalar argument.
|
||||
/// Add a uint8_t scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint8_t arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
// Add a vector-tensor argument.
|
||||
// Add a uint64_t scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
/// Add a vector-tensor argument.
|
||||
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
/// Add a 1D tensor argument.
|
||||
template <size_t size>
|
||||
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {size}, keySet);
|
||||
}
|
||||
|
||||
// Add a matrix-tensor argument.
|
||||
/// Add a 2D tensor argument.
|
||||
template <size_t size0, size_t size1>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<uint8_t, size1>, size0> arg,
|
||||
@@ -63,7 +76,7 @@ public:
|
||||
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
|
||||
}
|
||||
|
||||
// Add a rank3 tensor.
|
||||
/// Add a 3D tensor argument.
|
||||
template <size_t size0, size_t size1, size_t size2>
|
||||
outcome::checked<void, StringError>
|
||||
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
|
||||
@@ -92,6 +105,7 @@ public:
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
/// Push a variadic list of arguments.
|
||||
template <typename Arg0, typename... OtherArgs>
|
||||
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet,
|
||||
Arg0 arg0, OtherArgs... others) {
|
||||
@@ -99,26 +113,18 @@ public:
|
||||
return pushArgs(keySet, others...);
|
||||
}
|
||||
|
||||
// Terminal case of pushArgs
|
||||
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet) {
|
||||
return checkAllArgs(keySet);
|
||||
}
|
||||
|
||||
outcome::checked<PublicArguments, StringError>
|
||||
asPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext);
|
||||
|
||||
EncryptedArgs();
|
||||
~EncryptedArgs();
|
||||
|
||||
private:
|
||||
outcome::checked<void, StringError>
|
||||
checkPushTooManyArgs(std::shared_ptr<KeySet> keySetPtr);
|
||||
checkPushTooManyArgs(std::shared_ptr<KeySet> keySet);
|
||||
outcome::checked<void, StringError>
|
||||
checkAllArgs(std::shared_ptr<KeySet> keySet);
|
||||
// Add a scalar argument.
|
||||
outcome::checked<void, StringError> pushArg(uint64_t arg,
|
||||
std::shared_ptr<KeySet> keySet);
|
||||
|
||||
private:
|
||||
// Position of the next pushed argument
|
||||
size_t currentPos;
|
||||
std::vector<void *> preparedArgs;
|
||||
@@ -29,6 +29,7 @@ class KeySet {
|
||||
public:
|
||||
KeySet();
|
||||
~KeySet();
|
||||
KeySet(KeySet &other) = delete;
|
||||
|
||||
// allocate a KeySet according the ClientParameters.
|
||||
static outcome::checked<std::unique_ptr<KeySet>, StringError>
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "boost/outcome.h"
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/Types.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
@@ -26,7 +26,7 @@ namespace clientlib {
|
||||
|
||||
using concretelang::error::StringError;
|
||||
|
||||
class EncryptedArgs;
|
||||
class EncryptedArguments;
|
||||
class PublicArguments {
|
||||
/// PublicArguments will be sended to the server. It includes encrypted
|
||||
/// arguments and public keys.
|
||||
@@ -35,12 +35,9 @@ public:
|
||||
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
|
||||
bool clearRuntimeContext, std::vector<void *> &&preparedArgs,
|
||||
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers);
|
||||
PublicArguments(PublicArguments &other) = delete;
|
||||
// to have proper owership transfer (outcome and local object)
|
||||
PublicArguments(PublicArguments &&other);
|
||||
~PublicArguments();
|
||||
|
||||
void freeIfNotOwned(std::vector<encrypted_scalar_t> res);
|
||||
PublicArguments(PublicArguments &other) = delete;
|
||||
PublicArguments(PublicArguments &&other) = delete;
|
||||
|
||||
static outcome::checked<std::shared_ptr<PublicArguments>, StringError>
|
||||
unserialize(ClientParameters &expectedParams, std::istream &istream);
|
||||
@@ -57,9 +54,44 @@ private:
|
||||
std::vector<void *> preparedArgs;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
|
||||
|
||||
// Indicates if this public argument own the runtime keys.
|
||||
bool clearRuntimeContext;
|
||||
};
|
||||
|
||||
struct PublicResult {
|
||||
/// PublicResult is a result of a ServerLambda call which contains encrypted
|
||||
/// results.
|
||||
|
||||
PublicResult(const ClientParameters &clientParameters,
|
||||
std::vector<encrypted_scalars_and_sizes_t> buffers = {})
|
||||
: clientParameters(clientParameters), buffers(buffers){};
|
||||
|
||||
PublicResult(PublicResult &) = delete;
|
||||
|
||||
/// Create a public result from buffers.
|
||||
static std::unique_ptr<PublicResult>
|
||||
fromBuffers(const ClientParameters &clientParameters,
|
||||
std::vector<encrypted_scalars_and_sizes_t> buffers) {
|
||||
return std::make_unique<PublicResult>(clientParameters, buffers);
|
||||
}
|
||||
|
||||
/// Unserialize from a input stream.
|
||||
outcome::checked<void, StringError> unserialize(std::istream &istream);
|
||||
|
||||
/// Serialize into an output stream.
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
/// Decrypt the result at `pos` as a vector.
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
decryptVector(KeySet &keySet, size_t pos);
|
||||
|
||||
private:
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
ClientParameters clientParameters;
|
||||
std::vector<encrypted_scalars_and_sizes_t> buffers;
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
|
||||
@@ -27,17 +27,23 @@ encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
|
||||
size_t rank, encrypted_scalars_t allocated, encrypted_scalars_t aligned,
|
||||
size_t offset, size_t *sizes, size_t *strides);
|
||||
|
||||
/// ServerLambda is a utility class that allows to call a function of a
|
||||
/// compilation result.
|
||||
class ServerLambda {
|
||||
|
||||
public:
|
||||
/// Load the symbol `funcName` of the compilation result located at the path
|
||||
/// `outputLib`.
|
||||
static outcome::checked<ServerLambda, concretelang::error::StringError>
|
||||
load(std::string funcName, std::string outputLib);
|
||||
|
||||
/// Load the symbol `funcName` of the dynamic loaded library
|
||||
static outcome::checked<ServerLambda, concretelang::error::StringError>
|
||||
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
|
||||
|
||||
outcome::checked<void, concretelang::error::StringError>
|
||||
read_call_write(std::istream &istream, std::ostream &ostream);
|
||||
/// Call the ServerLambda with public arguments.
|
||||
std::unique_ptr<clientlib::PublicResult>
|
||||
call(clientlib::PublicArguments &args);
|
||||
|
||||
protected:
|
||||
ClientParameters clientParameters;
|
||||
|
||||
@@ -63,41 +63,17 @@ public:
|
||||
keySet(keySet) {}
|
||||
|
||||
outcome::checked<Result, StringError> call(Args... args) {
|
||||
// client
|
||||
auto BINARY = std::ios::binary;
|
||||
std::string message;
|
||||
{
|
||||
// client
|
||||
std::ostringstream clientOuput(BINARY);
|
||||
OUTCOME_TRYV(this->serializeCall(args..., keySet, clientOuput));
|
||||
if (clientOuput.fail()) {
|
||||
return StringError("Error on clientOuput");
|
||||
}
|
||||
message = clientOuput.str();
|
||||
}
|
||||
{
|
||||
// server
|
||||
std::istringstream serverInput(message, BINARY);
|
||||
freeStringMemory(message);
|
||||
assert(serverInput.tellg() == 0);
|
||||
std::ostringstream serverOutput(BINARY);
|
||||
OUTCOME_TRYV(serverLambda.read_call_write(serverInput, serverOutput));
|
||||
if (serverInput.fail()) {
|
||||
return StringError("Error on serverInput");
|
||||
}
|
||||
if (serverOutput.fail()) {
|
||||
return StringError("Error on serverOutput");
|
||||
}
|
||||
message = serverOutput.str();
|
||||
}
|
||||
{
|
||||
// client
|
||||
std::istringstream clientInput(message, BINARY);
|
||||
freeStringMemory(message);
|
||||
OUTCOME_TRY(auto result, this->decryptReturned(*keySet, clientInput));
|
||||
assert(clientInput.good());
|
||||
return result;
|
||||
}
|
||||
// client argument encryption
|
||||
OUTCOME_TRY(auto encryptedArgs,
|
||||
clientlib::EncryptedArguments::create(keySet, args...));
|
||||
OUTCOME_TRY(auto publicArgument,
|
||||
encryptedArgs->exportPublicArguments(this->clientParameters,
|
||||
keySet->runtimeContext()));
|
||||
// server function call
|
||||
auto publicResult = serverLambda.call(*publicArgument);
|
||||
|
||||
// client result decryption
|
||||
return this->decryptResult(*keySet, *publicResult);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@@ -14,7 +14,7 @@ add_mlir_library(
|
||||
ConcretelangClientLib
|
||||
ClientLambda.cpp
|
||||
ClientParameters.cpp
|
||||
EncryptedArgs.cpp
|
||||
EncryptedArguments.cpp
|
||||
KeySet.cpp
|
||||
KeySetCache.cpp
|
||||
PublicArguments.cpp
|
||||
|
||||
@@ -46,35 +46,15 @@ ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
|
||||
seed_lsb);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
ClientLambda::untypedSerializeCall(PublicArguments &serverArguments,
|
||||
std::ostream &ostream) {
|
||||
return serverArguments.serialize(ostream);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
ClientLambda::decryptReturnedScalar(KeySet &keySet, std::istream &istream) {
|
||||
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, istream));
|
||||
ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
|
||||
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, result));
|
||||
return v[0];
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, std::istream &istream) {
|
||||
auto lweSize =
|
||||
clientParameters.lweSecretKeyParam(clientParameters.outputs[0]).lweSize();
|
||||
std::vector<int64_t> sizes = clientParameters.outputs[0].shape.dimensions;
|
||||
sizes.push_back(lweSize);
|
||||
auto encryptedValues = unserializeEncryptedValues(sizes, istream);
|
||||
if (istream.fail()) {
|
||||
return StringError("Encrypted scalars has not the right size");
|
||||
}
|
||||
auto len = encryptedValues.length();
|
||||
decrypted_tensor_1_t decryptedValues(len / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto buffer = (uint64_t *)(&encryptedValues.values[i * lweSize]);
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, buffer, decryptedValues[i]));
|
||||
}
|
||||
return decryptedValues;
|
||||
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
|
||||
return result.decryptVector(keySet, 0);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError> errorResultRank(size_t expected,
|
||||
@@ -128,7 +108,7 @@ decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
|
||||
|
||||
template <typename DecryptedTensor>
|
||||
outcome::checked<DecryptedTensor, StringError>
|
||||
decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
|
||||
decryptReturnedTensor(PublicResult &result, ClientLambda &lambda,
|
||||
ClientParameters ¶ms, size_t expectedRank,
|
||||
KeySet &keySet) {
|
||||
auto shape = params.outputs[0].shape;
|
||||
@@ -137,7 +117,7 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
|
||||
return StringError("Function returns a tensor of rank ")
|
||||
<< expectedRank << " which cannot be decrypted to rank " << rank;
|
||||
}
|
||||
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, istream));
|
||||
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, result));
|
||||
llvm::SmallVector<size_t, 6> sizes;
|
||||
for (size_t dim = 0; dim < rank; dim++) {
|
||||
sizes.push_back(shape.dimensions[dim]);
|
||||
@@ -146,27 +126,27 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor1(KeySet &keySet, std::istream &istream) {
|
||||
ClientLambda::decryptReturnedTensor1(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_1_t>(
|
||||
istream, *this, this->clientParameters, 1, keySet);
|
||||
result, *this, this->clientParameters, 1, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor2(KeySet &keySet, std::istream &istream) {
|
||||
ClientLambda::decryptReturnedTensor2(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_2_t>(
|
||||
istream, *this, this->clientParameters, 2, keySet);
|
||||
result, *this, this->clientParameters, 2, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
ClientLambda::decryptReturnedTensor3(KeySet &keySet, std::istream &istream) {
|
||||
ClientLambda::decryptReturnedTensor3(KeySet &keySet, PublicResult &result) {
|
||||
return decryptReturnedTensor<decrypted_tensor_3_t>(
|
||||
istream, *this, this->clientParameters, 3, keySet);
|
||||
result, *this, this->clientParameters, 3, keySet);
|
||||
}
|
||||
|
||||
template <typename Result>
|
||||
outcome::checked<Result, StringError>
|
||||
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
PublicResult &result) {
|
||||
// compile time error if used
|
||||
using COMPATIBLE_RESULT_TYPE = void;
|
||||
return (Result)(COMPATIBLE_RESULT_TYPE)0;
|
||||
@@ -175,32 +155,32 @@ topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
|
||||
template <>
|
||||
outcome::checked<decrypted_scalar_t, StringError>
|
||||
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedScalar(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedScalar(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_1_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor1(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor1(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_2_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor2(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor2(keySet, result);
|
||||
}
|
||||
|
||||
template <>
|
||||
outcome::checked<decrypted_tensor_3_t, StringError>
|
||||
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
|
||||
KeySet &keySet,
|
||||
std::istream &istream) {
|
||||
return lambda.decryptReturnedTensor3(keySet, istream);
|
||||
PublicResult &result) {
|
||||
return lambda.decryptReturnedTensor3(keySet, result);
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
#include "concretelang/ClientLib/PublicArguments.h"
|
||||
|
||||
namespace concretelang {
|
||||
@@ -11,20 +11,23 @@ namespace clientlib {
|
||||
|
||||
using StringError = concretelang::error::StringError;
|
||||
|
||||
EncryptedArgs::~EncryptedArgs() {
|
||||
// There is no explicit allocation
|
||||
// All buffers are owned by ciphertextBuffers
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext) {
|
||||
// On client side the runtimeContext is hold by the KeySet
|
||||
bool clearContext = false;
|
||||
return std::make_unique<PublicArguments>(
|
||||
clientParameters, runtimeContext, clearContext, std::move(preparedArgs),
|
||||
std::move(ciphertextBuffers));
|
||||
}
|
||||
|
||||
EncryptedArgs::EncryptedArgs() : currentPos(0) {}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg((uint64_t)arg, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
// TODO: NON ENCRYPTED
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
@@ -65,14 +68,15 @@ EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(std::vector<uint8_t> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::pushArg(std::vector<uint8_t> arg,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::pushArg(size_t width, void *data,
|
||||
llvm::ArrayRef<int64_t> shape,
|
||||
std::shared_ptr<KeySet> keySet) {
|
||||
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
|
||||
auto pos = currentPos;
|
||||
CircuitGate input = keySet->inputGate(pos);
|
||||
@@ -148,7 +152,7 @@ EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape,
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
|
||||
size_t arity = keySet->numInputs();
|
||||
if (currentPos < arity) {
|
||||
return outcome::success();
|
||||
@@ -158,7 +162,7 @@ EncryptedArgs::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
EncryptedArgs::checkAllArgs(std::shared_ptr<KeySet> keySet) {
|
||||
EncryptedArguments::checkAllArgs(std::shared_ptr<KeySet> keySet) {
|
||||
size_t arity = keySet->numInputs();
|
||||
if (currentPos == arity) {
|
||||
return outcome::success();
|
||||
@@ -168,14 +172,5 @@ EncryptedArgs::checkAllArgs(std::shared_ptr<KeySet> keySet) {
|
||||
<< " arguments";
|
||||
}
|
||||
|
||||
outcome::checked<PublicArguments, StringError>
|
||||
EncryptedArgs::asPublicArguments(ClientParameters clientParameters,
|
||||
RuntimeContext runtimeContext) {
|
||||
// On client side the runtimeContext is hold by the KeySet
|
||||
bool clearContext = false;
|
||||
return PublicArguments(clientParameters, runtimeContext, clearContext,
|
||||
std::move(preparedArgs), std::move(ciphertextBuffers));
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
@@ -230,7 +230,6 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
|
||||
|
||||
if (argPos >= outputs.size()) {
|
||||
return StringError("decrypt_lwe: position of argument is too high");
|
||||
}
|
||||
|
||||
@@ -29,18 +29,6 @@ PublicArguments::PublicArguments(
|
||||
ciphertextBuffers = std::move(ciphertextBuffers_);
|
||||
}
|
||||
|
||||
PublicArguments::PublicArguments(PublicArguments &&other) {
|
||||
clientParameters = other.clientParameters;
|
||||
runtimeContext = other.runtimeContext;
|
||||
runtimeContext.bsk = std::move(other.runtimeContext.bsk);
|
||||
clearRuntimeContext = other.clearRuntimeContext;
|
||||
preparedArgs = std::move(other.preparedArgs);
|
||||
ciphertextBuffers = std::move(other.ciphertextBuffers);
|
||||
// transfer ownership
|
||||
other.clearRuntimeContext = false;
|
||||
other.runtimeContext.ksk = nullptr;
|
||||
}
|
||||
|
||||
PublicArguments::~PublicArguments() {
|
||||
if (!clearRuntimeContext) {
|
||||
return;
|
||||
@@ -68,7 +56,7 @@ PublicArguments::serialize(std::ostream &ostream) {
|
||||
size_t rank = gate.shape.dimensions.size();
|
||||
if (!gate.encryption.hasValue()) {
|
||||
return StringError("PublicArguments::serialize: Clear arguments "
|
||||
"are not supported. Argument ")
|
||||
"are not yet supported. Argument ")
|
||||
<< iGate;
|
||||
}
|
||||
/*auto allocated = */ preparedArgs[iPreparedArgs++];
|
||||
@@ -140,14 +128,27 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
}
|
||||
std::vector<void *> empty;
|
||||
std::vector<encrypted_scalars_and_sizes_t> emptyBuffers;
|
||||
// On server side the PublicArguments is responsible for the context
|
||||
auto clearRuntimeContext = true;
|
||||
auto sArguments = std::make_shared<PublicArguments>(
|
||||
clientParameters, runtimeContext, clearRuntimeContext, std::move(empty),
|
||||
clientParameters, runtimeContext, true, std::move(empty),
|
||||
std::move(emptyBuffers));
|
||||
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
|
||||
return sArguments;
|
||||
}
|
||||
|
||||
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
|
||||
PublicResult::decryptVector(KeySet &keySet, size_t pos) {
|
||||
auto lweSize =
|
||||
clientParameters.lweSecretKeyParam(clientParameters.outputs[pos])
|
||||
.lweSize();
|
||||
|
||||
auto buffer = buffers[pos];
|
||||
decrypted_tensor_1_t decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = &buffer.values[i * lweSize];
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i]));
|
||||
}
|
||||
return decryptedValues;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -123,37 +123,20 @@ ServerLambda::load(std::string funcName, std::string outputLib) {
|
||||
|
||||
encrypted_scalars_and_sizes_t dynamicCall(void *(*func)(void *...),
|
||||
std::vector<void *> &preparedArgs,
|
||||
CircuitGate &output,
|
||||
std::ostream &ostream) {
|
||||
CircuitGate &output) {
|
||||
size_t rank = output.shape.dimensions.size();
|
||||
return multi_arity_call_dynamic_rank(func, preparedArgs, rank);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
ServerLambda::read_call_write(std::istream &istream, std::ostream &ostream) {
|
||||
OUTCOME_TRY(auto argumentsPtr,
|
||||
PublicArguments::unserialize(clientParameters, istream));
|
||||
assert(istream.good());
|
||||
|
||||
PublicArguments &arguments = *argumentsPtr;
|
||||
// The runtime context is always the last argument list
|
||||
arguments.preparedArgs.push_back((void *)&arguments.runtimeContext);
|
||||
auto values_and_sizes = dynamicCall(this->func, arguments.preparedArgs,
|
||||
clientParameters.outputs[0], ostream);
|
||||
auto shape = clientParameters.outputs[0].shape;
|
||||
size_t rank = shape.dimensions.size();
|
||||
for (size_t dim = 0; dim < rank; dim++) {
|
||||
if (values_and_sizes.sizes[dim] != (size_t)shape.dimensions[dim]) {
|
||||
return StringError("Dimension mismatch on dim ")
|
||||
<< dim << " actual: " << values_and_sizes.sizes[dim]
|
||||
<< " vs expected: " << shape.dimensions[dim] << "\n";
|
||||
}
|
||||
}
|
||||
serializeEncryptedValues(values_and_sizes, ostream);
|
||||
if (ostream.fail()) {
|
||||
return StringError("Cannot write result");
|
||||
}
|
||||
return outcome::success();
|
||||
std::unique_ptr<clientlib::PublicResult>
|
||||
ServerLambda::call(PublicArguments &args) {
|
||||
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
|
||||
args.preparedArgs.end());
|
||||
preparedArgs.push_back((void *)&args.runtimeContext);
|
||||
return clientlib::PublicResult::fromBuffers(
|
||||
clientParameters,
|
||||
{dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])});
|
||||
;
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "../unittest/end_to_end_jit_test.h"
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EncryptedArgs.h"
|
||||
#include "concretelang/ClientLib/EncryptedArguments.h"
|
||||
|
||||
namespace clientlib = concretelang::clientlib;
|
||||
|
||||
|
||||
@@ -73,10 +73,11 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
ASSERT_TRUE(maybeKeySet.has_value());
|
||||
std::shared_ptr<KeySet> keySet = std::move(maybeKeySet.value());
|
||||
auto maybePublicArguments = lambda.publicArguments(1, keySet);
|
||||
|
||||
ASSERT_TRUE(maybePublicArguments.has_value());
|
||||
auto publicArguments = std::move(maybePublicArguments.value());
|
||||
std::ostringstream osstream(std::ios::binary);
|
||||
EXPECT_TRUE(lambda.untypedSerializeCall(publicArguments, osstream));
|
||||
ASSERT_TRUE(publicArguments->serialize(osstream).has_value());
|
||||
EXPECT_TRUE(osstream.good());
|
||||
// Direct call without intermediate
|
||||
EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream));
|
||||
@@ -119,6 +120,7 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
}
|
||||
|
||||
TEST(CompiledModule, call_2s_1s) {
|
||||
|
||||
std::string source = R"(
|
||||
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
|
||||
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)
|
||||
|
||||
Reference in New Issue
Block a user