mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 11:35:02 -05:00
refactor(compiler/clientlib): Remove the building of the calling convention in the EncryptedArguments and test serialization in end_to_end_tests
This commit is contained in:
@@ -179,26 +179,7 @@ public:
|
||||
td.bulkAssign(values);
|
||||
ciphertextBuffers.push_back(std::move(td));
|
||||
}
|
||||
TensorData &td = ciphertextBuffers.back().getTensor();
|
||||
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(td.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : td.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = td.getNumElements();
|
||||
for (size_t size : td.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
currentPos++;
|
||||
return outcome::success();
|
||||
}
|
||||
@@ -232,7 +213,6 @@ private:
|
||||
private:
|
||||
/// Position of the next pushed argument
|
||||
size_t currentPos;
|
||||
std::vector<void *> preparedArgs;
|
||||
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<ScalarOrTensorData> ciphertextBuffers;
|
||||
|
||||
@@ -37,7 +37,6 @@ class EncryptedArguments;
|
||||
class PublicArguments {
|
||||
public:
|
||||
PublicArguments(const ClientParameters &clientParameters,
|
||||
std::vector<void *> &&preparedArgs,
|
||||
std::vector<ScalarOrTensorData> &&ciphertextBuffers);
|
||||
~PublicArguments();
|
||||
PublicArguments(PublicArguments &other) = delete;
|
||||
@@ -48,16 +47,18 @@ public:
|
||||
|
||||
outcome::checked<void, StringError> serialize(std::ostream &ostream);
|
||||
|
||||
private:
|
||||
std::vector<ScalarOrTensorData> &getArguments() { return arguments; }
|
||||
ClientParameters &getClientParameters() { return clientParameters; }
|
||||
|
||||
friend class ::concretelang::serverlib::ServerLambda;
|
||||
friend class ::mlir::concretelang::JITLambda;
|
||||
|
||||
private:
|
||||
outcome::checked<void, StringError> unserializeArgs(std::istream &istream);
|
||||
|
||||
ClientParameters clientParameters;
|
||||
std::vector<void *> preparedArgs;
|
||||
/// Store buffers of ciphertexts
|
||||
std::vector<ScalarOrTensorData> ciphertextBuffers;
|
||||
std::vector<ScalarOrTensorData> arguments;
|
||||
};
|
||||
|
||||
/// PublicResult is a result of a ServerLambda call which contains encrypted
|
||||
|
||||
@@ -87,17 +87,20 @@ std::ostream &serializeTensorDataRaw(const llvm::ArrayRef<size_t> &dimensions,
|
||||
return ostream;
|
||||
}
|
||||
|
||||
outcome::checked<TensorData, StringError> unserializeTensorData(
|
||||
std::vector<int64_t> &expectedSizes, // includes unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream);
|
||||
outcome::checked<TensorData, StringError>
|
||||
unserializeTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
|
||||
std::ostream &ostream);
|
||||
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
|
||||
std::istream &istream);
|
||||
unserializeScalarOrTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &
|
||||
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &sotd,
|
||||
std::ostream &ostream);
|
||||
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
|
||||
unserializeVectorOfScalarOrTensorData(std::istream &istream);
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream, const LweSecretKey &wrappedKsk);
|
||||
LweSecretKey readLweSecretKey(std::istream &istream);
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <stddef.h>
|
||||
@@ -818,6 +819,8 @@ public:
|
||||
// Retrieves the value as a generic `uint64_t`
|
||||
uint64_t getValueAsU64() const {
|
||||
size_t width = getElementTypeWidth(type);
|
||||
if (width == 64)
|
||||
return value.u64;
|
||||
uint64_t mask = ((uint64_t)1 << width) - 1;
|
||||
uint64_t val = value.u64 & mask;
|
||||
return val;
|
||||
|
||||
@@ -114,6 +114,42 @@ invokeRawOnLambda(Lambda *lambda, clientlib::ClientParameters clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
|
||||
template <typename Lambda>
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
invokeRawOnLambda(Lambda *lambda, clientlib::PublicArguments &arguments,
|
||||
clientlib::EvaluationKeys &evaluationKeys) {
|
||||
// Prepare arguments with the right calling convention
|
||||
std::vector<void *> preparedArgs;
|
||||
for (auto &arg : arguments.getArguments()) {
|
||||
if (arg.isScalar()) {
|
||||
auto scalar = arg.getScalar().getValueAsU64();
|
||||
preparedArgs.push_back((void *)scalar);
|
||||
} else {
|
||||
clientlib::TensorData &td = arg.getTensor();
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(td.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : td.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = td.getNumElements();
|
||||
for (size_t size : td.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
}
|
||||
}
|
||||
return invokeRawOnLambda(lambda, arguments.getClientParameters(),
|
||||
preparedArgs, evaluationKeys);
|
||||
}
|
||||
|
||||
template <typename V, unsigned int N>
|
||||
llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
|
||||
const llvm::SmallVector<V, N> vect) {
|
||||
|
||||
@@ -13,8 +13,8 @@ using StringError = concretelang::error::StringError;
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
EncryptedArguments::exportPublicArguments(ClientParameters clientParameters) {
|
||||
return std::make_unique<PublicArguments>(
|
||||
clientParameters, std::move(preparedArgs), std::move(ciphertextBuffers));
|
||||
return std::make_unique<PublicArguments>(clientParameters,
|
||||
std::move(ciphertextBuffers));
|
||||
}
|
||||
|
||||
/// Split the input integer into `size` chunks of `chunkWidth` bits each
|
||||
@@ -49,7 +49,7 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
}
|
||||
if (!input.encryption.has_value()) {
|
||||
// clear scalar: just push the argument
|
||||
preparedArgs.push_back((void *)arg);
|
||||
ciphertextBuffers.push_back(ScalarData(arg));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -63,24 +63,6 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(
|
||||
pos, values_and_sizes.getElementPointer<decrypted_scalar_t>(0), arg));
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (auto size : values_and_sizes.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
// strides
|
||||
int64_t stride = TensorData::getNumElements(shape);
|
||||
for (size_t size : values_and_sizes.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -15,13 +15,10 @@ namespace clientlib {
|
||||
using concretelang::error::StringError;
|
||||
|
||||
// TODO: optimize the move
|
||||
PublicArguments::PublicArguments(
|
||||
const ClientParameters &clientParameters,
|
||||
std::vector<void *> &&preparedArgs_,
|
||||
std::vector<ScalarOrTensorData> &&ciphertextBuffers_)
|
||||
PublicArguments::PublicArguments(const ClientParameters &clientParameters,
|
||||
std::vector<ScalarOrTensorData> &&arguments_)
|
||||
: clientParameters(clientParameters) {
|
||||
preparedArgs = std::move(preparedArgs_);
|
||||
ciphertextBuffers = std::move(ciphertextBuffers_);
|
||||
arguments = std::move(arguments_);
|
||||
}
|
||||
|
||||
PublicArguments::~PublicArguments() {}
|
||||
@@ -32,146 +29,41 @@ PublicArguments::serialize(std::ostream &ostream) {
|
||||
return StringError(
|
||||
"PublicArguments::serialize: ostream should be in binary mode");
|
||||
}
|
||||
size_t iPreparedArgs = 0;
|
||||
int iGate = -1;
|
||||
for (auto gate : clientParameters.inputs) {
|
||||
iGate++;
|
||||
size_t rank = gate.shape.dimensions.size();
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("PublicArguments::serialize: Clear arguments "
|
||||
"are not yet supported. Argument ")
|
||||
<< iGate;
|
||||
}
|
||||
|
||||
/*auto allocated = */ iPreparedArgs++;
|
||||
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
|
||||
assert(aligned != nullptr);
|
||||
auto offset = (size_t)preparedArgs[iPreparedArgs++];
|
||||
std::vector<size_t> sizes; // includes lweSize as last dim
|
||||
sizes.resize(rank + (gate.encryption->encoding.crt.empty() ? 1 : 2));
|
||||
for (auto dim = 0u; dim < sizes.size(); dim++) {
|
||||
// sizes are part of the client parameters signature
|
||||
// it's static now but some day it could be dynamic so we serialize
|
||||
// them.
|
||||
sizes[dim] = (size_t)preparedArgs[iPreparedArgs++];
|
||||
}
|
||||
std::vector<size_t> strides(rank + 1);
|
||||
/* strides should be zero here and are not serialized */
|
||||
for (auto dim = 0u; dim < strides.size(); dim++) {
|
||||
strides[dim] = (size_t)preparedArgs[iPreparedArgs++];
|
||||
}
|
||||
// TODO: STRIDES
|
||||
auto values = aligned + offset;
|
||||
|
||||
writeWord<uint8_t>(ostream, 1);
|
||||
serializeTensorDataRaw(sizes,
|
||||
llvm::ArrayRef<clientlib::EncryptedScalarElement>{
|
||||
values, TensorData::getNumElements(sizes)},
|
||||
ostream);
|
||||
serializeVectorOfScalarOrTensorData(arguments, ostream);
|
||||
if (ostream.bad()) {
|
||||
return StringError(
|
||||
"PublicArguments::serialize: cannot serialize public arguments");
|
||||
}
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
int iGate = -1;
|
||||
for (auto gate : clientParameters.inputs) {
|
||||
iGate++;
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("Clear values are not handled");
|
||||
}
|
||||
|
||||
std::vector<int64_t> sizes = gate.shape.dimensions;
|
||||
if (gate.encryption.has_value() && !gate.encryption->encoding.crt.empty()) {
|
||||
sizes.push_back(gate.encryption->encoding.crt.size());
|
||||
}
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
sizes.push_back(lweSize);
|
||||
|
||||
auto sotdOrErr = unserializeScalarOrTensorData(sizes, istream);
|
||||
|
||||
if (sotdOrErr.has_error())
|
||||
return sotdOrErr.error();
|
||||
|
||||
ciphertextBuffers.push_back(std::move(sotdOrErr.value()));
|
||||
auto &buffer = ciphertextBuffers.back();
|
||||
|
||||
if (istream.fail()) {
|
||||
return StringError(
|
||||
"PublicArguments::unserializeArgs: Failed to read argument ")
|
||||
<< iGate;
|
||||
}
|
||||
|
||||
if (buffer.isTensor()) {
|
||||
TensorData &td = buffer.getTensor();
|
||||
preparedArgs.push_back(/*allocated*/ nullptr);
|
||||
preparedArgs.push_back(td.getValuesAsOpaquePointer());
|
||||
preparedArgs.push_back(/*offset*/ 0);
|
||||
// sizes
|
||||
for (auto size : td.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
// strides has been removed by serialization
|
||||
auto stride = td.length();
|
||||
for (auto size : sizes) {
|
||||
stride /= size;
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
} else {
|
||||
ScalarData &sd = buffer.getScalar();
|
||||
preparedArgs.push_back((void *)sd.getValueAsU64());
|
||||
}
|
||||
}
|
||||
OUTCOME_TRY(arguments, unserializeVectorOfScalarOrTensorData(istream));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<std::unique_ptr<PublicArguments>, StringError>
|
||||
PublicArguments::unserialize(ClientParameters &clientParameters,
|
||||
std::istream &istream) {
|
||||
std::vector<void *> empty;
|
||||
std::vector<ScalarOrTensorData> emptyBuffers;
|
||||
auto sArguments = std::make_unique<PublicArguments>(
|
||||
clientParameters, std::move(empty), std::move(emptyBuffers));
|
||||
auto sArguments = std::make_unique<PublicArguments>(clientParameters,
|
||||
std::move(emptyBuffers));
|
||||
OUTCOME_TRYV(sArguments->unserializeArgs(istream));
|
||||
return std::move(sArguments);
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicResult::unserialize(std::istream &istream) {
|
||||
for (auto gate : clientParameters.outputs) {
|
||||
if (!gate.encryption.has_value()) {
|
||||
return StringError("Clear values are not handled");
|
||||
}
|
||||
|
||||
std::vector<int64_t> sizes = gate.shape.dimensions;
|
||||
if (gate.encryption.has_value() && !gate.encryption->encoding.crt.empty()) {
|
||||
sizes.push_back(gate.encryption->encoding.crt.size());
|
||||
}
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
sizes.push_back(lweSize);
|
||||
|
||||
auto sotd = unserializeScalarOrTensorData(sizes, istream);
|
||||
|
||||
if (sotd.has_error())
|
||||
return sotd.error();
|
||||
|
||||
buffers.push_back(std::move(sotd.value()));
|
||||
}
|
||||
OUTCOME_TRY(buffers, unserializeVectorOfScalarOrTensorData(istream));
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
outcome::checked<void, StringError>
|
||||
PublicResult::serialize(std::ostream &ostream) {
|
||||
if (incorrectMode(ostream)) {
|
||||
return StringError(
|
||||
"PublicResult::serialize: ostream should be in binary mode");
|
||||
}
|
||||
for (const ScalarOrTensorData &sotd : buffers) {
|
||||
serializeScalarOrTensorData(sotd, ostream);
|
||||
if (ostream.fail()) {
|
||||
return StringError("Cannot write data");
|
||||
}
|
||||
serializeVectorOfScalarOrTensorData(buffers, ostream);
|
||||
if (ostream.bad()) {
|
||||
return StringError("PublicResult::serialize: cannot serialize");
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -444,10 +444,8 @@ std::ostream &serializeTensorData(const TensorData &values_and_sizes,
|
||||
assert(false && "Unhandled element type");
|
||||
}
|
||||
|
||||
outcome::checked<TensorData, StringError> unserializeTensorData(
|
||||
const std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream) {
|
||||
outcome::checked<TensorData, StringError>
|
||||
unserializeTensorData(std::istream &istream) {
|
||||
|
||||
if (incorrectMode(istream)) {
|
||||
return StringError("Stream is in incorrect mode");
|
||||
@@ -461,13 +459,6 @@ outcome::checked<TensorData, StringError> unserializeTensorData(
|
||||
for (uint64_t i = 0; i < numDimensions; i++) {
|
||||
int64_t dimSize;
|
||||
readWord(istream, dimSize);
|
||||
|
||||
if (dimSize != expectedSizes[i]) {
|
||||
istream.setstate(std::ios::badbit);
|
||||
return StringError("Number of dimensions did not match the number of "
|
||||
"expected dimensions");
|
||||
}
|
||||
|
||||
dims.push_back(dimSize);
|
||||
}
|
||||
|
||||
@@ -537,8 +528,7 @@ std::ostream &serializeScalarOrTensorData(const ScalarOrTensorData &sotd,
|
||||
}
|
||||
|
||||
outcome::checked<ScalarOrTensorData, StringError>
|
||||
unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
|
||||
std::istream &istream) {
|
||||
unserializeScalarOrTensorData(std::istream &istream) {
|
||||
uint8_t isTensor;
|
||||
readWord(istream, isTensor);
|
||||
|
||||
@@ -549,7 +539,7 @@ unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
|
||||
}
|
||||
|
||||
if (isTensor) {
|
||||
auto tdOrErr = unserializeTensorData(expectedSizes, istream);
|
||||
auto tdOrErr = unserializeTensorData(istream);
|
||||
|
||||
if (tdOrErr.has_error())
|
||||
return std::move(tdOrErr.error());
|
||||
@@ -565,5 +555,29 @@ unserializeScalarOrTensorData(const std::vector<int64_t> &expectedSizes,
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream &
|
||||
serializeVectorOfScalarOrTensorData(const std::vector<ScalarOrTensorData> &v,
|
||||
std::ostream &ostream) {
|
||||
writeSize(ostream, v.size());
|
||||
for (auto &sotd : v) {
|
||||
serializeScalarOrTensorData(sotd, ostream);
|
||||
if (!ostream.good()) {
|
||||
return ostream;
|
||||
}
|
||||
}
|
||||
return ostream;
|
||||
}
|
||||
outcome::checked<std::vector<ScalarOrTensorData>, StringError>
|
||||
unserializeVectorOfScalarOrTensorData(std::istream &istream) {
|
||||
uint64_t nbElt;
|
||||
readSize(istream, nbElt);
|
||||
std::vector<ScalarOrTensorData> v;
|
||||
for (uint64_t i = 0; i < nbElt; i++) {
|
||||
OUTCOME_TRY(auto elt, unserializeScalarOrTensorData(istream));
|
||||
v.push_back(std::move(elt));
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -79,8 +79,7 @@ llvm::Error ServerLambda::invokeRaw(llvm::MutableArrayRef<void *> args) {
|
||||
|
||||
llvm::Expected<std::unique_ptr<clientlib::PublicResult>>
|
||||
ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
|
||||
return invokeRawOnLambda(this, args.clientParameters, args.preparedArgs,
|
||||
evaluationKeys);
|
||||
return invokeRawOnLambda(this, args, evaluationKeys);
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
|
||||
@@ -101,8 +101,7 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
}
|
||||
#endif
|
||||
|
||||
return ::concretelang::invokeRawOnLambda(this, args.clientParameters,
|
||||
args.preparedArgs, evaluationKeys);
|
||||
return ::concretelang::invokeRawOnLambda(this, args, evaluationKeys);
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -79,12 +79,34 @@ public:
|
||||
stream << evaluationKeys;
|
||||
stream.seekg(0, std::ios::beg);
|
||||
evaluationKeys = concretelang::clientlib::readEvaluationKeys(stream);
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
/* Serialize and unserialize public arguments */
|
||||
auto serializeRes = publicArguments->serialize(stream);
|
||||
ASSERT_FALSE(serializeRes.has_error());
|
||||
stream.seekg(0, std::ios::beg);
|
||||
auto unserializedArgs =
|
||||
concretelang::clientlib::PublicArguments::unserialize(clientParameters,
|
||||
stream);
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
ASSERT_FALSE(unserializedArgs.has_error());
|
||||
|
||||
/* Call the server lambda */
|
||||
auto publicResult =
|
||||
support.serverCall(serverLambda, *publicArguments, evaluationKeys);
|
||||
auto publicResult = support.serverCall(
|
||||
serverLambda, *unserializedArgs.value(), evaluationKeys);
|
||||
ASSERT_EXPECTED_SUCCESS(publicResult);
|
||||
|
||||
/* Serialize and unserialize public result */
|
||||
serializeRes = (*publicResult)->serialize(stream);
|
||||
ASSERT_FALSE(serializeRes.has_error());
|
||||
|
||||
auto unserializedResult =
|
||||
concretelang::clientlib::PublicResult::unserialize(clientParameters,
|
||||
stream);
|
||||
ASSERT_FALSE(unserializedResult.has_error());
|
||||
|
||||
/* Decrypt the public result */
|
||||
auto result = mlir::concretelang::typedResult<
|
||||
std::unique_ptr<mlir::concretelang::LambdaArgument>>(*keySet,
|
||||
|
||||
Reference in New Issue
Block a user