refactor(compiler/clientlib): Remove the building of the calling convention in the EncryptedArguments and test serialization in end_to_end_tests

This commit is contained in:
Quentin Bourgerie
2023-05-17 10:59:45 +02:00
parent b3ec478de9
commit 0bdb85b67d
11 changed files with 124 additions and 193 deletions

View File

@@ -179,26 +179,7 @@ public:
td.bulkAssign(values);
ciphertextBuffers.push_back(std::move(td));
}
TensorData &td = ciphertextBuffers.back().getTensor();
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back(td.getValuesAsOpaquePointer());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t size : td.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = td.getNumElements();
for (size_t size : td.getDimensions()) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
currentPos++;
return outcome::success();
}
@@ -232,7 +213,6 @@ private:
private:
/// Position of the next pushed argument
size_t currentPos;
std::vector<void *> preparedArgs;
/// Store buffers of ciphertexts
std::vector<ScalarOrTensorData> ciphertextBuffers;

View File

@@ -37,7 +37,6 @@ class EncryptedArguments;
class PublicArguments {
public:
PublicArguments(const ClientParameters &clientParameters,
std::vector<void *> &&preparedArgs,
std::vector<ScalarOrTensorData> &&ciphertextBuffers);
~PublicArguments();
PublicArguments(PublicArguments &other) = delete;
@@ -48,16 +47,18 @@ public:
outcome::checked<void, StringError> serialize(std::ostream &ostream);
private:
std::vector<ScalarOrTensorData> &getArguments() { return arguments; }
ClientParameters &getClientParameters() { return clientParameters; }
friend class ::concretelang::serverlib::ServerLambda;
friend class ::mlir::concretelang::JITLambda;
private:
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
ClientParameters clientParameters;
std::vector<void *> preparedArgs;
/// Store buffers of ciphertexts
std::vector<ScalarOrTensorData> ciphertextBuffers;
std::vector<ScalarOrTensorData> arguments;
};
/// PublicResult is a result of a ServerLambda call which contains encrypted

View File

@@ -87,17 +87,20 @@ std::ostream &serializeTensorDataRaw(const llvm::ArrayRef<size_t> &dimensions,
return ostream;
}
outcome::checked<TensorData, StringError> unserializeTensorData(
std::vector<int64_t> &expectedSizes, // includes unsigned to
// accomodate non static sizes
std::istream &istream);
outcome::checked<TensorData, StringError>
unserializeTensorData(std::istream &istream);
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
std::ostream &ostream);
outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
std::istream &istream);
unserializeScalarOrTensorData(std::istream &istream);
std::ostream &
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &sotd,
std::ostream &ostream);
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
unserializeVectorOfScalarOrTensorData(std::istream &istream);
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
LweSecretKey readLweSecretKey(std::istream &istream);

View File

@@ -7,6 +7,7 @@
#define CONCRETELANG_CLIENTLIB_TYPES_H_
#include "llvm/ADT/ArrayRef.h"
#include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <stddef.h>
@@ -818,6 +819,8 @@ public:
// Retrieves the value as a generic `uint64_t`
uint64_t getValueAsU64() const {
size_t width = getElementTypeWidth(type);
if (width == 64)
return value.u64;
uint64_t mask = ((uint64_t)1 << width) - 1;
uint64_t val = value.u64 & mask;
return val;

View File

@@ -114,6 +114,42 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
std::move(buffers));
}
template <typename Lambda>
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments,
clientlib::EvaluationKeys &evaluationKeys) {
// Prepare arguments with the right calling convention
std::vector<void *> preparedArgs;
for (auto &arg : arguments.getArguments()) {
if (arg.isScalar()) {
auto scalar = arg.getScalar().getValueAsU64();
preparedArgs.push_back((void *)scalar);
} else {
clientlib::TensorData &td = arg.getTensor();
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back(td.getValuesAsOpaquePointer());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t size : td.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = td.getNumElements();
for (size_t size : td.getDimensions()) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
}
}
return invokeRawOnLambda(lambda, arguments.getClientParameters(),
preparedArgs, evaluationKeys);
}
template <typename V, unsigned int N>
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
const llvm::SmallVector<V, N> vect) {

View File

@@ -13,8 +13,8 @@ using StringError = concretelang::error::StringError;
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
return std::make_unique<PublicArguments>(
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
return std::make_unique<PublicArguments>(clientParameters,
std::move(ciphertextBuffers));
}
/// Split the input integer into `size` chunks of `chunkWidth` bits each
@@ -49,7 +49,7 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
}
if (!input.encryption.has_value()) {
// clear scalar: just push the argument
preparedArgs.push_back((void *)arg);
ciphertextBuffers.push_back(ScalarData(arg));
return outcome::success();
}
@@ -63,24 +63,6 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
OUTCOME_TRYV(keySet.encrypt_lwe(
pos, values_and_sizes.getElementPointer<decrypted_scalar_t>(0), arg));
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (auto size : values_and_sizes.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// strides
int64_t stride = TensorData::getNumElements(shape);
for (size_t size : values_and_sizes.getDimensions()) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
return outcome::success();
}

