enhance(client/server): Don't decrypt directly from istream use a intermediate container to represent public result

This commit is contained in:
Quentin Bourgerie
2022-02-28 16:54:29 +01:00
parent 69037cd1fa
commit 73da7da81c
14 changed files with 181 additions and 203 deletions

View File

@@ -9,7 +9,7 @@
#include <cassert>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArgs.h"
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/KeySet.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/PublicArguments.h"
@@ -33,15 +33,10 @@ class ClientLambda {
/// Low-level class to create the client side view of a FHE function.
public:
virtual ~ClientLambda() = default;
static outcome::checked<ClientLambda, StringError>
/// Construct a ClientLambda from a ClientParameter file.
load(std::string funcName, std::string jsonPath);
/// Emit a call to the given ostream, no meta-date are include, so it's the
/// responsability of the the caller/callee to verify the add/verify the
/// function to be called.
outcome::checked<void, StringError>
untypedSerializeCall(PublicArguments &publicArguments, std::ostream &ostream);
/// Construct a ClientLambda from a ClientParameter file.
static outcome::checked<ClientLambda, StringError> load(std::string funcName,
std::string jsonPath);
/// Generate or get from cache a KeySet suitable for this ClientLambda
outcome::checked<std::unique_ptr<KeySet>, StringError>
@@ -49,19 +44,19 @@ public:
uint64_t seed_lsb);
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
decryptReturnedValues(KeySet &keySet, std::istream &istream);
decryptReturnedValues(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_scalar_t, StringError>
decryptReturnedScalar(KeySet &keySet, std::istream &istream);
decryptReturnedScalar(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_tensor_1_t, StringError>
decryptReturnedTensor1(KeySet &keySet, std::istream &istream);
decryptReturnedTensor1(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_tensor_2_t, StringError>
decryptReturnedTensor2(KeySet &keySet, std::istream &istream);
decryptReturnedTensor2(KeySet &keySet, PublicResult &result);
outcome::checked<decrypted_tensor_3_t, StringError>
decryptReturnedTensor3(KeySet &keySet, std::istream &istream);
decryptReturnedTensor3(KeySet &keySet, PublicResult &result);
public:
ClientParameters clientParameters;
@@ -70,7 +65,7 @@ public:
template <typename Result>
outcome::checked<Result, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
std::istream &istream);
PublicResult &result);
template <typename Result, typename... Args>
class TypedClientLambda : public ClientLambda {
@@ -90,19 +85,21 @@ public:
serializeCall(Args... args, std::shared_ptr<KeySet> keySet,
std::ostream &ostream) {
OUTCOME_TRY(auto publicArguments, publicArguments(args..., keySet));
return ClientLambda::untypedSerializeCall(publicArguments, ostream);
return publicArguments->serialize(ostream);
}
outcome::checked<PublicArguments, StringError>
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
publicArguments(Args... args, std::shared_ptr<KeySet> keySet) {
OUTCOME_TRY(auto clientArguments, EncryptedArgs::create(keySet, args...));
return clientArguments->asPublicArguments(clientParameters,
keySet->runtimeContext());
OUTCOME_TRY(auto clientArguments,
EncryptedArguments::create(keySet, args...));
return clientArguments->exportPublicArguments(clientParameters,
keySet->runtimeContext());
}
outcome::checked<Result, StringError> decryptReturned(KeySet &keySet,
std::istream &istream) {
return topLevelDecryptResult<Result>((*this), keySet, istream);
outcome::checked<Result, StringError> decryptResult(KeySet &keySet,
PublicResult &result) {
return topLevelDecryptResult<Result>((*this), keySet, result);
}
TypedClientLambda(ClientLambda &lambda) : ClientLambda(lambda) {
@@ -115,25 +112,25 @@ protected:
template <typename Result_>
friend outcome::checked<Result_, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
std::istream &istream);
PublicResult &result);
};
template <>
outcome::checked<decrypted_scalar_t, StringError>
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
std::istream &istream);
PublicResult &result);
template <>
outcome::checked<decrypted_tensor_1_t, StringError>
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream);
PublicResult &result);
template <>
outcome::checked<decrypted_tensor_2_t, StringError>
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream);
PublicResult &result);
} // namespace clientlib
} // namespace concretelang

