refactor(client/server): Rename encrypted_scalars_and_sizes_t o TensorData as it can be used for any kind of tensor

This commit is contained in:
Quentin Bourgerie
2022-03-02 09:17:31 +01:00
parent 5e8d2e7986
commit 82741868f1
12 changed files with 52 additions and 61 deletions

View File

@@ -130,7 +130,7 @@ private:
std::vector<void *> preparedArgs;
// Store buffers of ciphertexts
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
std::vector<TensorData> ciphertextBuffers;
};
} // namespace clientlib

View File

@@ -31,10 +31,10 @@ class PublicArguments {
/// PublicArguments will be sended to the server. It includes encrypted
/// arguments and public keys.
public:
PublicArguments(
const ClientParameters &clientParameters, RuntimeContext runtimeContext,
bool clearRuntimeContext, std::vector<void *> &&preparedArgs,
std::vector<encrypted_scalars_and_sizes_t> &&ciphertextBuffers);
PublicArguments(const ClientParameters &clientParameters,
RuntimeContext runtimeContext, bool clearRuntimeContext,
std::vector<void *> &&preparedArgs,
std::vector<TensorData> &&ciphertextBuffers);
~PublicArguments();
PublicArguments(PublicArguments &other) = delete;
PublicArguments(PublicArguments &&other) = delete;
@@ -53,7 +53,7 @@ private:
RuntimeContext runtimeContext;
std::vector<void *> preparedArgs;
// Store buffers of ciphertexts
std::vector<encrypted_scalars_and_sizes_t> ciphertextBuffers;
std::vector<TensorData> ciphertextBuffers;
// Indicates if this public argument own the runtime keys.
bool clearRuntimeContext;
@@ -64,7 +64,7 @@ struct PublicResult {
/// results.
PublicResult(const ClientParameters &clientParameters,
std::vector<encrypted_scalars_and_sizes_t> buffers = {})
std::vector<TensorData> buffers = {})
: clientParameters(clientParameters), buffers(buffers){};
PublicResult(PublicResult &) = delete;
@@ -72,7 +72,7 @@ struct PublicResult {
/// Create a public result from buffers.
static std::unique_ptr<PublicResult>
fromBuffers(const ClientParameters &clientParameters,
std::vector<encrypted_scalars_and_sizes_t> buffers) {
std::vector<TensorData> buffers) {
return std::make_unique<PublicResult>(clientParameters, buffers);
}
@@ -89,7 +89,7 @@ struct PublicResult {
private:
friend class ::concretelang::serverlib::ServerLambda;
ClientParameters clientParameters;
std::vector<encrypted_scalars_and_sizes_t> buffers;
std::vector<TensorData> buffers;
};
} // namespace clientlib

View File

@@ -52,23 +52,18 @@ template <typename Stream> bool incorrectMode(Stream &stream) {
return !binary;
}
std::ostream &operator<<(std::ostream &ostream, const ClientParameters &params);
std::istream &operator>>(std::istream &istream, ClientParameters &params);
std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext);
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
std::ostream &serializeEncryptedValues(std::vector<size_t> &sizes,
encrypted_scalars_t values,
std::ostream &ostream);
std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
std::ostream &ostream);
std::ostream &
serializeEncryptedValues(encrypted_scalars_and_sizes_t &values_and_sizes,
std::ostream &ostream);
std::ostream &serializeTensorData(TensorData &values_and_sizes,
std::ostream &ostream);
encrypted_scalars_and_sizes_t unserializeEncryptedValues(
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
TensorData unserializeTensorData(
std::vector<int64_t> &expectedSizes, // includes unsigned to
// accomodate non static sizes
std::istream &istream);

View File

@@ -33,7 +33,7 @@ template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
using encrypted_scalar_t = uint64_t *;
using encrypted_scalars_t = uint64_t *;
struct encrypted_scalars_and_sizes_t {
struct TensorData {
std::vector<uint64_t> values; // tensor of rank r + 1
std::vector<size_t> sizes; // r sizes

View File

@@ -13,11 +13,10 @@
namespace concretelang {
namespace serverlib {
using concretelang::clientlib::encrypted_scalars_and_sizes_t;
using concretelang::clientlib::TensorData;
encrypted_scalars_and_sizes_t
multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args, size_t rank);
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args, size_t rank);
} // namespace serverlib
} // namespace concretelang

View File

@@ -20,12 +20,12 @@ namespace concretelang {
namespace serverlib {
using concretelang::clientlib::encrypted_scalar_t;
using concretelang::clientlib::encrypted_scalars_and_sizes_t;
using concretelang::clientlib::encrypted_scalars_t;
using concretelang::clientlib::TensorData;
encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
size_t rank, encrypted_scalars_t allocated, encrypted_scalars_t aligned,
size_t offset, size_t *sizes, size_t *strides);
TensorData TensorData_from_MemRef(size_t rank, encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides);
/// ServerLambda is a utility class that allows to call a function of a
/// compilation result.

View File

@@ -45,7 +45,7 @@ EncryptedArguments::pushArg(uint64_t arg, std::shared_ptr<KeySet> keySet) {
return outcome::success();
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back();
TensorData &values_and_sizes = ciphertextBuffers.back();
auto lweSize = keySet->getInputLweSecretKeyParam(pos).lweSize();
values_and_sizes.sizes.push_back(lweSize);
values_and_sizes.values.resize(lweSize);
@@ -100,7 +100,7 @@ EncryptedArguments::pushArg(size_t width, void *data,
<< shape.size() << " expected " << input.shape.dimensions.size();
}
ciphertextBuffers.resize(ciphertextBuffers.size() + 1); // Allocate empty
encrypted_scalars_and_sizes_t &values_and_sizes = ciphertextBuffers.back();
TensorData &values_and_sizes = ciphertextBuffers.back();
for (size_t i = 0; i < shape.size(); i++) {
values_and_sizes.sizes.push_back(shape[i]);
if (shape[i] != input.shape.dimensions[i]) {

View File

@@ -78,7 +78,7 @@ PublicArguments::serialize(std::ostream &ostream) {
}
// TODO: STRIDES
auto values = aligned + offset;
serializeEncryptedValues(sizes, values, ostream);
serializeTensorData(sizes, values, ostream);
}
return outcome::success();
}
@@ -94,7 +94,7 @@ PublicArguments::unserializeArgs(std::istream &istream) {
auto lweSize = clientParameters.lweSecretKeyParam(gate).lweSize();
std::vector<int64_t> sizes = gate.shape.dimensions;
sizes.push_back(lweSize);
ciphertextBuffers.push_back(unserializeEncryptedValues(sizes, istream));
ciphertextBuffers.push_back(unserializeTensorData(sizes, istream));
auto &values_and_sizes = ciphertextBuffers.back();
if (istream.fail()) {
return StringError(
@@ -127,7 +127,7 @@ PublicArguments::unserialize(ClientParameters &clientParameters,
return StringError("Cannot read runtime context");
}
std::vector<void *> empty;
std::vector<encrypted_scalars_and_sizes_t> emptyBuffers;
std::vector<TensorData> emptyBuffers;
auto sArguments = std::make_shared<PublicArguments>(
clientParameters, runtimeContext, true, std::move(empty),
std::move(emptyBuffers));

View File

@@ -81,8 +81,8 @@ std::ostream &operator<<(std::ostream &ostream,
return ostream;
}
std::ostream &serializeEncryptedValues(encrypted_scalars_t values,
size_t length, std::ostream &ostream) {
std::ostream &serializeTensorData(uint64_t *values, size_t length,
std::ostream &ostream) {
if (incorrectMode(ostream)) {
return ostream;
}
@@ -93,32 +93,30 @@ std::ostream &serializeEncryptedValues(encrypted_scalars_t values,
return ostream;
}
std::ostream &serializeEncryptedValues(std::vector<size_t> &sizes,
encrypted_scalars_t values,
std::ostream &ostream) {
std::ostream &serializeTensorData(std::vector<size_t> &sizes, uint64_t *values,
std::ostream &ostream) {
size_t length = 1;
for (auto size : sizes) {
length *= size;
writeSize(ostream, size);
}
serializeEncryptedValues(values, length, ostream);
serializeTensorData(values, length, ostream);
assert(ostream.good());
return ostream;
}
std::ostream &
serializeEncryptedValues(encrypted_scalars_and_sizes_t &values_and_sizes,
std::ostream &ostream) {
std::ostream &serializeTensorData(TensorData &values_and_sizes,
std::ostream &ostream) {
std::vector<size_t> &sizes = values_and_sizes.sizes;
encrypted_scalars_t values = values_and_sizes.values.data();
return serializeEncryptedValues(sizes, values, ostream);
return serializeTensorData(sizes, values, ostream);
}
encrypted_scalars_and_sizes_t unserializeEncryptedValues(
TensorData unserializeTensorData(
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
// accomodate non static sizes
std::istream &istream) {
encrypted_scalars_and_sizes_t result;
TensorData result;
if (incorrectMode(istream)) {
return result;
}

View File

@@ -15,11 +15,11 @@
namespace concretelang {
namespace serverlib {
encrypted_scalars_and_sizes_t
multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args, size_t rank) {
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args,
size_t rank) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = &encrypted_scalars_and_sizes_t_from_MemRef;
constexpr auto convert = &TensorData_from_MemRef;
switch (rank) {
case 0: {
auto m = multi_arity_call((MemRefDescriptor<1>(*)(void *...))func, args);

View File

@@ -48,12 +48,12 @@ size_t global_index(size_t index[], size_t sizes[], size_t strides[],
}
/** Helper function to convert from MemRefDescriptor to
* encrypted_scalars_and_sizes_t assuming MemRefDescriptor are bufferized */
encrypted_scalars_and_sizes_t encrypted_scalars_and_sizes_t_from_MemRef(
size_t memref_rank, encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset, size_t *sizes,
size_t *strides) {
encrypted_scalars_and_sizes_t result;
* TensorData assuming MemRefDescriptor are bufferized */
TensorData TensorData_from_MemRef(size_t memref_rank,
encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides) {
TensorData result;
assert(aligned != nullptr);
result.sizes.resize(memref_rank);
for (size_t r = 0; r < memref_rank; r++) {
@@ -125,9 +125,8 @@ ServerLambda::load(std::string funcName, std::string outputLib) {
return ServerLambda::loadFromModule(module, funcName);
}
encrypted_scalars_and_sizes_t dynamicCall(void *(*func)(void *...),
std::vector<void *> &preparedArgs,
CircuitGate &output) {
TensorData dynamicCall(void *(*func)(void *...),
std::vector<void *> &preparedArgs, CircuitGate &output) {
size_t rank = output.shape.dimensions.size();
return multi_arity_call_dynamic_rank(func, preparedArgs, rank);
}

View File

@@ -19,10 +19,10 @@ print(
namespace concretelang {
namespace serverlib {
encrypted_scalars_and_sizes_t
TensorData
multi_arity_call_dynamic_rank(void* (*func)(void *...), std::vector<void *> args, size_t rank) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = &encrypted_scalars_and_sizes_t_from_MemRef;
constexpr auto convert = &TensorData_from_MemRef;
switch (rank) {""")
for tensor_rank in range(0, 33):