View File

@@ -15,13 +15,10 @@ namespace clientlib {
using concretelang::error::StringError;
// TODO: optimize the move
PublicArguments::PublicArguments(
const ClientParameters &clientParameters,
std::vector<void *> &&preparedArgs_,
std::vector<ScalarOrTensorData> &&ciphertextBuffers_)
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
std::vector<ScalarOrTensorData> &&arguments_)
: clientParameters(clientParameters) {
preparedArgs = std::move(preparedArgs_);
ciphertextBuffers = std::move(ciphertextBuffers_);
arguments = std::move(arguments_);
}
PublicArguments::~PublicArguments() {}
@@ -32,146 +29,41 @@ PublicArguments::serialize(std::ostream &ostream) {
return StringError(
"PublicArguments::serialize: ostream should be in binary mode");
}
size_t iPreparedArgs = 0;
int iGate = -1;
for (auto gate : clientParameters.inputs) {
iGate++;
size_t rank = gate.shape.dimensions.size();
if (!gate.encryption.has_value()) {
return StringError("PublicArguments::serialize: Clear arguments "
"are not yet supported. Argument ")
<< iGate;
}
/*auto allocated = */ iPreparedArgs++;
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
assert(aligned != nullptr);
auto offset = (size_t)preparedArgs[iPreparedArgs++];
std::vector<size_t> sizes; // includes lweSize as last dim
sizes.resize(rank + (gate.encryption->encoding.crt.empty() ? 1 : 2));
for (auto dim = 0u; dim < sizes.size(); dim++) {
// sizes are part of the client parameters signature
// it's static now but some day it could be dynamic so we serialize
// them.
sizes[dim] = (size_t)preparedArgs[iPreparedArgs++];
}
std::vector<size_t> strides(rank + 1);
/* strides should be zero here and are not serialized */
for (auto dim = 0u; dim < strides.size(); dim++) {
strides[dim] = (size_t)preparedArgs[iPreparedArgs++];
}
// TODO: STRIDES
auto values = aligned + offset;
writeWord<uint8_t>(ostream, 1);
serializeTensorDataRaw(sizes,
llvm::ArrayRef<clientlib::EncryptedScalarElement>{
values, TensorData::getNumElements(sizes)},
ostream);
serializeVectorOfScalarOrTensorData(arguments, ostream);
if (ostream.bad()) {
return StringError(
"PublicArguments::serialize: cannot serialize public arguments");
}
return outcome::success();
}
outcome::checked<void, StringError>
PublicArguments::unserializeArgs(std::istream &istream) {
int iGate = -1;
for (auto gate : clientParameters.inputs) {
iGate++;
if (!gate.encryption.has_value()) {
return StringError("Clear values are not handled");
}
std::vector<int64_t> sizes = gate.shape.dimensions;
if (gate.encryption.has_value() && !gate.encryption->encoding.crt.empty()) {
sizes.push_back(gate.encryption->encoding.crt.size());
}
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
sizes.push_back(lweSize);
auto sotdOrErr = unserializeScalarOrTensorData(sizes, istream);
if (sotdOrErr.has_error())
return sotdOrErr.error();
ciphertextBuffers.push_back(std::move(sotdOrErr.value()));
auto &buffer = ciphertextBuffers.back();
if (istream.fail()) {
return StringError(
"PublicArguments::unserializeArgs: Failed to read argument ")
<< iGate;
}
if (buffer.isTensor()) {
TensorData &td = buffer.getTensor();
preparedArgs.push_back(/*allocated*/ nullptr);
preparedArgs.push_back(td.getValuesAsOpaquePointer());
preparedArgs.push_back(/*offset*/ 0);
// sizes
for (auto size : td.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// strides has been removed by serialization
auto stride = td.length();
for (auto size : sizes) {
stride /= size;
preparedArgs.push_back((void *)stride);
}
} else {
ScalarData &sd = buffer.getScalar();
preparedArgs.push_back((void *)sd.getValueAsU64());
}
}
OUTCOME_TRY(arguments, unserializeVectorOfScalarOrTensorData(istream));
return outcome::success();
}
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
PublicArguments::unserialize(ClientParameters &clientParameters,
std::istream &istream) {
std::vector<void *> empty;
std::vector<ScalarOrTensorData> emptyBuffers;
auto sArguments = std::make_unique<PublicArguments>(
clientParameters, std::move(empty), std::move(emptyBuffers));
auto sArguments = std::make_unique<PublicArguments>(clientParameters,
std::move(emptyBuffers));
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
return std::move(sArguments);
}
outcome::checked<void, StringError>
PublicResult::unserialize(std::istream &istream) {
for (auto gate : clientParameters.outputs) {
if (!gate.encryption.has_value()) {
return StringError("Clear values are not handled");
}
std::vector<int64_t> sizes = gate.shape.dimensions;
if (gate.encryption.has_value() && !gate.encryption->encoding.crt.empty()) {
sizes.push_back(gate.encryption->encoding.crt.size());
}
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
sizes.push_back(lweSize);
auto sotd = unserializeScalarOrTensorData(sizes, istream);
if (sotd.has_error())
return sotd.error();
buffers.push_back(std::move(sotd.value()));
}
OUTCOME_TRY(buffers, unserializeVectorOfScalarOrTensorData(istream));
return outcome::success();
}
outcome::checked<void, StringError>
PublicResult::serialize(std::ostream &ostream) {
if (incorrectMode(ostream)) {
return StringError(
"PublicResult::serialize: ostream should be in binary mode");
}
for (const ScalarOrTensorData &sotd : buffers) {
serializeScalarOrTensorData(sotd, ostream);
if (ostream.fail()) {
return StringError("Cannot write data");
}
serializeVectorOfScalarOrTensorData(buffers, ostream);
if (ostream.bad()) {
return StringError("PublicResult::serialize: cannot serialize");
}
return outcome::success();
}

View File

@@ -444,10 +444,8 @@ std::ostream &serializeTensorData(const TensorData &values_and_sizes,
assert(false && "Unhandled element type");
}
outcome::checked<TensorData, StringError> unserializeTensorData(
const std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
// accomodate non static sizes
std::istream &istream) {
outcome::checked<TensorData, StringError>
unserializeTensorData(std::istream &istream) {
if (incorrectMode(istream)) {
return StringError("Stream is in incorrect mode");
@@ -461,13 +459,6 @@ outcome::checked<TensorData, StringError> unserializeTensorData(
for (uint64_t i = 0; i < numDimensions; i++) {
int64_t dimSize;
readWord(istream, dimSize);
if (dimSize != expectedSizes[i]) {
istream.setstate(std::ios::badbit);
return StringError("Number of dimensions did not match the number of "
"expected dimensions");
}
dims.push_back(dimSize);
}
@@ -537,8 +528,7 @@ std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
}
outcome::checked<ScalarOrTensorData, StringError>
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
std::istream &istream) {
unserializeScalarOrTensorData(std::istream &istream) {
uint8_t isTensor;
readWord(istream, isTensor);
@@ -549,7 +539,7 @@ unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
}
if (isTensor) {
auto tdOrErr = unserializeTensorData(expectedSizes, istream);
auto tdOrErr = unserializeTensorData(istream);
if (tdOrErr.has_error())
return std::move(tdOrErr.error());
@@ -565,5 +555,29 @@ unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
}
}
std::ostream &
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &v,
std::ostream &ostream) {
writeSize(ostream, v.size());
for (auto &sotd : v) {
serializeScalarOrTensorData(sotd, ostream);
if (!ostream.good()) {
return ostream;
}
}
return ostream;
}
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
unserializeVectorOfScalarOrTensorData(std::istream &istream) {
uint64_t nbElt;
readSize(istream, nbElt);
std::vector<ScalarOrTensorData> v;
for (uint64_t i = 0; i < nbElt; i++) {
OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream));
v.push_back(std::move(elt));
}
return v;
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -79,8 +79,7 @@ llvm::Error ServerLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
return invokeRawOnLambda(this, args.clientParameters, args.preparedArgs,
evaluationKeys);
return invokeRawOnLambda(this, args, evaluationKeys);
}
} // namespace serverlib