View File

@@ -23,39 +23,52 @@ using concretelang::error::StringError;
class PublicArguments;
class EncryptedArgs {
class EncryptedArguments {
/// Temporary object used to hold and encrypt parameters before calling a
/// ClientLambda. Use preferably TypeClientLambda and serializeCall(Args...).
/// Otherwise convert it to a PublicArguments and use
/// serializeCall(PublicArguments, KeySet).
public:
// Create EncryptedArgument that use the given KeySet to perform
// encryption/decryption operations.
EncryptedArguments() : currentPos(0) {}
/// Encrypts args thanks the given KeySet and pack the encrypted arguments to
/// an EncryptedArguments
template <typename... Args>
static outcome::checked<std::shared_ptr<EncryptedArgs>, StringError>
static outcome::checked<std::unique_ptr<EncryptedArguments>, StringError>
create(std::shared_ptr<KeySet> keySet, Args... args) {
auto arguments = std::make_shared<EncryptedArgs>();
auto arguments = std::make_unique<EncryptedArguments>();
OUTCOME_TRYV(arguments->pushArgs(keySet, args...));
return arguments;
}
/** Low level interface */
/// Export encrypted arguments as public arguments, reset the encrypted
/// arguments, i.e. move all buffers to the PublicArguments and reset the
/// positional counter.
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext);
public:
// Add a scalar argument.
/// Add a uint8_t scalar argument.
outcome::checked<void, StringError> pushArg(uint8_t arg,
std::shared_ptr<KeySet> keySet);
// Add a vector-tensor argument.
// Add a uint64_t scalar argument.
outcome::checked<void, StringError> pushArg(uint64_t arg,
std::shared_ptr<KeySet> keySet);
/// Add a vector-tensor argument.
outcome::checked<void, StringError> pushArg(std::vector<uint8_t> arg,
std::shared_ptr<KeySet> keySet);
/// Add a 1D tensor argument.
template <size_t size>
outcome::checked<void, StringError> pushArg(std::array<uint8_t, size> arg,
std::shared_ptr<KeySet> keySet) {
return pushArg(8, (void *)arg.data(), {size}, keySet);
}
// Add a matrix-tensor argument.
/// Add a 2D tensor argument.
template <size_t size0, size_t size1>
outcome::checked<void, StringError>
pushArg(std::array<std::array<uint8_t, size1>, size0> arg,
@@ -63,7 +76,7 @@ public:
return pushArg(8, (void *)arg.data(), {size0, size1}, keySet);
}
// Add a rank3 tensor.
/// Add a 3D tensor argument.
template <size_t size0, size_t size1, size_t size2>
outcome::checked<void, StringError>
pushArg(std::array<std::array<std::array<uint8_t, size2>, size1>, size0> arg,
@@ -92,6 +105,7 @@ public:
llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet);
/// Push a variadic list of arguments.
template <typename Arg0, typename... OtherArgs>
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet,
Arg0 arg0, OtherArgs... others) {
@@ -99,26 +113,18 @@ public:
return pushArgs(keySet, others...);
}
// Terminal case of pushArgs
outcome::checked<void, StringError> pushArgs(std::shared_ptr<KeySet> keySet) {
return checkAllArgs(keySet);
}
outcome::checked<PublicArguments, StringError>
asPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext);
EncryptedArgs();
~EncryptedArgs();
private:
outcome::checked<void, StringError>
checkPushTooManyArgs(std::shared_ptr<KeySet> keySetPtr);
checkPushTooManyArgs(std::shared_ptr<KeySet> keySet);
outcome::checked<void, StringError>
checkAllArgs(std::shared_ptr<KeySet> keySet);
// Add a scalar argument.
outcome::checked<void, StringError> pushArg(uint64_t arg,
std::shared_ptr<KeySet> keySet);
private:
// Position of the next pushed argument
size_t currentPos;
std::vector<void *> preparedArgs;

