mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(compiler): Add support for clear result tensors with element width != 64 bits
Returning tensors with elements whose width is not equal to 64 results in garbled data. This commit extends the `TensorData` class used to represent tensors in JIT compilation with support for signed / unsigned elements of 8/16/32 and 64 bits, such that all clear text tensors with up to 64 bits can be represented accurately.
This commit is contained in:
@@ -147,9 +147,6 @@ public:
|
||||
<< pos << "has not the expected number of dimension, got "
|
||||
<< shape.size() << " expected " << input.shape.dimensions.size();
|
||||
}
|
||||
// Allocate empty
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
|
||||
TensorData &values_and_sizes = ciphertextBuffers.back();
|
||||
|
||||
// Check shape
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
@@ -159,53 +156,49 @@ public:
|
||||
<< shape[i] << " expected " << input.shape.dimensions[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Set sizes
|
||||
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
|
||||
std::vector<int64_t> sizes = keySet.clientParameters().bufferShape(input);
|
||||
|
||||
if (input.encryption.hasValue()) {
|
||||
// Allocate values
|
||||
values_and_sizes.values.resize(
|
||||
keySet.clientParameters().bufferSize(input));
|
||||
TensorData td(sizes, ElementType::u64);
|
||||
|
||||
auto lweSize = keySet.clientParameters().lweBufferSize(input);
|
||||
auto &values = values_and_sizes.values;
|
||||
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data[i]));
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(
|
||||
pos, td.getElementPointer<uint64_t>(offset), data[i]));
|
||||
}
|
||||
ciphertextBuffers.push_back(std::move(td));
|
||||
} else {
|
||||
// Allocate values take care of gate bitwidth
|
||||
auto bitsPerValue = bitWidthAsWord(input.shape.width);
|
||||
auto bytesPerValue = bitsPerValue / 8;
|
||||
auto nbWordPerValue = 8 / bytesPerValue;
|
||||
// ceil division
|
||||
auto size = (input.shape.size / nbWordPerValue) +
|
||||
(input.shape.size % nbWordPerValue != 0);
|
||||
size = size == 0 ? 1 : size;
|
||||
values_and_sizes.values.resize(size);
|
||||
auto v = (uint8_t *)values_and_sizes.values.data();
|
||||
for (size_t i = 0; i < input.shape.size; i++) {
|
||||
auto dst = v + i * bytesPerValue;
|
||||
auto src = (const uint8_t *)&data[i];
|
||||
for (size_t j = 0; j < bytesPerValue; j++) {
|
||||
dst[j] = src[j];
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: This always requests a tensor with unsigned elements,
|
||||
// as the signedness is not captured in the description of the
|
||||
// circuit
|
||||
TensorData td(sizes, bitsPerValue, false);
|
||||
llvm::ArrayRef<T> values(data, TensorData::getNumElements(sizes));
|
||||
td.bulkAssign(values);
|
||||
ciphertextBuffers.push_back(std::move(td));
|
||||
}
|
||||
TensorData &td = ciphertextBuffers.back();
|
||||
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
preparedArgs.push_back(td.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
for (size_t size : td.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = values_and_sizes.length();
|
||||
for (size_t size : values_and_sizes.sizes) {
|
||||
int64_t stride = td.getNumElements();
|
||||
for (size_t size : td.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
|
||||
@@ -66,16 +66,16 @@ private:
|
||||
struct PublicResult {
|
||||
|
||||
PublicResult(const ClientParameters &clientParameters,
|
||||
std::vector<TensorData> buffers = {})
|
||||
: clientParameters(clientParameters), buffers(buffers){};
|
||||
std::vector<TensorData> &&buffers = {})
|
||||
: clientParameters(clientParameters), buffers(std::move(buffers)){};
|
||||
|
||||
PublicResult(PublicResult &) = delete;
|
||||
|
||||
/// Create a public result from buffers.
|
||||
static std::unique_ptr<PublicResult>
|
||||
fromBuffers(const ClientParameters &clientParameters,
|
||||
std::vector<TensorData> buffers) {
|
||||
return std::make_unique<PublicResult>(clientParameters, buffers);
|
||||
std::vector<TensorData> &&buffers) {
|
||||
return std::make_unique<PublicResult>(clientParameters, std::move(buffers));
|
||||
}
|
||||
|
||||
/// Unserialize from an input stream inplace.
|
||||
@@ -99,21 +99,22 @@ struct PublicResult {
|
||||
outcome::checked<std::vector<T>, StringError>
|
||||
asClearTextVector(KeySet &keySet, size_t pos) {
|
||||
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
|
||||
if (!gate.isEncrypted()) {
|
||||
std::vector<T> result;
|
||||
result.reserve(buffers[pos].values.size());
|
||||
std::copy(buffers[pos].values.begin(), buffers[pos].values.end(),
|
||||
std::back_inserter(result));
|
||||
return result;
|
||||
}
|
||||
if (!gate.isEncrypted())
|
||||
return buffers[pos].asFlatVector<T>();
|
||||
|
||||
auto buffer = buffers[pos];
|
||||
auto &buffer = buffers[pos];
|
||||
auto lweSize = clientParameters.lweBufferSize(gate);
|
||||
|
||||
std::vector<T> decryptedValues(buffer.length() / lweSize);
|
||||
for (size_t i = 0; i < decryptedValues.size(); i++) {
|
||||
auto ciphertext = &buffer.values[i * lweSize];
|
||||
auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize);
|
||||
uint64_t decrypted;
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decrypted));
|
||||
|
||||
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
|
||||
// FIXME: this may break alignment restrictions on some
|
||||
// architectures
|
||||
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
|
||||
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted));
|
||||
decryptedValues[i] = decrypted;
|
||||
}
|
||||
return decryptedValues;
|
||||
@@ -137,10 +138,9 @@ TensorData tensorDataFromScalar(uint64_t value);
|
||||
|
||||
/// Helper function to convert from MemRefDescriptor to
|
||||
/// TensorData
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank,
|
||||
encrypted_scalars_t allocated,
|
||||
encrypted_scalars_t aligned, size_t offset,
|
||||
size_t *sizes, size_t *strides);
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
|
||||
bool is_signed, void *allocated, void *aligned,
|
||||
size_t offset, size_t *sizes, size_t *strides);
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
|
||||
#include "concretelang/ClientLib/ClientParameters.h"
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
@@ -40,6 +41,14 @@ std::istream &readWord(std::istream &istream, Word &word) {
|
||||
return istream;
|
||||
}
|
||||
|
||||
template <typename Word>
|
||||
std::istream &readWords(std::istream &istream, Word *words, size_t numWords) {
|
||||
assert(std::numeric_limits<size_t>::max() / sizeof(*words) > numWords);
|
||||
istream.read(reinterpret_cast<char *>(words), sizeof(*words) * numWords);
|
||||
assert(istream.good());
|
||||
return istream;
|
||||
}
|
||||
|
||||
template <typename Size>
|
||||
std::istream &readSize(std::istream &istream, Size &size) {
|
||||
return readWord(istream, size);
|
||||
@@ -57,13 +66,29 @@ std::ostream &operator<<(std::ostream &ostream,
|
||||
const RuntimeContext &runtimeContext);
|
||||
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
|
||||
|
||||
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
|
||||
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
|
||||
std::ostream &ostream);
|
||||
|
||||
std::ostream &serializeTensorData(TensorData &values_and_sizes,
|
||||
std::ostream &ostream);
|
||||
template <typename T>
|
||||
std::ostream &serializeTensorDataRaw(const llvm::ArrayRef<size_t> &dimensions,
|
||||
const llvm::ArrayRef<T> &values,
|
||||
std::ostream &ostream) {
|
||||
|
||||
TensorData unserializeTensorData(
|
||||
writeWord<uint64_t>(ostream, dimensions.size());
|
||||
|
||||
for (size_t dim : dimensions)
|
||||
writeWord<int64_t>(ostream, dim);
|
||||
|
||||
writeWord<uint64_t>(ostream, sizeof(T) * 8);
|
||||
writeWord<uint8_t>(ostream, std::is_signed<T>());
|
||||
|
||||
for (T val : values)
|
||||
writeWord(ostream, val);
|
||||
|
||||
return ostream;
|
||||
}
|
||||
|
||||
outcome::checked<TensorData, StringError> unserializeTensorData(
|
||||
std::vector<int64_t> &expectedSizes, // includes unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream);
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
#ifndef CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
#define CONCRETELANG_CLIENTLIB_TYPES_H_
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <stddef.h>
|
||||
#include <vector>
|
||||
@@ -30,24 +32,630 @@ template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
|
||||
using encrypted_scalar_t = uint64_t *;
|
||||
using encrypted_scalars_t = uint64_t *;
|
||||
|
||||
struct TensorData {
|
||||
std::vector<uint64_t> values; // tensor of rank r + 1
|
||||
std::vector<int64_t> sizes; // r sizes
|
||||
// Element types for `TensorData`
|
||||
enum class ElementType { u64, i64, u32, i32, u16, i16, u8, i8 };
|
||||
|
||||
inline size_t length() {
|
||||
if (sizes.empty()) {
|
||||
return 0;
|
||||
}
|
||||
size_t len = 1;
|
||||
for (auto size : sizes) {
|
||||
len *= size;
|
||||
}
|
||||
return len;
|
||||
namespace {
|
||||
// Returns the number of bits for an element type
|
||||
static constexpr size_t getElementTypeWidth(ElementType t) {
|
||||
switch (t) {
|
||||
case ElementType::u64:
|
||||
case ElementType::i64:
|
||||
return 64;
|
||||
case ElementType::u32:
|
||||
case ElementType::i32:
|
||||
return 32;
|
||||
case ElementType::u16:
|
||||
case ElementType::i16:
|
||||
return 16;
|
||||
case ElementType::u8:
|
||||
case ElementType::i8:
|
||||
return 8;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// Constants for the element types used for tensors representing
|
||||
// encrypted data and data after decryption
|
||||
constexpr ElementType EncryptedScalarElementType = ElementType::u64;
|
||||
constexpr size_t EncryptedScalarElementWidth =
|
||||
getElementTypeWidth(ElementType::u64);
|
||||
|
||||
using EncryptedScalarElement = uint64_t;
|
||||
|
||||
namespace detail {
|
||||
namespace TensorData {
|
||||
|
||||
// Union used to store the pointer to the actual data of an instance
|
||||
// of `TensorData`. Values are stored contiguously in memory in a
|
||||
// `std::vector` whose element type corresponds to the element type of
|
||||
// the tensor.
|
||||
union value_vector_union {
|
||||
std::vector<uint64_t> *u64;
|
||||
std::vector<int64_t> *i64;
|
||||
std::vector<uint32_t> *u32;
|
||||
std::vector<int32_t> *i32;
|
||||
std::vector<uint16_t> *u16;
|
||||
std::vector<int16_t> *i16;
|
||||
std::vector<uint8_t> *u8;
|
||||
std::vector<int8_t> *i8;
|
||||
};
|
||||
|
||||
// Function templates that would go into the class `TensorData`, but
|
||||
// which need to declared in namespace scope, since specializations of
|
||||
// templates on the return type cannot be done for member functions as
|
||||
// per the C++ standard
|
||||
template <typename T> T begin(union value_vector_union &vec);
|
||||
template <typename T> T end(union value_vector_union &vec);
|
||||
template <typename T> T cbegin(union value_vector_union &vec);
|
||||
template <typename T> T cend(union value_vector_union &vec);
|
||||
template <typename T> T getElements(union value_vector_union &vec);
|
||||
template <typename T> T getConstElements(const union value_vector_union &vec);
|
||||
|
||||
template <typename T>
|
||||
T getElementValue(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
template <typename T>
|
||||
T &getElementReference(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
template <typename T>
|
||||
T *getElementPointer(union value_vector_union &vec, size_t idx,
|
||||
ElementType elementType);
|
||||
|
||||
// Specializations for the above templates
|
||||
#define TENSORDATA_SPECIALIZE_FOR_ITERATOR(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline std::vector<ELTY>::iterator begin(union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->begin(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::iterator end(union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->end(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::const_iterator cbegin( \
|
||||
union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->cbegin(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY>::const_iterator cend( \
|
||||
union value_vector_union &vec) { \
|
||||
return vec.SUFFIX->cend(); \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline std::vector<ELTY> &getElements(union value_vector_union &vec) { \
|
||||
return *vec.SUFFIX; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline const std::vector<ELTY> &getConstElements( \
|
||||
const union value_vector_union &vec) { \
|
||||
return *vec.SUFFIX; \
|
||||
}
|
||||
|
||||
inline size_t lweSize() { return sizes.back(); }
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint64_t, u64)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int64_t, i64)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint32_t, u32)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int32_t, i32)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint16_t, u16)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int16_t, i16)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint8_t, u8)
|
||||
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int8_t, i8)
|
||||
|
||||
#define TENSORDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
|
||||
template <> \
|
||||
inline ELTY getElementValue(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return (*vec.SUFFIX)[idx]; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline ELTY &getElementReference(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return (*vec.SUFFIX)[idx]; \
|
||||
} \
|
||||
\
|
||||
template <> \
|
||||
inline ELTY *getElementPointer(union value_vector_union &vec, size_t idx, \
|
||||
ElementType elementType) { \
|
||||
assert(elementType == ElementType::SUFFIX); \
|
||||
return &(*vec.SUFFIX)[idx]; \
|
||||
}
|
||||
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
|
||||
TENSORDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
|
||||
|
||||
} // namespace TensorData
|
||||
} // namespace detail
|
||||
|
||||
// Representation of a tensor with an arbitrary number of dimensions
|
||||
class TensorData {
|
||||
protected:
|
||||
detail::TensorData::value_vector_union values;
|
||||
ElementType elementType;
|
||||
std::vector<size_t> dimensions;
|
||||
|
||||
/* Multi-dimensional, uninitialized, but preallocated tensor */
|
||||
void initPreallocated(llvm::ArrayRef<size_t> dimensions,
|
||||
ElementType elementType) {
|
||||
assert(dimensions.size() != 0);
|
||||
this->dimensions.resize(dimensions.size());
|
||||
|
||||
size_t n = getNumElements(dimensions);
|
||||
|
||||
switch (elementType) {
|
||||
case ElementType::u64:
|
||||
this->values.u64 = new std::vector<uint64_t>(n);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->values.i64 = new std::vector<int64_t>(n);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->values.u32 = new std::vector<uint32_t>(n);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->values.i32 = new std::vector<int32_t>(n);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->values.u16 = new std::vector<uint16_t>(n);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->values.i16 = new std::vector<int16_t>(n);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->values.u8 = new std::vector<uint8_t>(n);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->values.i8 = new std::vector<int8_t>(n);
|
||||
break;
|
||||
}
|
||||
this->elementType = elementType;
|
||||
std::copy(dimensions.begin(), dimensions.end(), this->dimensions.begin());
|
||||
}
|
||||
|
||||
// Creates a vector<size_t> from an ArrayRef<T>
|
||||
template <typename T>
|
||||
static std::vector<size_t> toDimSpec(llvm::ArrayRef<T> dims) {
|
||||
return std::vector<size_t>(dims.begin(), dims.end());
|
||||
}
|
||||
|
||||
public:
|
||||
// Returns the total number of elements of a tensor with the
|
||||
// specified dimensions
|
||||
template <typename T> static size_t getNumElements(T dimensions) {
|
||||
size_t n = 1;
|
||||
for (auto dim : dimensions)
|
||||
n *= dim;
|
||||
|
||||
return n;
|
||||
}
|
||||
|
||||
// Returns the number of bits of an integer capable of storing
|
||||
// values with up to `elementWidth` bits.
|
||||
static size_t storageWidth(size_t elementWidth) {
|
||||
if (elementWidth > 64) {
|
||||
assert(false && "Maximum supported element width is 64");
|
||||
} else if (elementWidth > 32) {
|
||||
return 64;
|
||||
} else if (elementWidth > 16) {
|
||||
return 32;
|
||||
} else if (elementWidth > 8) {
|
||||
return 16;
|
||||
} else {
|
||||
return 8;
|
||||
}
|
||||
}
|
||||
|
||||
// Move constructor. Leaves `that` uninitialized.
|
||||
TensorData(TensorData &&that)
|
||||
: elementType(that.elementType), dimensions(std::move(that.dimensions)) {
|
||||
switch (that.elementType) {
|
||||
case ElementType::u64:
|
||||
this->values.u64 = that.values.u64;
|
||||
that.values.u64 = nullptr;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
this->values.i64 = that.values.i64;
|
||||
that.values.i64 = nullptr;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
this->values.u32 = that.values.u32;
|
||||
that.values.u32 = nullptr;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
this->values.i32 = that.values.i32;
|
||||
that.values.i32 = nullptr;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
this->values.u16 = that.values.u16;
|
||||
that.values.u16 = nullptr;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
this->values.i16 = that.values.i16;
|
||||
that.values.i16 = nullptr;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
this->values.u8 = that.values.u8;
|
||||
that.values.u8 = nullptr;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
this->values.i8 = that.values.i8;
|
||||
that.values.i8 = nullptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Constructor to build a multi-dimensional tensor with the
|
||||
// corresponding element type. All elements are initialized with the
|
||||
// default value of `0`.
|
||||
TensorData(llvm::ArrayRef<size_t> dimensions, ElementType elementType) {
|
||||
initPreallocated(dimensions, elementType);
|
||||
}
|
||||
|
||||
TensorData(llvm::ArrayRef<int64_t> dimensions, ElementType elementType)
|
||||
: TensorData(toDimSpec(dimensions), elementType) {}
|
||||
|
||||
// Constructor to build a multi-dimensional tensor with the element
|
||||
// type corresponding to `elementWidth` and `sign`. The value for
|
||||
// `elementWidth` must be a power of 2 of up to 64. All elements are
|
||||
// initialized with the default value of `0`.
|
||||
TensorData(llvm::ArrayRef<size_t> dimensions, size_t elementWidth,
|
||||
bool sign) {
|
||||
switch (elementWidth) {
|
||||
case 64:
|
||||
initPreallocated(dimensions,
|
||||
(sign) ? ElementType::i64 : ElementType::u64);
|
||||
break;
|
||||
case 32:
|
||||
initPreallocated(dimensions,
|
||||
(sign) ? ElementType::i32 : ElementType::u32);
|
||||
break;
|
||||
case 16:
|
||||
initPreallocated(dimensions,
|
||||
(sign) ? ElementType::i16 : ElementType::u16);
|
||||
break;
|
||||
case 8:
|
||||
initPreallocated(dimensions, (sign) ? ElementType::i8 : ElementType::u8);
|
||||
break;
|
||||
default:
|
||||
assert(false && "Element width must be 64, 32, 16 or 8 bits");
|
||||
}
|
||||
}
|
||||
|
||||
TensorData(llvm::ArrayRef<int64_t> dimensions, size_t elementWidth, bool sign)
|
||||
: TensorData(toDimSpec(dimensions), elementWidth, sign) {}
|
||||
|
||||
#define DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(ELTY, SUFFIX) \
|
||||
/* Multi-dimensional, initialized tensor, values copied from */ \
|
||||
/* `values` */ \
|
||||
TensorData(llvm::ArrayRef<ELTY> values, llvm::ArrayRef<size_t> dimensions) \
|
||||
: dimensions(dimensions.begin(), dimensions.end()) { \
|
||||
assert(dimensions.size() != 0); \
|
||||
size_t n = getNumElements(dimensions); \
|
||||
this->values.SUFFIX = new std::vector<ELTY>(n); \
|
||||
this->elementType = ElementType::SUFFIX; \
|
||||
this->bulkAssign(values); \
|
||||
} \
|
||||
\
|
||||
/* One-dimensional, initialized tensor. Values are copied from */ \
|
||||
/* `values` */ \
|
||||
TensorData(llvm::ArrayRef<ELTY> values) \
|
||||
: TensorData(values, llvm::SmallVector<size_t, 1>{values.size()}) {}
|
||||
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint64_t, u64)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int64_t, i64)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint32_t, u32)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int32_t, i32)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint16_t, u16)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int16_t, i16)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint8_t, u8)
|
||||
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int8_t, i8)
|
||||
|
||||
~TensorData() {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
delete values.u64;
|
||||
break;
|
||||
case ElementType::i64:
|
||||
delete values.i64;
|
||||
break;
|
||||
case ElementType::u32:
|
||||
delete values.u32;
|
||||
break;
|
||||
case ElementType::i32:
|
||||
delete values.i32;
|
||||
break;
|
||||
case ElementType::u16:
|
||||
delete values.u16;
|
||||
break;
|
||||
case ElementType::i16:
|
||||
delete values.i16;
|
||||
break;
|
||||
case ElementType::u8:
|
||||
delete values.u8;
|
||||
break;
|
||||
case ElementType::i8:
|
||||
delete values.i8;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the total number of elements of the tensor
|
||||
size_t length() const { return getNumElements(this->dimensions); }
|
||||
|
||||
// Returns a vector with the size for each dimension of the tensor
|
||||
const std::vector<size_t> &getDimensions() const { return this->dimensions; }
|
||||
|
||||
// Returns the number of dimensions
|
||||
size_t getRank() const { return this->dimensions.size(); }
|
||||
|
||||
// Multi-dimensional access to a tensor element
|
||||
template <typename T> T &operator[](llvm::ArrayRef<int64_t> index) {
|
||||
// Number of dimensions must match
|
||||
assert(index.size() == dimensions.size());
|
||||
|
||||
int64_t offset = 0;
|
||||
int64_t multiplier = 1;
|
||||
for (int64_t i = index.size() - 1; i > 0; i--) {
|
||||
offset += index[i] * multiplier;
|
||||
multiplier *= this->dimensions[i];
|
||||
}
|
||||
|
||||
return detail::TensorData::getElementReference<T>(values, offset,
|
||||
elementType);
|
||||
}
|
||||
|
||||
// Iterator pointing to the first element of a flat representation
|
||||
// of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator begin() {
|
||||
return detail::TensorData::begin<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Iterator pointing past the last element of a flat representation
|
||||
// of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator end() {
|
||||
return detail::TensorData::end<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Const iterator pointing to the first element of a flat
|
||||
// representation of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator cbegin() {
|
||||
return detail::TensorData::cbegin<typename std::vector<T>::iterator>(
|
||||
values);
|
||||
}
|
||||
|
||||
// Const iterator pointing past the last element of a flat
|
||||
// representation of the tensor.
|
||||
template <typename T> typename std::vector<T>::iterator cend() {
|
||||
return detail::TensorData::cend<typename std::vector<T>::iterator>(values);
|
||||
}
|
||||
|
||||
// Flat representation of the const tensor
|
||||
template <typename T> const std::vector<T> &getElements() const {
|
||||
return detail::TensorData::getConstElements<const std::vector<T> &>(values);
|
||||
}
|
||||
|
||||
// Flat representation of the tensor
|
||||
template <typename T> const std::vector<T> &getElements() {
|
||||
return detail::TensorData::getElements<std::vector<T> &>(values);
|
||||
}
|
||||
|
||||
// Returns the `index`-th value of a flat representation of the tensor
|
||||
template <typename T> T getElementValue(size_t index) {
|
||||
return detail::TensorData::getElementValue<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a reference to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
template <typename T> T &getElementReference(size_t index) {
|
||||
return detail::TensorData::getElementReference<T>(values, index,
|
||||
elementType);
|
||||
}
|
||||
|
||||
// Returns a pointer to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
template <typename T> T *getElementPointer(size_t index) {
|
||||
return detail::TensorData::getElementPointer<T>(values, index, elementType);
|
||||
}
|
||||
|
||||
// Returns a void pointer to the `index`-th value of a flat
|
||||
// representation of the tensor
|
||||
void *getOpaqueElementPointer(size_t index) {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint64_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i64:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int64_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u32:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint32_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i32:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int32_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u16:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint16_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i16:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int16_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::u8:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<uint8_t>(values, index,
|
||||
elementType));
|
||||
case ElementType::i8:
|
||||
return reinterpret_cast<void *>(
|
||||
detail::TensorData::getElementPointer<int8_t>(values, index,
|
||||
elementType));
|
||||
}
|
||||
|
||||
assert(false && "Unknown element type");
|
||||
}
|
||||
|
||||
// Returns the element type of the tensor
|
||||
ElementType getElementType() const { return this->elementType; }
|
||||
|
||||
// Returns the size of a tensor element in bytes
|
||||
size_t getElementSize() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
case ElementType::i64:
|
||||
return 8;
|
||||
case ElementType::u32:
|
||||
case ElementType::i32:
|
||||
return 4;
|
||||
case ElementType::u16:
|
||||
case ElementType::i16:
|
||||
return 2;
|
||||
case ElementType::u8:
|
||||
case ElementType::i8:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns `true` if elements are signed, otherwise `false`
|
||||
bool getElementSignedness() const {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
case ElementType::u32:
|
||||
case ElementType::u16:
|
||||
case ElementType::u8:
|
||||
return false;
|
||||
case ElementType::i64:
|
||||
case ElementType::i32:
|
||||
case ElementType::i16:
|
||||
case ElementType::i8:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns the width of an element in bits
|
||||
size_t getElementWidth() const {
|
||||
return getElementTypeWidth(this->elementType);
|
||||
}
|
||||
|
||||
// Returns the total number of elements of the tensor
|
||||
size_t getNumElements() const { return getNumElements(this->dimensions); }
|
||||
|
||||
// Copy all elements from `values` to the tensor. Note that this
|
||||
// does not append values to the tensor, but overwrites existing
|
||||
// values.
|
||||
template <typename T> void bulkAssign(llvm::ArrayRef<T> values) {
|
||||
assert(values.size() <= this->getNumElements());
|
||||
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
std::copy(values.begin(), values.end(), this->values.u64->begin());
|
||||
break;
|
||||
case ElementType::i64:
|
||||
std::copy(values.begin(), values.end(), this->values.i64->begin());
|
||||
break;
|
||||
case ElementType::u32:
|
||||
std::copy(values.begin(), values.end(), this->values.u32->begin());
|
||||
break;
|
||||
case ElementType::i32:
|
||||
std::copy(values.begin(), values.end(), this->values.i32->begin());
|
||||
break;
|
||||
case ElementType::u16:
|
||||
std::copy(values.begin(), values.end(), this->values.u16->begin());
|
||||
break;
|
||||
case ElementType::i16:
|
||||
std::copy(values.begin(), values.end(), this->values.i16->begin());
|
||||
break;
|
||||
case ElementType::u8:
|
||||
std::copy(values.begin(), values.end(), this->values.u8->begin());
|
||||
break;
|
||||
case ElementType::i8:
|
||||
std::copy(values.begin(), values.end(), this->values.i8->begin());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Copies all elements of a flat representation of the tensor to the
|
||||
// positions starting with the iterator `start`.
|
||||
template <typename IT> void copy(IT start) {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
std::copy(this->values.u64->begin(), this->values.u64->end(), start);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
std::copy(this->values.i64->begin(), this->values.i64->end(), start);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
std::copy(this->values.u32->begin(), this->values.u32->end(), start);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
std::copy(this->values.i32->begin(), this->values.i32->end(), start);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
std::copy(this->values.u16->begin(), this->values.u16->end(), start);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
std::copy(this->values.i16->begin(), this->values.i16->end(), start);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
std::copy(this->values.u8->begin(), this->values.u8->end(), start);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
std::copy(this->values.i8->begin(), this->values.i8->end(), start);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Returns a flat representation of the tensor with elements
|
||||
// converted to the type `T`
|
||||
template <typename T> std::vector<T> asFlatVector() {
|
||||
std::vector<T> ret(getNumElements());
|
||||
this->copy(ret.begin());
|
||||
return ret;
|
||||
}
|
||||
|
||||
// Returns a void pointer to the first element of a flat
|
||||
// representation of the tensor
|
||||
void *getValuesAsOpaquePointer() {
|
||||
switch (this->elementType) {
|
||||
case ElementType::u64:
|
||||
return static_cast<void *>(values.u64->data());
|
||||
case ElementType::i64:
|
||||
return static_cast<void *>(values.i64->data());
|
||||
case ElementType::u32:
|
||||
return static_cast<void *>(values.u32->data());
|
||||
case ElementType::i32:
|
||||
return static_cast<void *>(values.i32->data());
|
||||
case ElementType::u16:
|
||||
return static_cast<void *>(values.u16->data());
|
||||
case ElementType::i16:
|
||||
return static_cast<void *>(values.i16->data());
|
||||
case ElementType::u8:
|
||||
return static_cast<void *>(values.u8->data());
|
||||
case ElementType::i8:
|
||||
return static_cast<void *>(values.i8->data());
|
||||
}
|
||||
|
||||
assert(false && "Unhandled element type");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
#endif
|
||||
|
||||
@@ -16,7 +16,8 @@ namespace serverlib {
|
||||
using concretelang::clientlib::TensorData;
|
||||
|
||||
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args, size_t rank);
|
||||
std::vector<void *> args, size_t rank,
|
||||
size_t element_width, bool is_signed);
|
||||
|
||||
} // namespace serverlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -31,32 +31,34 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape = keySet.clientParameters().bufferShape(input);
|
||||
|
||||
// Allocate empty
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
|
||||
ciphertextBuffers.emplace_back(shape, clientlib::EncryptedScalarElementType);
|
||||
TensorData &values_and_sizes = ciphertextBuffers.back();
|
||||
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
|
||||
values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input));
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
|
||||
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(
|
||||
pos, values_and_sizes.getElementPointer<decrypted_scalar_t>(0), arg));
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (auto size : values_and_sizes.sizes) {
|
||||
for (auto size : values_and_sizes.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
// strides
|
||||
int64_t stride = values_and_sizes.length();
|
||||
for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) {
|
||||
auto size = values_and_sizes.sizes[i];
|
||||
int64_t stride = TensorData::getNumElements(shape);
|
||||
for (size_t size : values_and_sizes.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
preparedArgs.push_back((void *)1);
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
|
||||
@@ -41,11 +41,12 @@ PublicArguments::serialize(std::ostream &ostream) {
|
||||
"are not yet supported. Argument ")
|
||||
<< iGate;
|
||||
}
|
||||
|
||||
/*auto allocated = */ preparedArgs[iPreparedArgs++];
|
||||
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
|
||||
assert(aligned != nullptr);
|
||||
auto offset = (size_t)preparedArgs[iPreparedArgs++];
|
||||
std::vector<int64_t> sizes; // includes lweSize as last dim
|
||||
std::vector<size_t> sizes; // includes lweSize as last dim
|
||||
sizes.resize(rank + 1);
|
||||
for (auto dim = 0u; dim < sizes.size(); dim++) {
|
||||
// sizes are part of the client parameters signature
|
||||
@@ -60,8 +61,13 @@ PublicArguments::serialize(std::ostream &ostream) {
|
||||
}
|
||||
// TODO: STRIDES
|
||||
auto values = aligned + offset;
|
||||
serializeTensorData(sizes, values, ostream);
|
||||
|
||||
serializeTensorDataRaw(sizes,
|
||||
llvm::ArrayRef<clientlib::EncryptedScalarElement>{
|
||||
values, TensorData::getNumElements(sizes)},
|
||||
ostream);
|
||||
}
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
@@ -76,18 +82,24 @@ PublicArguments::unserializeArgs(std::istream &istream) {
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
std::vector<int64_t> sizes = gate.shape.dimensions;
|
||||
sizes.push_back(lweSize);
|
||||
ciphertextBuffers.push_back(unserializeTensorData(sizes, istream));
|
||||
auto tdOrErr = unserializeTensorData(sizes, istream);
|
||||
|
||||
if (tdOrErr.has_error())
|
||||
return tdOrErr.error();
|
||||
|
||||
ciphertextBuffers.push_back(std::move(tdOrErr.value()));
|
||||
auto &values_and_sizes = ciphertextBuffers.back();
|
||||
|
||||
if (istream.fail()) {
|
||||
return StringError(
|
||||
"PublicArguments::unserializeArgs: Failed to read argument ")
|
||||
<< iGate;
|
||||
}
|
||||
preparedArgs.push_back(/*allocated*/ nullptr);
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
preparedArgs.push_back(values_and_sizes.getValuesAsOpaquePointer());
|
||||
preparedArgs.push_back(/*offset*/ 0);
|
||||
// sizes
|
||||
for (auto size : values_and_sizes.sizes) {
|
||||
for (auto size : values_and_sizes.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
// strides has been removed by serialization
|
||||
@@ -120,10 +132,12 @@ PublicResult::unserialize(std::istream &istream) {
|
||||
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
|
||||
std::vector<int64_t> sizes = gate.shape.dimensions;
|
||||
sizes.push_back(lweSize);
|
||||
buffers.push_back(unserializeTensorData(sizes, istream));
|
||||
if (istream.fail()) {
|
||||
return StringError("Cannot read tensor data");
|
||||
}
|
||||
auto tdOrErr = unserializeTensorData(sizes, istream);
|
||||
|
||||
if (tdOrErr.has_error())
|
||||
return tdOrErr.error();
|
||||
|
||||
buffers.push_back(std::move(tdOrErr.value()));
|
||||
}
|
||||
return outcome::success();
|
||||
}
|
||||
@@ -134,7 +148,7 @@ PublicResult::serialize(std::ostream &ostream) {
|
||||
return StringError(
|
||||
"PublicResult::serialize: ostream should be in binary mode");
|
||||
}
|
||||
for (auto tensorData : buffers) {
|
||||
for (const TensorData &tensorData : buffers) {
|
||||
serializeTensorData(tensorData, ostream);
|
||||
if (ostream.fail()) {
|
||||
return StringError("Cannot write tensor data");
|
||||
@@ -166,40 +180,85 @@ size_t global_index(size_t index[], size_t sizes[], size_t strides[],
|
||||
return g_index;
|
||||
}
|
||||
|
||||
TensorData tensorDataFromScalar(uint64_t value) { return {{value}, {1}}; }
|
||||
TensorData tensorDataFromScalar(uint64_t value) {
|
||||
return TensorData{llvm::ArrayRef<uint64_t>{value}, {1}};
|
||||
}
|
||||
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank,
|
||||
encrypted_scalars_t allocated,
|
||||
encrypted_scalars_t aligned, size_t offset,
|
||||
size_t *sizes, size_t *strides) {
|
||||
TensorData result;
|
||||
static inline bool isReferenceToMLIRGlobalMemory(void *ptr) {
|
||||
return reinterpret_cast<uintptr_t>(ptr) == 0xdeadbeef;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TensorData tensorDataFromMemRefTyped(size_t memref_rank, void *allocatedVoid,
|
||||
void *alignedVoid, size_t offset,
|
||||
size_t *sizes, size_t *strides) {
|
||||
T *allocated = reinterpret_cast<T *>(allocatedVoid);
|
||||
T *aligned = reinterpret_cast<T *>(alignedVoid);
|
||||
|
||||
// FIXME: handle sign correctly
|
||||
TensorData result(llvm::ArrayRef<size_t>{sizes, memref_rank}, sizeof(T) * 8,
|
||||
false);
|
||||
assert(aligned != nullptr);
|
||||
result.sizes.resize(memref_rank);
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
result.sizes[r] = sizes[r];
|
||||
}
|
||||
|
||||
// ephemeral multi dim index to compute global strides
|
||||
size_t *index = new size_t[memref_rank];
|
||||
for (size_t r = 0; r < memref_rank; r++) {
|
||||
index[r] = 0;
|
||||
}
|
||||
auto len = result.length();
|
||||
result.values.resize(len);
|
||||
|
||||
// TODO: add a fast path for dense result (no real strides)
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
int g_index = offset + global_index(index, sizes, strides, memref_rank);
|
||||
result.values[i] = aligned[g_index];
|
||||
result.getElementReference<T>(i) = aligned[g_index];
|
||||
next_coord_index(index, sizes, memref_rank);
|
||||
}
|
||||
delete[] index;
|
||||
// TEMPORARY: That quick and dirty but as this function is used only to
|
||||
// convert a result of the mlir program and as data are copied here, we
|
||||
// release the alocated pointer if it set.
|
||||
if (allocated != nullptr) {
|
||||
|
||||
if (allocated != nullptr && !isReferenceToMLIRGlobalMemory(allocated)) {
|
||||
free(allocated);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
|
||||
bool is_signed, void *allocated, void *aligned,
|
||||
size_t offset, size_t *sizes, size_t *strides) {
|
||||
size_t storage_width = TensorData::storageWidth(element_width);
|
||||
|
||||
switch (storage_width) {
|
||||
case 64:
|
||||
return (is_signed)
|
||||
? std::move(tensorDataFromMemRefTyped<int64_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides))
|
||||
: std::move(tensorDataFromMemRefTyped<uint64_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides));
|
||||
case 32:
|
||||
return (is_signed)
|
||||
? std::move(tensorDataFromMemRefTyped<int32_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides))
|
||||
: std::move(tensorDataFromMemRefTyped<uint32_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides));
|
||||
case 16:
|
||||
return (is_signed)
|
||||
? std::move(tensorDataFromMemRefTyped<int16_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides))
|
||||
: std::move(tensorDataFromMemRefTyped<uint16_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides));
|
||||
case 8:
|
||||
return (is_signed)
|
||||
? std::move(tensorDataFromMemRefTyped<int8_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides))
|
||||
: std::move(tensorDataFromMemRefTyped<uint8_t>(
|
||||
memref_rank, allocated, aligned, offset, sizes, strides));
|
||||
default:
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace clientlib
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -149,70 +149,135 @@ std::ostream &operator<<(std::ostream &ostream,
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(uint64_t *values, size_t length,
|
||||
std::ostream &ostream) {
|
||||
if (incorrectMode(ostream)) {
|
||||
return ostream;
|
||||
}
|
||||
writeSize(ostream, length);
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
writeWord(ostream, values[i]);
|
||||
}
|
||||
return ostream;
|
||||
template <typename T>
|
||||
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
|
||||
std::istream &istream) {
|
||||
readWords(istream, values_and_sizes.getElementPointer<T>(0),
|
||||
values_and_sizes.getNumElements());
|
||||
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
|
||||
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
|
||||
std::ostream &ostream) {
|
||||
size_t length = 1;
|
||||
for (auto size : sizes) {
|
||||
length *= size;
|
||||
writeSize(ostream, size);
|
||||
switch (values_and_sizes.getElementType()) {
|
||||
case ElementType::u64:
|
||||
return serializeTensorDataRaw<uint64_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint64_t>(), ostream);
|
||||
case ElementType::i64:
|
||||
return serializeTensorDataRaw<int64_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int64_t>(), ostream);
|
||||
case ElementType::u32:
|
||||
return serializeTensorDataRaw<uint32_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint32_t>(), ostream);
|
||||
case ElementType::i32:
|
||||
return serializeTensorDataRaw<int32_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int32_t>(), ostream);
|
||||
case ElementType::u16:
|
||||
return serializeTensorDataRaw<uint16_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint16_t>(), ostream);
|
||||
case ElementType::i16:
|
||||
return serializeTensorDataRaw<int16_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int16_t>(), ostream);
|
||||
case ElementType::u8:
|
||||
return serializeTensorDataRaw<uint8_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint8_t>(), ostream);
|
||||
case ElementType::i8:
|
||||
return serializeTensorDataRaw<int8_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int8_t>(), ostream);
|
||||
}
|
||||
serializeTensorData(values, length, ostream);
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
|
||||
assert(false && "Unhandled element type");
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(TensorData &values_and_sizes,
|
||||
std::ostream &ostream) {
|
||||
std::vector<int64_t> &sizes = values_and_sizes.sizes;
|
||||
encrypted_scalars_t values = values_and_sizes.values.data();
|
||||
return serializeTensorData(sizes, values, ostream);
|
||||
}
|
||||
|
||||
TensorData unserializeTensorData(
|
||||
outcome::checked<TensorData, StringError> unserializeTensorData(
|
||||
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream) {
|
||||
TensorData result;
|
||||
|
||||
if (incorrectMode(istream)) {
|
||||
return result;
|
||||
return StringError("Stream is in incorrect mode");
|
||||
}
|
||||
for (auto expectedSize : expectedSizes) {
|
||||
size_t actualSize;
|
||||
readSize(istream, actualSize);
|
||||
if ((size_t)expectedSize != actualSize) {
|
||||
|
||||
uint64_t numDimensions;
|
||||
readWord(istream, numDimensions);
|
||||
|
||||
std::vector<size_t> dims;
|
||||
|
||||
for (uint64_t i = 0; i < numDimensions; i++) {
|
||||
int64_t dimSize;
|
||||
readWord(istream, dimSize);
|
||||
|
||||
if (dimSize != expectedSizes[i]) {
|
||||
istream.setstate(std::ios::badbit);
|
||||
return StringError("Number of dimensions did not match the number of "
|
||||
"expected dimensions");
|
||||
}
|
||||
assert(actualSize > 0);
|
||||
result.sizes.push_back(actualSize);
|
||||
assert(result.sizes.back() > 0);
|
||||
|
||||
dims.push_back(dimSize);
|
||||
}
|
||||
size_t expectedLen = result.length();
|
||||
assert(expectedLen > 0);
|
||||
// TODO: full read in one step
|
||||
size_t actualLen;
|
||||
readSize(istream, actualLen);
|
||||
if (expectedLen != actualLen) {
|
||||
istream.setstate(std::ios::badbit);
|
||||
|
||||
uint64_t elementWidth;
|
||||
readWord(istream, elementWidth);
|
||||
|
||||
switch (elementWidth) {
|
||||
case 64:
|
||||
case 32:
|
||||
case 16:
|
||||
case 8:
|
||||
break;
|
||||
default:
|
||||
return StringError("Element width must be either 64, 32, 16 or 8, but got ")
|
||||
<< elementWidth;
|
||||
}
|
||||
assert(actualLen == expectedLen);
|
||||
result.values.resize(actualLen);
|
||||
for (uint64_t &value : result.values) {
|
||||
value = 0;
|
||||
readWord(istream, value);
|
||||
|
||||
uint8_t elementSignedness;
|
||||
readWord(istream, elementSignedness);
|
||||
|
||||
if (elementSignedness != 0 && elementSignedness != 1) {
|
||||
return StringError("Numerical value for element signedness must be either "
|
||||
"0 or 1, but got ")
|
||||
<< elementSignedness;
|
||||
}
|
||||
return result;
|
||||
|
||||
TensorData result(dims, elementWidth, elementSignedness == 1);
|
||||
|
||||
switch (result.getElementType()) {
|
||||
case ElementType::u64:
|
||||
unserializeTensorDataElements<uint64_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
unserializeTensorDataElements<int64_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
unserializeTensorDataElements<uint32_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
unserializeTensorDataElements<int32_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
unserializeTensorDataElements<uint16_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
unserializeTensorDataElements<int16_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
unserializeTensorDataElements<uint8_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
unserializeTensorDataElements<int8_t>(result, istream);
|
||||
break;
|
||||
}
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
|
||||
@@ -33,170 +33,202 @@ template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
|
||||
}
|
||||
|
||||
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args,
|
||||
size_t rank) {
|
||||
std::vector<void *> args, size_t rank,
|
||||
size_t element_width, bool is_signed) {
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
|
||||
switch (rank) {
|
||||
case 1: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<1> (*)(void *...)>(func), args);
|
||||
return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(1, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 2: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<2> (*)(void *...)>(func), args);
|
||||
return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(2, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 3: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<3> (*)(void *...)>(func), args);
|
||||
return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(3, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 4: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<4> (*)(void *...)>(func), args);
|
||||
return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(4, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 5: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<5> (*)(void *...)>(func), args);
|
||||
return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(5, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 6: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<6> (*)(void *...)>(func), args);
|
||||
return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(6, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 7: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<7> (*)(void *...)>(func), args);
|
||||
return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(7, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 8: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<8> (*)(void *...)>(func), args);
|
||||
return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(8, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 9: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<9> (*)(void *...)>(func), args);
|
||||
return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(9, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 10: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<10> (*)(void *...)>(func), args);
|
||||
return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(10, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 11: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<11> (*)(void *...)>(func), args);
|
||||
return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(11, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 12: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<12> (*)(void *...)>(func), args);
|
||||
return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(12, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 13: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<13> (*)(void *...)>(func), args);
|
||||
return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(13, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 14: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<14> (*)(void *...)>(func), args);
|
||||
return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(14, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 15: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<15> (*)(void *...)>(func), args);
|
||||
return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(15, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 16: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<16> (*)(void *...)>(func), args);
|
||||
return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(16, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 17: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<17> (*)(void *...)>(func), args);
|
||||
return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(17, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 18: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<18> (*)(void *...)>(func), args);
|
||||
return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(18, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 19: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<19> (*)(void *...)>(func), args);
|
||||
return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(19, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 20: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<20> (*)(void *...)>(func), args);
|
||||
return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(20, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 21: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<21> (*)(void *...)>(func), args);
|
||||
return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(21, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 22: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<22> (*)(void *...)>(func), args);
|
||||
return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(22, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 23: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<23> (*)(void *...)>(func), args);
|
||||
return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(23, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 24: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<24> (*)(void *...)>(func), args);
|
||||
return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(24, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 25: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<25> (*)(void *...)>(func), args);
|
||||
return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(25, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 26: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<26> (*)(void *...)>(func), args);
|
||||
return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(26, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 27: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<27> (*)(void *...)>(func), args);
|
||||
return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(27, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 28: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<28> (*)(void *...)>(func), args);
|
||||
return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(28, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 29: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<29> (*)(void *...)>(func), args);
|
||||
return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(29, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 30: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<30> (*)(void *...)>(func), args);
|
||||
return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(30, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 31: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<31> (*)(void *...)>(func), args);
|
||||
return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(31, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
case 32: {
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<32> (*)(void *...)>(func), args);
|
||||
return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert(32, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}
|
||||
|
||||
default:
|
||||
|
||||
@@ -82,9 +82,17 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
|
||||
"ServerLambda::call is implemented for only one output");
|
||||
auto output = args.clientParameters.outputs[0];
|
||||
auto rank = args.clientParameters.bufferShape(output).size();
|
||||
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank);
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters, {result});
|
||||
;
|
||||
|
||||
// FIXME: Handle sign correctly
|
||||
size_t element_width = (output.isEncrypted()) ? 64 : output.shape.width;
|
||||
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank,
|
||||
element_width, false);
|
||||
|
||||
std::vector<TensorData> results;
|
||||
results.push_back(std::move(result));
|
||||
|
||||
return clientlib::PublicResult::fromBuffers(clientParameters,
|
||||
std::move(results));
|
||||
}
|
||||
|
||||
} // namespace serverlib
|
||||
|
||||
@@ -37,8 +37,8 @@ template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
|
||||
}
|
||||
|
||||
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
|
||||
std::vector<void *> args,
|
||||
size_t rank) {
|
||||
std::vector<void *> args, size_t rank,
|
||||
size_t element_width, bool is_signed) {
|
||||
using concretelang::clientlib::MemRefDescriptor;
|
||||
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
|
||||
switch (rank) {""")
|
||||
@@ -48,7 +48,8 @@ for tensor_rank in range(1, 33):
|
||||
print(f""" case {tensor_rank}: {{
|
||||
auto m = multi_arity_call(
|
||||
convert_fnptr<MemRefDescriptor<{memref_rank}> (*)(void *...)>(func), args);
|
||||
return convert({memref_rank}, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
|
||||
return convert({memref_rank}, element_width, is_signed, m.allocated, m.aligned,
|
||||
m.offset, m.sizes, m.strides);
|
||||
}}""")
|
||||
|
||||
print("""
|
||||
|
||||
@@ -101,7 +101,8 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
return std::move(err);
|
||||
}
|
||||
std::vector<clientlib::TensorData> buffers;
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -119,11 +120,13 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
|
||||
}
|
||||
}
|
||||
std::vector<void *> outputs(numOutputs);
|
||||
std::vector<uint64_t> outputs(numOutputs);
|
||||
|
||||
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
|
||||
// inputs and outputs.
|
||||
std::vector<void *> rawArgs(args.preparedArgs.size() + 1 /*runtime context*/ +
|
||||
outputs.size());
|
||||
std::vector<void *> rawArgs(
|
||||
args.preparedArgs.size() + 1 /*runtime context*/ + 1 /* outputs */
|
||||
);
|
||||
size_t i = 0;
|
||||
// Pointers on inputs
|
||||
for (auto &arg : args.preparedArgs) {
|
||||
@@ -136,10 +139,10 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
// is passed to the compiled function.
|
||||
auto rtCtxPtr = &runtimeContext;
|
||||
rawArgs[i++] = &rtCtxPtr;
|
||||
// Pointers on outputs
|
||||
for (auto &out : outputs) {
|
||||
rawArgs[i++] = &out;
|
||||
}
|
||||
|
||||
// Outputs
|
||||
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
|
||||
|
||||
// Invoke
|
||||
if (auto err = invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
@@ -165,12 +168,20 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
outputOffset += rank;
|
||||
size_t *strides = (size_t *)&outputs[outputOffset];
|
||||
outputOffset += rank;
|
||||
|
||||
size_t elementWidth = (output.isEncrypted())
|
||||
? clientlib::EncryptedScalarElementWidth
|
||||
: output.shape.width;
|
||||
|
||||
// FIXME: Handle sign correctly
|
||||
buffers.push_back(clientlib::tensorDataFromMemRef(
|
||||
rank, allocated, aligned, offset, sizes, strides));
|
||||
rank, elementWidth, false, allocated, aligned, offset, sizes,
|
||||
strides));
|
||||
}
|
||||
}
|
||||
}
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters,
|
||||
std::move(buffers));
|
||||
}
|
||||
|
||||
} // namespace concretelang
|
||||
|
||||
@@ -5,7 +5,30 @@
|
||||
// 1D tensor //////////////////////////////////////////////////////////////////
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, DISABLED_identity) {
|
||||
TEST(End2EndJit_ClearTensor_2D, constant_i8) {
|
||||
checkedJit(lambda,
|
||||
R"XXX(
|
||||
func.func @main() -> tensor<2x2xi8> {
|
||||
%cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi8>
|
||||
return %cst : tensor<2x2xi8>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
llvm::Expected<std::vector<uint8_t>> res =
|
||||
lambda.operator()<std::vector<uint8_t>>();
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (size_t)4);
|
||||
|
||||
EXPECT_EQ((*res)[0], 0);
|
||||
EXPECT_EQ((*res)[1], 1);
|
||||
EXPECT_EQ((*res)[2], 2);
|
||||
EXPECT_EQ((*res)[3], 3);
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, identity) {
|
||||
checkedJit(lambda,
|
||||
R"XXX(
|
||||
func.func @main(%t: tensor<10xi64>) -> tensor<10xi64> {
|
||||
@@ -37,6 +60,29 @@ func.func @main(%t: tensor<10xi64>) -> tensor<10xi64> {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, identity_i8) {
|
||||
checkedJit(lambda,
|
||||
R"XXX(
|
||||
func.func @main(%t: tensor<10xi8>) -> tensor<10xi8> {
|
||||
return %t : tensor<10xi8>
|
||||
}
|
||||
)XXX",
|
||||
"main", true);
|
||||
|
||||
uint8_t arg[]{16, 21, 3, 127, 9, 17, 32, 18, 29, 104};
|
||||
|
||||
llvm::Expected<std::vector<uint8_t>> res =
|
||||
lambda.operator()<std::vector<uint8_t>>(arg, ARRAY_SIZE(arg));
|
||||
|
||||
ASSERT_EXPECTED_SUCCESS(res);
|
||||
|
||||
ASSERT_EQ(res->size(), (size_t)10);
|
||||
|
||||
for (size_t i = 0; i < res->size(); i++) {
|
||||
EXPECT_EQ(arg[i], res->operator[](i)) << "result differ at index " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(End2EndJit_ClearTensor_1D, extract_64) {
|
||||
checkedJit(lambda, R"XXX(
|
||||
func.func @main(%t: tensor<10xi64>, %i: index) -> i64{
|
||||
|
||||
Reference in New Issue
Block a user