View File

@@ -101,8 +101,7 @@ JITLambda::call(clientlib::PublicArguments &args,
}
#endif
return ::concretelang::invokeRawOnLambda(this, args.clientParameters,
args.preparedArgs, evaluationKeys);
return ::concretelang::invokeRawOnLambda(this, args, evaluationKeys);
}
} // namespace concretelang

View File

@@ -79,12 +79,34 @@ public:
stream << evaluationKeys;
stream.seekg(0, std::ios::beg);
evaluationKeys = concretelang::clientlib::readEvaluationKeys(stream);
stream.str("");
stream.clear();
/* Serialize and unserialize public arguments */
auto serializeRes = publicArguments->serialize(stream);
ASSERT_FALSE(serializeRes.has_error());
stream.seekg(0, std::ios::beg);
auto unserializedArgs =
concretelang::clientlib::PublicArguments::unserialize(clientParameters,
stream);
stream.str("");
stream.clear();
ASSERT_FALSE(unserializedArgs.has_error());
/* Call the server lambda */
auto publicResult =
support.serverCall(serverLambda, *publicArguments, evaluationKeys);
auto publicResult = support.serverCall(
serverLambda, *unserializedArgs.value(), evaluationKeys);
ASSERT_EXPECTED_SUCCESS(publicResult);
/* Serialize and unserialize public result */
serializeRes = (*publicResult)->serialize(stream);
ASSERT_FALSE(serializeRes.has_error());
auto unserializedResult =
concretelang::clientlib::PublicResult::unserialize(clientParameters,
stream);
ASSERT_FALSE(unserializedResult.has_error());
/* Decrypt the public result */
auto result = mlir::concretelang::typedResult<
std::unique_ptr<mlir::concretelang::LambdaArgument>>(*keySet,