View File

@@ -29,6 +29,7 @@ class KeySet {
public:
KeySet();
~KeySet();
KeySet(KeySet &other) = delete;
// allocate a KeySet according the ClientParameters.
static outcome::checked<std::unique_ptr<KeySet>, StringError>

View File

@@ -11,7 +11,7 @@
#include "boost/outcome.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArgs.h"
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/Types.h"
#include "concretelang/Common/Error.h"
#include "concretelang/Runtime/context.h"
@@ -26,7 +26,7 @@ namespace clientlib {
using concretelang::error::StringError;
class EncryptedArgs;
class EncryptedArguments;
class PublicArguments {
/// PublicArguments will be sended to the server. It includes encrypted
/// arguments and public keys.
@@ -35,12 +35,9 @@ public:
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
bool clearRuntimeContext, std::vector<void *> &&preparedArgs,
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers);
PublicArguments(PublicArguments &other) = delete;
// to have proper owership transfer (outcome and local object)
PublicArguments(PublicArguments &&other);
~PublicArguments();
void freeIfNotOwned(std::vector<encrypted_scalar_t> res);
PublicArguments(PublicArguments &other) = delete;
PublicArguments(PublicArguments &&other) = delete;
static outcome::checked<std::shared_ptr<PublicArguments>, StringError>
unserialize(ClientParameters &expectedParams, std::istream &istream);
@@ -57,9 +54,44 @@ private:
std::vector<void *> preparedArgs;
// Store buffers of ciphertexts
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
// Indicates if this public argument own the runtime keys.
bool clearRuntimeContext;
};
struct PublicResult {
/// PublicResult is a result of a ServerLambda call which contains encrypted
/// results.
PublicResult(const ClientParameters &clientParameters,
std::vector<encrypted_scalars_and_sizes_t> buffers = {})
: clientParameters(clientParameters), buffers(buffers){};
PublicResult(PublicResult &) = delete;
/// Create a public result from buffers.
static std::unique_ptr<PublicResult>
fromBuffers(const ClientParameters &clientParameters,
std::vector<encrypted_scalars_and_sizes_t> buffers) {
return std::make_unique<PublicResult>(clientParameters, buffers);
}
/// Unserialize from a input stream.
outcome::checked<void, StringError> unserialize(std::istream &istream);
/// Serialize into an output stream.
outcome::checked<void, StringError> serialize(std::ostream &ostream);
/// Decrypt the result at `pos` as a vector.
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
decryptVector(KeySet &keySet, size_t pos);
private:
friend class ::concretelang::serverlib::ServerLambda;
ClientParameters clientParameters;
std::vector<encrypted_scalars_and_sizes_t> buffers;
};
} // namespace clientlib
} // namespace concretelang

View File

@@ -27,17 +27,23 @@ encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
size_t rank, encrypted_scalars_t allocated, encrypted_scalars_t aligned,
size_t offset, size_t *sizes, size_t *strides);
/// ServerLambda is a utility class that allows to call a function of a
/// compilation result.
class ServerLambda {
public:
/// Load the symbol `funcName` of the compilation result located at the path
/// `outputLib`.
static outcome::checked<ServerLambda, concretelang::error::StringError>
load(std::string funcName, std::string outputLib);
/// Load the symbol `funcName` of the dynamic loaded library
static outcome::checked<ServerLambda, concretelang::error::StringError>
loadFromModule(std::shared_ptr<DynamicModule> module, std::string funcName);
outcome::checked<void, concretelang::error::StringError>
read_call_write(std::istream &istream, std::ostream &ostream);
/// Call the ServerLambda with public arguments.
std::unique_ptr<clientlib::PublicResult>
call(clientlib::PublicArguments &args);
protected:
ClientParameters clientParameters;

View File

@@ -63,41 +63,17 @@ public:
keySet(keySet) {}
outcome::checked<Result, StringError> call(Args... args) {
// client
auto BINARY = std::ios::binary;
std::string message;
{
// client
std::ostringstream clientOuput(BINARY);
OUTCOME_TRYV(this->serializeCall(args..., keySet, clientOuput));
if (clientOuput.fail()) {
return StringError("Error on clientOuput");
}
message = clientOuput.str();
}
{
// server
std::istringstream serverInput(message, BINARY);
freeStringMemory(message);
assert(serverInput.tellg() == 0);
std::ostringstream serverOutput(BINARY);
OUTCOME_TRYV(serverLambda.read_call_write(serverInput, serverOutput));
if (serverInput.fail()) {
return StringError("Error on serverInput");
}
if (serverOutput.fail()) {
return StringError("Error on serverOutput");
}
message = serverOutput.str();
}
{
// client
std::istringstream clientInput(message, BINARY);
freeStringMemory(message);
OUTCOME_TRY(auto result, this->decryptReturned(*keySet, clientInput));
assert(clientInput.good());
return result;
}
// client argument encryption
OUTCOME_TRY(auto encryptedArgs,
clientlib::EncryptedArguments::create(keySet, args...));
OUTCOME_TRY(auto publicArgument,
encryptedArgs->exportPublicArguments(this->clientParameters,
keySet->runtimeContext()));
// server function call
auto publicResult = serverLambda.call(*publicArgument);
// client result decryption
return this->decryptResult(*keySet, *publicResult);
}
private:

View File

@@ -14,7 +14,7 @@ add_mlir_library(
ConcretelangClientLib
ClientLambda.cpp
ClientParameters.cpp
EncryptedArgs.cpp
EncryptedArguments.cpp
KeySet.cpp
KeySetCache.cpp
PublicArguments.cpp

View File

@@ -46,35 +46,15 @@ ClientLambda::keySet(std::shared_ptr<KeySetCache> optionalCache,
seed_lsb);
}
outcome::checked<void, StringError>
ClientLambda::untypedSerializeCall(PublicArguments &serverArguments,
std::ostream &ostream) {
return serverArguments.serialize(ostream);
}
outcome::checked<decrypted_scalar_t, StringError>
ClientLambda::decryptReturnedScalar(KeySet &keySet, std::istream &istream) {
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, istream));
ClientLambda::decryptReturnedScalar(KeySet &keySet, PublicResult &result) {
OUTCOME_TRY(auto v, decryptReturnedValues(keySet, result));
return v[0];
}
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
ClientLambda::decryptReturnedValues(KeySet &keySet, std::istream &istream) {
auto lweSize =
clientParameters.lweSecretKeyParam(clientParameters.outputs[0]).lweSize();
std::vector<int64_t> sizes = clientParameters.outputs[0].shape.dimensions;
sizes.push_back(lweSize);
auto encryptedValues = unserializeEncryptedValues(sizes, istream);
if (istream.fail()) {
return StringError("Encrypted scalars has not the right size");
}
auto len = encryptedValues.length();
decrypted_tensor_1_t decryptedValues(len / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto buffer = (uint64_t *)(&encryptedValues.values[i * lweSize]);
OUTCOME_TRYV(keySet.decrypt_lwe(0, buffer, decryptedValues[i]));
}
return decryptedValues;
ClientLambda::decryptReturnedValues(KeySet &keySet, PublicResult &result) {
return result.decryptVector(keySet, 0);
}
outcome::checked<void, StringError> errorResultRank(size_t expected,
@@ -128,7 +108,7 @@ decrypted_tensor_3_t flatToTensor(decrypted_tensor_1_t &values, size_t *sizes) {
template <typename DecryptedTensor>
outcome::checked<DecryptedTensor, StringError>
decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
decryptReturnedTensor(PublicResult &result, ClientLambda &lambda,
ClientParameters &params, size_t expectedRank,
KeySet &keySet) {
auto shape = params.outputs[0].shape;
@@ -137,7 +117,7 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
return StringError("Function returns a tensor of rank ")
<< expectedRank << " which cannot be decrypted to rank " << rank;
}
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, istream));
OUTCOME_TRY(auto values, lambda.decryptReturnedValues(keySet, result));
llvm::SmallVector<size_t, 6> sizes;
for (size_t dim = 0; dim < rank; dim++) {
sizes.push_back(shape.dimensions[dim]);
@@ -146,27 +126,27 @@ decryptReturnedTensor(std::istream &istream, ClientLambda &lambda,
}
outcome::checked<decrypted_tensor_1_t, StringError>
ClientLambda::decryptReturnedTensor1(KeySet &keySet, std::istream &istream) {
ClientLambda::decryptReturnedTensor1(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_1_t>(
istream, *this, this->clientParameters, 1, keySet);
result, *this, this->clientParameters, 1, keySet);
}
outcome::checked<decrypted_tensor_2_t, StringError>
ClientLambda::decryptReturnedTensor2(KeySet &keySet, std::istream &istream) {
ClientLambda::decryptReturnedTensor2(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_2_t>(
istream, *this, this->clientParameters, 2, keySet);
result, *this, this->clientParameters, 2, keySet);
}
outcome::checked<decrypted_tensor_3_t, StringError>
ClientLambda::decryptReturnedTensor3(KeySet &keySet, std::istream &istream) {
ClientLambda::decryptReturnedTensor3(KeySet &keySet, PublicResult &result) {
return decryptReturnedTensor<decrypted_tensor_3_t>(
istream, *this, this->clientParameters, 3, keySet);
result, *this, this->clientParameters, 3, keySet);
}
template <typename Result>
outcome::checked<Result, StringError>
topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
std::istream &istream) {
PublicResult &result) {
// compile time error if used
using COMPATIBLE_RESULT_TYPE = void;
return (Result)(COMPATIBLE_RESULT_TYPE)0;
@@ -175,32 +155,32 @@ topLevelDecryptResult(ClientLambda &lambda, KeySet &keySet,
template <>
outcome::checked<decrypted_scalar_t, StringError>
topLevelDecryptResult<decrypted_scalar_t>(ClientLambda &lambda, KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedScalar(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedScalar(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_1_t, StringError>
topLevelDecryptResult<decrypted_tensor_1_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor1(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedTensor1(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_2_t, StringError>
topLevelDecryptResult<decrypted_tensor_2_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor2(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedTensor2(keySet, result);
}
template <>
outcome::checked<decrypted_tensor_3_t, StringError>
topLevelDecryptResult<decrypted_tensor_3_t>(ClientLambda &lambda,
KeySet &keySet,
std::istream &istream) {
return lambda.decryptReturnedTensor3(keySet, istream);
PublicResult &result) {
return lambda.decryptReturnedTensor3(keySet, result);
}
} // namespace clientlib

View File

@@ -3,7 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#include "concretelang/ClientLib/EncryptedArgs.h"
#include "concretelang/ClientLib/EncryptedArguments.h"
#include "concretelang/ClientLib/PublicArguments.h"
namespace concretelang {
@@ -11,20 +11,23 @@ namespace clientlib {
using StringError = concretelang::error::StringError;
EncryptedArgs::~EncryptedArgs() {
// There is no explicit allocation
// All buffers are owned by ciphertextBuffers
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext) {
// On client side the runtimeContext is hold by the KeySet
bool clearContext = false;
return std::make_unique<PublicArguments>(
clientParameters, runtimeContext, clearContext, std::move(preparedArgs),
std::move(ciphertextBuffers));
}
EncryptedArgs::EncryptedArgs() : currentPos(0) {}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
EncryptedArguments::pushArg(uint8_t arg, std::shared_ptr<KeySet> keySet) {
return pushArg((uint64_t)arg, keySet);
}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
// TODO: NON ENCRYPTED
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
@@ -65,14 +68,15 @@ EncryptedArgs::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(std::vector<uint8_t> arg,
std::shared_ptr<KeySet> keySet) {
EncryptedArguments::pushArg(std::vector<uint8_t> arg,
std::shared_ptr<KeySet> keySet) {
return pushArg(8, (void *)arg.data(), {(int64_t)arg.size()}, keySet);
}
outcome::checked<void, StringError>
EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet) {
EncryptedArguments::pushArg(size_t width, void *data,
llvm::ArrayRef<int64_t> shape,
std::shared_ptr<KeySet> keySet) {
OUTCOME_TRYV(checkPushTooManyArgs(keySet));
auto pos = currentPos;
CircuitGate input = keySet->inputGate(pos);
@@ -148,7 +152,7 @@ EncryptedArgs::pushArg(size_t width, void *data, llvm::ArrayRef<int64_t> shape,
}
outcome::checked<void, StringError>
EncryptedArgs::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
EncryptedArguments::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
size_t arity = keySet->numInputs();
if (currentPos < arity) {
return outcome::success();
@@ -158,7 +162,7 @@ EncryptedArgs::checkPushTooManyArgs(std::shared_ptr<KeySet> keySet) {
}
outcome::checked<void, StringError>
EncryptedArgs::checkAllArgs(std::shared_ptr<KeySet> keySet) {
EncryptedArguments::checkAllArgs(std::shared_ptr<KeySet> keySet) {
size_t arity = keySet->numInputs();
if (currentPos == arity) {
return outcome::success();
@@ -168,14 +172,5 @@ EncryptedArgs::checkAllArgs(std::shared_ptr<KeySet> keySet) {
<< " arguments";
}
outcome::checked<PublicArguments, StringError>
EncryptedArgs::asPublicArguments(ClientParameters clientParameters,
RuntimeContext runtimeContext) {
// On client side the runtimeContext is hold by the KeySet
bool clearContext = false;
return PublicArguments(clientParameters, runtimeContext, clearContext,
std::move(preparedArgs), std::move(ciphertextBuffers));
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -230,7 +230,6 @@ KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input) {
outcome::checked<void, StringError>
KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t &output) {
if (argPos >= outputs.size()) {
return StringError("decrypt_lwe: position of argument is too high");
}

View File

@@ -29,18 +29,6 @@ PublicArguments::PublicArguments(
ciphertextBuffers = std::move(ciphertextBuffers_);
}
PublicArguments::PublicArguments(PublicArguments &&other) {
clientParameters = other.clientParameters;
runtimeContext = other.runtimeContext;
runtimeContext.bsk = std::move(other.runtimeContext.bsk);
clearRuntimeContext = other.clearRuntimeContext;
preparedArgs = std::move(other.preparedArgs);
ciphertextBuffers = std::move(other.ciphertextBuffers);
// transfer ownership
other.clearRuntimeContext = false;
other.runtimeContext.ksk = nullptr;
}
PublicArguments::~PublicArguments() {
if (!clearRuntimeContext) {
return;
@@ -68,7 +56,7 @@ PublicArguments::serialize(std::ostream &ostream) {
size_t rank = gate.shape.dimensions.size();
if (!gate.encryption.hasValue()) {
return StringError("PublicArguments::serialize: Clear arguments "
"are not supported. Argument ")
"are not yet supported. Argument ")
<< iGate;
}
/*auto allocated = */ preparedArgs[iPreparedArgs++];
@@ -140,14 +128,27 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
}
std::vector<void *> empty;
std::vector<encrypted_scalars_and_sizes_t> emptyBuffers;
// On server side the PublicArguments is responsible for the context
auto clearRuntimeContext = true;
auto sArguments = std::make_shared<PublicArguments>(
clientParameters, runtimeContext, clearRuntimeContext, std::move(empty),
clientParameters, runtimeContext, true, std::move(empty),
std::move(emptyBuffers));
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
return sArguments;
}
outcome::checked<std::vector<decrypted_scalar_t>, StringError>
PublicResult::decryptVector(KeySet &keySet, size_t pos) {
auto lweSize =
clientParameters.lweSecretKeyParam(clientParameters.outputs[pos])
.lweSize();
auto buffer = buffers[pos];
decrypted_tensor_1_t decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = &buffer.values[i * lweSize];
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decryptedValues[i]));
}
return decryptedValues;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -123,37 +123,20 @@ ServerLambda::load(std::string funcName, std::string outputLib) {
encrypted_scalars_and_sizes_t dynamicCall(void *(*func)(void *...),
std::vector<void *> &preparedArgs,
CircuitGate &output,
std::ostream &ostream) {
CircuitGate &output) {
size_t rank = output.shape.dimensions.size();
return multi_arity_call_dynamic_rank(func, preparedArgs, rank);
}
outcome::checked<void, StringError>
ServerLambda::read_call_write(std::istream &istream, std::ostream &ostream) {
OUTCOME_TRY(auto argumentsPtr,
PublicArguments::unserialize(clientParameters, istream));
assert(istream.good());
PublicArguments &arguments = *argumentsPtr;
// The runtime context is always the last argument list
arguments.preparedArgs.push_back((void *)&arguments.runtimeContext);
auto values_and_sizes = dynamicCall(this->func, arguments.preparedArgs,
clientParameters.outputs[0], ostream);
auto shape = clientParameters.outputs[0].shape;
size_t rank = shape.dimensions.size();
for (size_t dim = 0; dim < rank; dim++) {
if (values_and_sizes.sizes[dim] != (size_t)shape.dimensions[dim]) {
return StringError("Dimension mismatch on dim ")
<< dim << " actual: " << values_and_sizes.sizes[dim]
<< " vs expected: " << shape.dimensions[dim] << "\n";
}
}
serializeEncryptedValues(values_and_sizes, ostream);
if (ostream.fail()) {
return StringError("Cannot write result");
}
return outcome::success();
std::unique_ptr<clientlib::PublicResult>
ServerLambda::call(PublicArguments &args) {
std::vector<void *> preparedArgs(args.preparedArgs.begin(),
args.preparedArgs.end());
preparedArgs.push_back((void *)&args.runtimeContext);
return clientlib::PublicResult::fromBuffers(
clientParameters,
{dynamicCall(this->func, preparedArgs, clientParameters.outputs[0])});
;
}
} // namespace serverlib

View File

@@ -2,7 +2,7 @@
#include "../unittest/end_to_end_jit_test.h"
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EncryptedArgs.h"
#include "concretelang/ClientLib/EncryptedArguments.h"
namespace clientlib = concretelang::clientlib;

View File

@@ -73,10 +73,11 @@ func @main(%arg0: !FHE.eint<7>) -> !FHE.eint<7> {
ASSERT_TRUE(maybeKeySet.has_value());
std::shared_ptr<KeySet> keySet = std::move(maybeKeySet.value());
auto maybePublicArguments = lambda.publicArguments(1, keySet);
ASSERT_TRUE(maybePublicArguments.has_value());
auto publicArguments = std::move(maybePublicArguments.value());
std::ostringstream osstream(std::ios::binary);
EXPECT_TRUE(lambda.untypedSerializeCall(publicArguments, osstream));
ASSERT_TRUE(publicArguments->serialize(osstream).has_value());
EXPECT_TRUE(osstream.good());
// Direct call without intermediate
EXPECT_TRUE(lambda.serializeCall(1, keySet, osstream));
@@ -119,6 +120,7 @@ func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
}
TEST(CompiledModule, call_2s_1s) {
std::string source = R"(
func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>) -> !FHE.eint<7> {
%1 = "FHE.add_eint"(%arg0, %arg1): (!FHE.eint<7>, !FHE.eint<7>) -> (!FHE.eint<7>)