fix(compiler): Add support for clear result tensors with element width != 64 bits

Returning tensors with elements whose width is not equal to 64 results
in garbled data. This commit extends the `TensorData` class used to
represent tensors in JIT compilation with support for signed /
unsigned elements of 8/16/32 and 64 bits, such that all clear text
tensors with up to 64 bits can be represented accurately.
This commit is contained in:
Andi Drebes
2022-09-09 16:04:09 +02:00
committed by rudy-6-4
parent f1833f06f2
commit 8255d3e190
13 changed files with 1048 additions and 197 deletions

View File

@@ -147,9 +147,6 @@ public:
<< pos << "has not the expected number of dimension, got "
<< shape.size() << " expected " << input.shape.dimensions.size();
}
// Allocate empty
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
TensorData &values_and_sizes = ciphertextBuffers.back();
// Check shape
for (size_t i = 0; i < shape.size(); i++) {
@@ -159,53 +156,49 @@ public:
<< shape[i] << " expected " << input.shape.dimensions[i];
}
}
// Set sizes
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
std::vector<int64_t> sizes = keySet.clientParameters().bufferShape(input);
if (input.encryption.hasValue()) {
// Allocate values
values_and_sizes.values.resize(
keySet.clientParameters().bufferSize(input));
TensorData td(sizes, ElementType::u64);
auto lweSize = keySet.clientParameters().lweBufferSize(input);
auto &values = values_and_sizes.values;
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data[i]));
OUTCOME_TRYV(keySet.encrypt_lwe(
pos, td.getElementPointer<uint64_t>(offset), data[i]));
}
ciphertextBuffers.push_back(std::move(td));
} else {
// Allocate values take care of gate bitwidth
auto bitsPerValue = bitWidthAsWord(input.shape.width);
auto bytesPerValue = bitsPerValue / 8;
auto nbWordPerValue = 8 / bytesPerValue;
// ceil division
auto size = (input.shape.size / nbWordPerValue) +
(input.shape.size % nbWordPerValue != 0);
size = size == 0 ? 1 : size;
values_and_sizes.values.resize(size);
auto v = (uint8_t *)values_and_sizes.values.data();
for (size_t i = 0; i < input.shape.size; i++) {
auto dst = v + i * bytesPerValue;
auto src = (const uint8_t *)&data[i];
for (size_t j = 0; j < bytesPerValue; j++) {
dst[j] = src[j];
}
}
// FIXME: This always requests a tensor with unsigned elements,
// as the signedness is not captured in the description of the
// circuit
TensorData td(sizes, bitsPerValue, false);
llvm::ArrayRef<T> values(data, TensorData::getNumElements(sizes));
td.bulkAssign(values);
ciphertextBuffers.push_back(std::move(td));
}
TensorData &td = ciphertextBuffers.back();
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.values.data());
preparedArgs.push_back(td.getValuesAsOpaquePointer());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (size_t size : values_and_sizes.sizes) {
for (size_t size : td.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = values_and_sizes.length();
for (size_t size : values_and_sizes.sizes) {
int64_t stride = td.getNumElements();
for (size_t size : td.getDimensions()) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}

View File

@@ -66,16 +66,16 @@ private:
struct PublicResult {
PublicResult(const ClientParameters &clientParameters,
std::vector<TensorData> buffers = {})
: clientParameters(clientParameters), buffers(buffers){};
std::vector<TensorData> &&buffers = {})
: clientParameters(clientParameters), buffers(std::move(buffers)){};
PublicResult(PublicResult &) = delete;
/// Create a public result from buffers.
static std::unique_ptr<PublicResult>
fromBuffers(const ClientParameters &clientParameters,
std::vector<TensorData> buffers) {
return std::make_unique<PublicResult>(clientParameters, buffers);
std::vector<TensorData> &&buffers) {
return std::make_unique<PublicResult>(clientParameters, std::move(buffers));
}
/// Unserialize from an input stream inplace.
@@ -99,21 +99,22 @@ struct PublicResult {
outcome::checked<std::vector<T>, StringError>
asClearTextVector(KeySet &keySet, size_t pos) {
OUTCOME_TRY(auto gate, clientParameters.ouput(pos));
if (!gate.isEncrypted()) {
std::vector<T> result;
result.reserve(buffers[pos].values.size());
std::copy(buffers[pos].values.begin(), buffers[pos].values.end(),
std::back_inserter(result));
return result;
}
if (!gate.isEncrypted())
return buffers[pos].asFlatVector<T>();
auto buffer = buffers[pos];
auto &buffer = buffers[pos];
auto lweSize = clientParameters.lweBufferSize(gate);
std::vector<T> decryptedValues(buffer.length() / lweSize);
for (size_t i = 0; i < decryptedValues.size(); i++) {
auto ciphertext = &buffer.values[i * lweSize];
auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize);
uint64_t decrypted;
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decrypted));
// Convert to uint64_t* as required by `KeySet::decrypt_lwe`
// FIXME: this may break alignment restrictions on some
// architectures
auto ciphertextu64 = reinterpret_cast<uint64_t *>(ciphertext);
OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted));
decryptedValues[i] = decrypted;
}
return decryptedValues;
@@ -137,10 +138,9 @@ TensorData tensorDataFromScalar(uint64_t value);
/// Helper function to convert from MemRefDescriptor to
/// TensorData
TensorData tensorDataFromMemRef(size_t memref_rank,
encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides);
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
bool is_signed, void *allocated, void *aligned,
size_t offset, size_t *sizes, size_t *strides);
} // namespace clientlib
} // namespace concretelang

View File

@@ -7,6 +7,7 @@
#define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H
#include <iostream>
#include <limits>
#include "concretelang/ClientLib/ClientParameters.h"
#include "concretelang/ClientLib/EvaluationKeys.h"
@@ -40,6 +41,14 @@ std::istream &readWord(std::istream &istream, Word &word) {
return istream;
}
template <typename Word>
std::istream &readWords(std::istream &istream, Word *words, size_t numWords) {
assert(std::numeric_limits<size_t>::max() / sizeof(*words) > numWords);
istream.read(reinterpret_cast<char *>(words), sizeof(*words) * numWords);
assert(istream.good());
return istream;
}
template <typename Size>
std::istream &readSize(std::istream &istream, Size &size) {
return readWord(istream, size);
@@ -57,13 +66,29 @@ std::ostream &operator<<(std::ostream &ostream,
const RuntimeContext &runtimeContext);
std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext);
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
std::ostream &ostream);
std::ostream &serializeTensorData(TensorData &values_and_sizes,
std::ostream &ostream);
template <typename T>
std::ostream &serializeTensorDataRaw(const llvm::ArrayRef<size_t> &dimensions,
const llvm::ArrayRef<T> &values,
std::ostream &ostream) {
TensorData unserializeTensorData(
writeWord<uint64_t>(ostream, dimensions.size());
for (size_t dim : dimensions)
writeWord<int64_t>(ostream, dim);
writeWord<uint64_t>(ostream, sizeof(T) * 8);
writeWord<uint8_t>(ostream, std::is_signed<T>());
for (T val : values)
writeWord(ostream, val);
return ostream;
}
outcome::checked<TensorData, StringError> unserializeTensorData(
std::vector<int64_t> &expectedSizes, // includes unsigned to
// accomodate non static sizes
std::istream &istream);

View File

@@ -6,6 +6,8 @@
#ifndef CONCRETELANG_CLIENTLIB_TYPES_H_
#define CONCRETELANG_CLIENTLIB_TYPES_H_
#include "llvm/ADT/ArrayRef.h"
#include <cstdint>
#include <stddef.h>
#include <vector>
@@ -30,24 +32,630 @@ template <size_t Rank> using encrypted_tensor_t = MemRefDescriptor<Rank>;
using encrypted_scalar_t = uint64_t *;
using encrypted_scalars_t = uint64_t *;
struct TensorData {
std::vector<uint64_t> values; // tensor of rank r + 1
std::vector<int64_t> sizes; // r sizes
// Element types for `TensorData`
enum class ElementType { u64, i64, u32, i32, u16, i16, u8, i8 };
inline size_t length() {
if (sizes.empty()) {
return 0;
}
size_t len = 1;
for (auto size : sizes) {
len *= size;
}
return len;
namespace {
// Returns the number of bits for an element type
static constexpr size_t getElementTypeWidth(ElementType t) {
switch (t) {
case ElementType::u64:
case ElementType::i64:
return 64;
case ElementType::u32:
case ElementType::i32:
return 32;
case ElementType::u16:
case ElementType::i16:
return 16;
case ElementType::u8:
case ElementType::i8:
return 8;
}
}
} // namespace
// Constants for the element types used for tensors representing
// encrypted data and data after decryption
constexpr ElementType EncryptedScalarElementType = ElementType::u64;
constexpr size_t EncryptedScalarElementWidth =
getElementTypeWidth(ElementType::u64);
using EncryptedScalarElement = uint64_t;
namespace detail {
namespace TensorData {
// Union used to store the pointer to the actual data of an instance
// of `TensorData`. Values are stored contiguously in memory in a
// `std::vector` whose element type corresponds to the element type of
// the tensor.
union value_vector_union {
std::vector<uint64_t> *u64;
std::vector<int64_t> *i64;
std::vector<uint32_t> *u32;
std::vector<int32_t> *i32;
std::vector<uint16_t> *u16;
std::vector<int16_t> *i16;
std::vector<uint8_t> *u8;
std::vector<int8_t> *i8;
};
// Function templates that would go into the class `TensorData`, but
// which need to declared in namespace scope, since specializations of
// templates on the return type cannot be done for member functions as
// per the C++ standard
template <typename T> T begin(union value_vector_union &vec);
template <typename T> T end(union value_vector_union &vec);
template <typename T> T cbegin(union value_vector_union &vec);
template <typename T> T cend(union value_vector_union &vec);
template <typename T> T getElements(union value_vector_union &vec);
template <typename T> T getConstElements(const union value_vector_union &vec);
template <typename T>
T getElementValue(union value_vector_union &vec, size_t idx,
ElementType elementType);
template <typename T>
T &getElementReference(union value_vector_union &vec, size_t idx,
ElementType elementType);
template <typename T>
T *getElementPointer(union value_vector_union &vec, size_t idx,
ElementType elementType);
// Specializations for the above templates
#define TENSORDATA_SPECIALIZE_FOR_ITERATOR(ELTY, SUFFIX) \
template <> \
inline std::vector<ELTY>::iterator begin(union value_vector_union &vec) { \
return vec.SUFFIX->begin(); \
} \
\
template <> \
inline std::vector<ELTY>::iterator end(union value_vector_union &vec) { \
return vec.SUFFIX->end(); \
} \
\
template <> \
inline std::vector<ELTY>::const_iterator cbegin( \
union value_vector_union &vec) { \
return vec.SUFFIX->cbegin(); \
} \
\
template <> \
inline std::vector<ELTY>::const_iterator cend( \
union value_vector_union &vec) { \
return vec.SUFFIX->cend(); \
} \
\
template <> \
inline std::vector<ELTY> &getElements(union value_vector_union &vec) { \
return *vec.SUFFIX; \
} \
\
template <> \
inline const std::vector<ELTY> &getConstElements( \
const union value_vector_union &vec) { \
return *vec.SUFFIX; \
}
inline size_t lweSize() { return sizes.back(); }
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint64_t, u64)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int64_t, i64)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint32_t, u32)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int32_t, i32)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint16_t, u16)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int16_t, i16)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint8_t, u8)
TENSORDATA_SPECIALIZE_FOR_ITERATOR(int8_t, i8)
#define TENSORDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \
template <> \
inline ELTY getElementValue(union value_vector_union &vec, size_t idx, \
ElementType elementType) { \
assert(elementType == ElementType::SUFFIX); \
return (*vec.SUFFIX)[idx]; \
} \
\
template <> \
inline ELTY &getElementReference(union value_vector_union &vec, size_t idx, \
ElementType elementType) { \
assert(elementType == ElementType::SUFFIX); \
return (*vec.SUFFIX)[idx]; \
} \
\
template <> \
inline ELTY *getElementPointer(union value_vector_union &vec, size_t idx, \
ElementType elementType) { \
assert(elementType == ElementType::SUFFIX); \
return &(*vec.SUFFIX)[idx]; \
}
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64)
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32)
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16)
TENSORDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8)
TENSORDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8)
} // namespace TensorData
} // namespace detail
// Representation of a tensor with an arbitrary number of dimensions
class TensorData {
protected:
detail::TensorData::value_vector_union values;
ElementType elementType;
std::vector<size_t> dimensions;
/* Multi-dimensional, uninitialized, but preallocated tensor */
void initPreallocated(llvm::ArrayRef<size_t> dimensions,
ElementType elementType) {
assert(dimensions.size() != 0);
this->dimensions.resize(dimensions.size());
size_t n = getNumElements(dimensions);
switch (elementType) {
case ElementType::u64:
this->values.u64 = new std::vector<uint64_t>(n);
break;
case ElementType::i64:
this->values.i64 = new std::vector<int64_t>(n);
break;
case ElementType::u32:
this->values.u32 = new std::vector<uint32_t>(n);
break;
case ElementType::i32:
this->values.i32 = new std::vector<int32_t>(n);
break;
case ElementType::u16:
this->values.u16 = new std::vector<uint16_t>(n);
break;
case ElementType::i16:
this->values.i16 = new std::vector<int16_t>(n);
break;
case ElementType::u8:
this->values.u8 = new std::vector<uint8_t>(n);
break;
case ElementType::i8:
this->values.i8 = new std::vector<int8_t>(n);
break;
}
this->elementType = elementType;
std::copy(dimensions.begin(), dimensions.end(), this->dimensions.begin());
}
// Creates a vector<size_t> from an ArrayRef<T>
template <typename T>
static std::vector<size_t> toDimSpec(llvm::ArrayRef<T> dims) {
return std::vector<size_t>(dims.begin(), dims.end());
}
public:
// Returns the total number of elements of a tensor with the
// specified dimensions
template <typename T> static size_t getNumElements(T dimensions) {
size_t n = 1;
for (auto dim : dimensions)
n *= dim;
return n;
}
// Returns the number of bits of an integer capable of storing
// values with up to `elementWidth` bits.
static size_t storageWidth(size_t elementWidth) {
if (elementWidth > 64) {
assert(false && "Maximum supported element width is 64");
} else if (elementWidth > 32) {
return 64;
} else if (elementWidth > 16) {
return 32;
} else if (elementWidth > 8) {
return 16;
} else {
return 8;
}
}
// Move constructor. Leaves `that` uninitialized.
TensorData(TensorData &&that)
: elementType(that.elementType), dimensions(std::move(that.dimensions)) {
switch (that.elementType) {
case ElementType::u64:
this->values.u64 = that.values.u64;
that.values.u64 = nullptr;
break;
case ElementType::i64:
this->values.i64 = that.values.i64;
that.values.i64 = nullptr;
break;
case ElementType::u32:
this->values.u32 = that.values.u32;
that.values.u32 = nullptr;
break;
case ElementType::i32:
this->values.i32 = that.values.i32;
that.values.i32 = nullptr;
break;
case ElementType::u16:
this->values.u16 = that.values.u16;
that.values.u16 = nullptr;
break;
case ElementType::i16:
this->values.i16 = that.values.i16;
that.values.i16 = nullptr;
break;
case ElementType::u8:
this->values.u8 = that.values.u8;
that.values.u8 = nullptr;
break;
case ElementType::i8:
this->values.i8 = that.values.i8;
that.values.i8 = nullptr;
break;
}
}
// Constructor to build a multi-dimensional tensor with the
// corresponding element type. All elements are initialized with the
// default value of `0`.
TensorData(llvm::ArrayRef<size_t> dimensions, ElementType elementType) {
initPreallocated(dimensions, elementType);
}
TensorData(llvm::ArrayRef<int64_t> dimensions, ElementType elementType)
: TensorData(toDimSpec(dimensions), elementType) {}
// Constructor to build a multi-dimensional tensor with the element
// type corresponding to `elementWidth` and `sign`. The value for
// `elementWidth` must be a power of 2 of up to 64. All elements are
// initialized with the default value of `0`.
TensorData(llvm::ArrayRef<size_t> dimensions, size_t elementWidth,
bool sign) {
switch (elementWidth) {
case 64:
initPreallocated(dimensions,
(sign) ? ElementType::i64 : ElementType::u64);
break;
case 32:
initPreallocated(dimensions,
(sign) ? ElementType::i32 : ElementType::u32);
break;
case 16:
initPreallocated(dimensions,
(sign) ? ElementType::i16 : ElementType::u16);
break;
case 8:
initPreallocated(dimensions, (sign) ? ElementType::i8 : ElementType::u8);
break;
default:
assert(false && "Element width must be 64, 32, 16 or 8 bits");
}
}
TensorData(llvm::ArrayRef<int64_t> dimensions, size_t elementWidth, bool sign)
: TensorData(toDimSpec(dimensions), elementWidth, sign) {}
#define DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(ELTY, SUFFIX) \
/* Multi-dimensional, initialized tensor, values copied from */ \
/* `values` */ \
TensorData(llvm::ArrayRef<ELTY> values, llvm::ArrayRef<size_t> dimensions) \
: dimensions(dimensions.begin(), dimensions.end()) { \
assert(dimensions.size() != 0); \
size_t n = getNumElements(dimensions); \
this->values.SUFFIX = new std::vector<ELTY>(n); \
this->elementType = ElementType::SUFFIX; \
this->bulkAssign(values); \
} \
\
/* One-dimensional, initialized tensor. Values are copied from */ \
/* `values` */ \
TensorData(llvm::ArrayRef<ELTY> values) \
: TensorData(values, llvm::SmallVector<size_t, 1>{values.size()}) {}
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint64_t, u64)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int64_t, i64)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint32_t, u32)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int32_t, i32)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint16_t, u16)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int16_t, i16)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint8_t, u8)
DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int8_t, i8)
~TensorData() {
switch (this->elementType) {
case ElementType::u64:
delete values.u64;
break;
case ElementType::i64:
delete values.i64;
break;
case ElementType::u32:
delete values.u32;
break;
case ElementType::i32:
delete values.i32;
break;
case ElementType::u16:
delete values.u16;
break;
case ElementType::i16:
delete values.i16;
break;
case ElementType::u8:
delete values.u8;
break;
case ElementType::i8:
delete values.i8;
break;
}
}
// Returns the total number of elements of the tensor
size_t length() const { return getNumElements(this->dimensions); }
// Returns a vector with the size for each dimension of the tensor
const std::vector<size_t> &getDimensions() const { return this->dimensions; }
// Returns the number of dimensions
size_t getRank() const { return this->dimensions.size(); }
// Multi-dimensional access to a tensor element
template <typename T> T &operator[](llvm::ArrayRef<int64_t> index) {
// Number of dimensions must match
assert(index.size() == dimensions.size());
int64_t offset = 0;
int64_t multiplier = 1;
for (int64_t i = index.size() - 1; i > 0; i--) {
offset += index[i] * multiplier;
multiplier *= this->dimensions[i];
}
return detail::TensorData::getElementReference<T>(values, offset,
elementType);
}
// Iterator pointing to the first element of a flat representation
// of the tensor.
template <typename T> typename std::vector<T>::iterator begin() {
return detail::TensorData::begin<typename std::vector<T>::iterator>(values);
}
// Iterator pointing past the last element of a flat representation
// of the tensor.
template <typename T> typename std::vector<T>::iterator end() {
return detail::TensorData::end<typename std::vector<T>::iterator>(values);
}
// Const iterator pointing to the first element of a flat
// representation of the tensor.
template <typename T> typename std::vector<T>::iterator cbegin() {
return detail::TensorData::cbegin<typename std::vector<T>::iterator>(
values);
}
// Const iterator pointing past the last element of a flat
// representation of the tensor.
template <typename T> typename std::vector<T>::iterator cend() {
return detail::TensorData::cend<typename std::vector<T>::iterator>(values);
}
// Flat representation of the const tensor
template <typename T> const std::vector<T> &getElements() const {
return detail::TensorData::getConstElements<const std::vector<T> &>(values);
}
// Flat representation of the tensor
template <typename T> const std::vector<T> &getElements() {
return detail::TensorData::getElements<std::vector<T> &>(values);
}
// Returns the `index`-th value of a flat representation of the tensor
template <typename T> T getElementValue(size_t index) {
return detail::TensorData::getElementValue<T>(values, index, elementType);
}
// Returns a reference to the `index`-th value of a flat
// representation of the tensor
template <typename T> T &getElementReference(size_t index) {
return detail::TensorData::getElementReference<T>(values, index,
elementType);
}
// Returns a pointer to the `index`-th value of a flat
// representation of the tensor
template <typename T> T *getElementPointer(size_t index) {
return detail::TensorData::getElementPointer<T>(values, index, elementType);
}
// Returns a void pointer to the `index`-th value of a flat
// representation of the tensor
void *getOpaqueElementPointer(size_t index) {
switch (this->elementType) {
case ElementType::u64:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint64_t>(values, index,
elementType));
case ElementType::i64:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int64_t>(values, index,
elementType));
case ElementType::u32:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint32_t>(values, index,
elementType));
case ElementType::i32:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int32_t>(values, index,
elementType));
case ElementType::u16:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint16_t>(values, index,
elementType));
case ElementType::i16:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int16_t>(values, index,
elementType));
case ElementType::u8:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<uint8_t>(values, index,
elementType));
case ElementType::i8:
return reinterpret_cast<void *>(
detail::TensorData::getElementPointer<int8_t>(values, index,
elementType));
}
assert(false && "Unknown element type");
}
// Returns the element type of the tensor
ElementType getElementType() const { return this->elementType; }
// Returns the size of a tensor element in bytes
size_t getElementSize() const {
switch (this->elementType) {
case ElementType::u64:
case ElementType::i64:
return 8;
case ElementType::u32:
case ElementType::i32:
return 4;
case ElementType::u16:
case ElementType::i16:
return 2;
case ElementType::u8:
case ElementType::i8:
return 1;
}
}
// Returns `true` if elements are signed, otherwise `false`
bool getElementSignedness() const {
switch (this->elementType) {
case ElementType::u64:
case ElementType::u32:
case ElementType::u16:
case ElementType::u8:
return false;
case ElementType::i64:
case ElementType::i32:
case ElementType::i16:
case ElementType::i8:
return true;
}
}
// Returns the width of an element in bits
size_t getElementWidth() const {
return getElementTypeWidth(this->elementType);
}
// Returns the total number of elements of the tensor
size_t getNumElements() const { return getNumElements(this->dimensions); }
// Copy all elements from `values` to the tensor. Note that this
// does not append values to the tensor, but overwrites existing
// values.
template <typename T> void bulkAssign(llvm::ArrayRef<T> values) {
assert(values.size() <= this->getNumElements());
switch (this->elementType) {
case ElementType::u64:
std::copy(values.begin(), values.end(), this->values.u64->begin());
break;
case ElementType::i64:
std::copy(values.begin(), values.end(), this->values.i64->begin());
break;
case ElementType::u32:
std::copy(values.begin(), values.end(), this->values.u32->begin());
break;
case ElementType::i32:
std::copy(values.begin(), values.end(), this->values.i32->begin());
break;
case ElementType::u16:
std::copy(values.begin(), values.end(), this->values.u16->begin());
break;
case ElementType::i16:
std::copy(values.begin(), values.end(), this->values.i16->begin());
break;
case ElementType::u8:
std::copy(values.begin(), values.end(), this->values.u8->begin());
break;
case ElementType::i8:
std::copy(values.begin(), values.end(), this->values.i8->begin());
break;
}
}
// Copies all elements of a flat representation of the tensor to the
// positions starting with the iterator `start`.
template <typename IT> void copy(IT start) {
switch (this->elementType) {
case ElementType::u64:
std::copy(this->values.u64->begin(), this->values.u64->end(), start);
break;
case ElementType::i64:
std::copy(this->values.i64->begin(), this->values.i64->end(), start);
break;
case ElementType::u32:
std::copy(this->values.u32->begin(), this->values.u32->end(), start);
break;
case ElementType::i32:
std::copy(this->values.i32->begin(), this->values.i32->end(), start);
break;
case ElementType::u16:
std::copy(this->values.u16->begin(), this->values.u16->end(), start);
break;
case ElementType::i16:
std::copy(this->values.i16->begin(), this->values.i16->end(), start);
break;
case ElementType::u8:
std::copy(this->values.u8->begin(), this->values.u8->end(), start);
break;
case ElementType::i8:
std::copy(this->values.i8->begin(), this->values.i8->end(), start);
break;
}
}
// Returns a flat representation of the tensor with elements
// converted to the type `T`
template <typename T> std::vector<T> asFlatVector() {
std::vector<T> ret(getNumElements());
this->copy(ret.begin());
return ret;
}
// Returns a void pointer to the first element of a flat
// representation of the tensor
void *getValuesAsOpaquePointer() {
switch (this->elementType) {
case ElementType::u64:
return static_cast<void *>(values.u64->data());
case ElementType::i64:
return static_cast<void *>(values.i64->data());
case ElementType::u32:
return static_cast<void *>(values.u32->data());
case ElementType::i32:
return static_cast<void *>(values.i32->data());
case ElementType::u16:
return static_cast<void *>(values.u16->data());
case ElementType::i16:
return static_cast<void *>(values.i16->data());
case ElementType::u8:
return static_cast<void *>(values.u8->data());
case ElementType::i8:
return static_cast<void *>(values.i8->data());
}
assert(false && "Unhandled element type");
}
};
} // namespace clientlib
} // namespace concretelang
#endif

View File

@@ -16,7 +16,8 @@ namespace serverlib {
using concretelang::clientlib::TensorData;
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args, size_t rank);
std::vector<void *> args, size_t rank,
size_t element_width, bool is_signed);
} // namespace serverlib
} // namespace concretelang

View File

@@ -31,32 +31,34 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
preparedArgs.push_back((void *)arg);
return outcome::success();
}
std::vector<int64_t> shape = keySet.clientParameters().bufferShape(input);
// Allocate empty
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
ciphertextBuffers.emplace_back(shape, clientlib::EncryptedScalarElementType);
TensorData &values_and_sizes = ciphertextBuffers.back();
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input));
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
OUTCOME_TRYV(keySet.encrypt_lwe(
pos, values_and_sizes.getElementPointer<decrypted_scalar_t>(0), arg));
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.values.data());
preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (auto size : values_and_sizes.sizes) {
for (auto size : values_and_sizes.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// strides
int64_t stride = values_and_sizes.length();
for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) {
auto size = values_and_sizes.sizes[i];
int64_t stride = TensorData::getNumElements(shape);
for (size_t size : values_and_sizes.getDimensions()) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
preparedArgs.push_back((void *)1);
return outcome::success();
}

View File

@@ -41,11 +41,12 @@ PublicArguments::serialize(std::ostream &ostream) {
"are not yet supported. Argument ")
<< iGate;
}
/*auto allocated = */ preparedArgs[iPreparedArgs++];
auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++];
assert(aligned != nullptr);
auto offset = (size_t)preparedArgs[iPreparedArgs++];
std::vector<int64_t> sizes; // includes lweSize as last dim
std::vector<size_t> sizes; // includes lweSize as last dim
sizes.resize(rank + 1);
for (auto dim = 0u; dim < sizes.size(); dim++) {
// sizes are part of the client parameters signature
@@ -60,8 +61,13 @@ PublicArguments::serialize(std::ostream &ostream) {
}
// TODO: STRIDES
auto values = aligned + offset;
serializeTensorData(sizes, values, ostream);
serializeTensorDataRaw(sizes,
llvm::ArrayRef<clientlib::EncryptedScalarElement>{
values, TensorData::getNumElements(sizes)},
ostream);
}
return outcome::success();
}
@@ -76,18 +82,24 @@ PublicArguments::unserializeArgs(std::istream &istream) {
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
std::vector<int64_t> sizes = gate.shape.dimensions;
sizes.push_back(lweSize);
ciphertextBuffers.push_back(unserializeTensorData(sizes, istream));
auto tdOrErr = unserializeTensorData(sizes, istream);
if (tdOrErr.has_error())
return tdOrErr.error();
ciphertextBuffers.push_back(std::move(tdOrErr.value()));
auto &values_and_sizes = ciphertextBuffers.back();
if (istream.fail()) {
return StringError(
"PublicArguments::unserializeArgs: Failed to read argument ")
<< iGate;
}
preparedArgs.push_back(/*allocated*/ nullptr);
preparedArgs.push_back((void *)values_and_sizes.values.data());
preparedArgs.push_back(values_and_sizes.getValuesAsOpaquePointer());
preparedArgs.push_back(/*offset*/ 0);
// sizes
for (auto size : values_and_sizes.sizes) {
for (auto size : values_and_sizes.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// strides has been removed by serialization
@@ -120,10 +132,12 @@ PublicResult::unserialize(std::istream &istream) {
auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize();
std::vector<int64_t> sizes = gate.shape.dimensions;
sizes.push_back(lweSize);
buffers.push_back(unserializeTensorData(sizes, istream));
if (istream.fail()) {
return StringError("Cannot read tensor data");
}
auto tdOrErr = unserializeTensorData(sizes, istream);
if (tdOrErr.has_error())
return tdOrErr.error();
buffers.push_back(std::move(tdOrErr.value()));
}
return outcome::success();
}
@@ -134,7 +148,7 @@ PublicResult::serialize(std::ostream &ostream) {
return StringError(
"PublicResult::serialize: ostream should be in binary mode");
}
for (auto tensorData : buffers) {
for (const TensorData &tensorData : buffers) {
serializeTensorData(tensorData, ostream);
if (ostream.fail()) {
return StringError("Cannot write tensor data");
@@ -166,40 +180,85 @@ size_t global_index(size_t index[], size_t sizes[], size_t strides[],
return g_index;
}
TensorData tensorDataFromScalar(uint64_t value) { return {{value}, {1}}; }
TensorData tensorDataFromScalar(uint64_t value) {
return TensorData{llvm::ArrayRef<uint64_t>{value}, {1}};
}
TensorData tensorDataFromMemRef(size_t memref_rank,
encrypted_scalars_t allocated,
encrypted_scalars_t aligned, size_t offset,
size_t *sizes, size_t *strides) {
TensorData result;
static inline bool isReferenceToMLIRGlobalMemory(void *ptr) {
return reinterpret_cast<uintptr_t>(ptr) == 0xdeadbeef;
}
template <typename T>
TensorData tensorDataFromMemRefTyped(size_t memref_rank, void *allocatedVoid,
void *alignedVoid, size_t offset,
size_t *sizes, size_t *strides) {
T *allocated = reinterpret_cast<T *>(allocatedVoid);
T *aligned = reinterpret_cast<T *>(alignedVoid);
// FIXME: handle sign correctly
TensorData result(llvm::ArrayRef<size_t>{sizes, memref_rank}, sizeof(T) * 8,
false);
assert(aligned != nullptr);
result.sizes.resize(memref_rank);
for (size_t r = 0; r < memref_rank; r++) {
result.sizes[r] = sizes[r];
}
// ephemeral multi dim index to compute global strides
size_t *index = new size_t[memref_rank];
for (size_t r = 0; r < memref_rank; r++) {
index[r] = 0;
}
auto len = result.length();
result.values.resize(len);
// TODO: add a fast path for dense result (no real strides)
for (size_t i = 0; i < len; i++) {
int g_index = offset + global_index(index, sizes, strides, memref_rank);
result.values[i] = aligned[g_index];
result.getElementReference<T>(i) = aligned[g_index];
next_coord_index(index, sizes, memref_rank);
}
delete[] index;
// TEMPORARY: That quick and dirty but as this function is used only to
// convert a result of the mlir program and as data are copied here, we
// release the alocated pointer if it set.
if (allocated != nullptr) {
if (allocated != nullptr && !isReferenceToMLIRGlobalMemory(allocated)) {
free(allocated);
}
return result;
}
TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width,
bool is_signed, void *allocated, void *aligned,
size_t offset, size_t *sizes, size_t *strides) {
size_t storage_width = TensorData::storageWidth(element_width);
switch (storage_width) {
case 64:
return (is_signed)
? std::move(tensorDataFromMemRefTyped<int64_t>(
memref_rank, allocated, aligned, offset, sizes, strides))
: std::move(tensorDataFromMemRefTyped<uint64_t>(
memref_rank, allocated, aligned, offset, sizes, strides));
case 32:
return (is_signed)
? std::move(tensorDataFromMemRefTyped<int32_t>(
memref_rank, allocated, aligned, offset, sizes, strides))
: std::move(tensorDataFromMemRefTyped<uint32_t>(
memref_rank, allocated, aligned, offset, sizes, strides));
case 16:
return (is_signed)
? std::move(tensorDataFromMemRefTyped<int16_t>(
memref_rank, allocated, aligned, offset, sizes, strides))
: std::move(tensorDataFromMemRefTyped<uint16_t>(
memref_rank, allocated, aligned, offset, sizes, strides));
case 8:
return (is_signed)
? std::move(tensorDataFromMemRefTyped<int8_t>(
memref_rank, allocated, aligned, offset, sizes, strides))
: std::move(tensorDataFromMemRefTyped<uint8_t>(
memref_rank, allocated, aligned, offset, sizes, strides));
default:
assert(false);
}
}
} // namespace clientlib
} // namespace concretelang

View File

@@ -149,70 +149,135 @@ std::ostream &operator<<(std::ostream &ostream,
return ostream;
}
std::ostream &serializeTensorData(uint64_t *values, size_t length,
std::ostream &ostream) {
if (incorrectMode(ostream)) {
return ostream;
}
writeSize(ostream, length);
for (size_t i = 0; i < length; i++) {
writeWord(ostream, values[i]);
}
return ostream;
template <typename T>
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
std::istream &istream) {
readWords(istream, values_and_sizes.getElementPointer<T>(0),
values_and_sizes.getNumElements());
return istream;
}
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
std::ostream &ostream) {
size_t length = 1;
for (auto size : sizes) {
length *= size;
writeSize(ostream, size);
switch (values_and_sizes.getElementType()) {
case ElementType::u64:
return serializeTensorDataRaw<uint64_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint64_t>(), ostream);
case ElementType::i64:
return serializeTensorDataRaw<int64_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int64_t>(), ostream);
case ElementType::u32:
return serializeTensorDataRaw<uint32_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint32_t>(), ostream);
case ElementType::i32:
return serializeTensorDataRaw<int32_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int32_t>(), ostream);
case ElementType::u16:
return serializeTensorDataRaw<uint16_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint16_t>(), ostream);
case ElementType::i16:
return serializeTensorDataRaw<int16_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int16_t>(), ostream);
case ElementType::u8:
return serializeTensorDataRaw<uint8_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<uint8_t>(), ostream);
case ElementType::i8:
return serializeTensorDataRaw<int8_t>(
values_and_sizes.getDimensions(),
values_and_sizes.getElements<int8_t>(), ostream);
}
serializeTensorData(values, length, ostream);
assert(ostream.good());
return ostream;
assert(false && "Unhandled element type");
}
std::ostream &serializeTensorData(TensorData &values_and_sizes,
std::ostream &ostream) {
std::vector<int64_t> &sizes = values_and_sizes.sizes;
encrypted_scalars_t values = values_and_sizes.values.data();
return serializeTensorData(sizes, values, ostream);
}
TensorData unserializeTensorData(
outcome::checked<TensorData, StringError> unserializeTensorData(
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
// accomodate non static sizes
std::istream &istream) {
TensorData result;
if (incorrectMode(istream)) {
return result;
return StringError("Stream is in incorrect mode");
}
for (auto expectedSize : expectedSizes) {
size_t actualSize;
readSize(istream, actualSize);
if ((size_t)expectedSize != actualSize) {
uint64_t numDimensions;
readWord(istream, numDimensions);
std::vector<size_t> dims;
for (uint64_t i = 0; i < numDimensions; i++) {
int64_t dimSize;
readWord(istream, dimSize);
if (dimSize != expectedSizes[i]) {
istream.setstate(std::ios::badbit);
return StringError("Number of dimensions did not match the number of "
"expected dimensions");
}
assert(actualSize > 0);
result.sizes.push_back(actualSize);
assert(result.sizes.back() > 0);
dims.push_back(dimSize);
}
size_t expectedLen = result.length();
assert(expectedLen > 0);
// TODO: full read in one step
size_t actualLen;
readSize(istream, actualLen);
if (expectedLen != actualLen) {
istream.setstate(std::ios::badbit);
uint64_t elementWidth;
readWord(istream, elementWidth);
switch (elementWidth) {
case 64:
case 32:
case 16:
case 8:
break;
default:
return StringError("Element width must be either 64, 32, 16 or 8, but got ")
<< elementWidth;
}
assert(actualLen == expectedLen);
result.values.resize(actualLen);
for (uint64_t &value : result.values) {
value = 0;
readWord(istream, value);
uint8_t elementSignedness;
readWord(istream, elementSignedness);
if (elementSignedness != 0 && elementSignedness != 1) {
return StringError("Numerical value for element signedness must be either "
"0 or 1, but got ")
<< elementSignedness;
}
return result;
TensorData result(dims, elementWidth, elementSignedness == 1);
switch (result.getElementType()) {
case ElementType::u64:
unserializeTensorDataElements<uint64_t>(result, istream);
break;
case ElementType::i64:
unserializeTensorDataElements<int64_t>(result, istream);
break;
case ElementType::u32:
unserializeTensorDataElements<uint32_t>(result, istream);
break;
case ElementType::i32:
unserializeTensorDataElements<int32_t>(result, istream);
break;
case ElementType::u16:
unserializeTensorDataElements<uint16_t>(result, istream);
break;
case ElementType::i16:
unserializeTensorDataElements<int16_t>(result, istream);
break;
case ElementType::u8:
unserializeTensorDataElements<uint8_t>(result, istream);
break;
case ElementType::i8:
unserializeTensorDataElements<int8_t>(result, istream);
break;
}
return std::move(result);
}
std::ostream &operator<<(std::ostream &ostream,

View File

@@ -33,170 +33,202 @@ template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
}
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args,
size_t rank) {
std::vector<void *> args, size_t rank,
size_t element_width, bool is_signed) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
switch (rank) {
case 1: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<1> (*)(void *...)>(func), args);
return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(1, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 2: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<2> (*)(void *...)>(func), args);
return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(2, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 3: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<3> (*)(void *...)>(func), args);
return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(3, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 4: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<4> (*)(void *...)>(func), args);
return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(4, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 5: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<5> (*)(void *...)>(func), args);
return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(5, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 6: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<6> (*)(void *...)>(func), args);
return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(6, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 7: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<7> (*)(void *...)>(func), args);
return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(7, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 8: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<8> (*)(void *...)>(func), args);
return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(8, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 9: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<9> (*)(void *...)>(func), args);
return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(9, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 10: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<10> (*)(void *...)>(func), args);
return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(10, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 11: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<11> (*)(void *...)>(func), args);
return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(11, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 12: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<12> (*)(void *...)>(func), args);
return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(12, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 13: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<13> (*)(void *...)>(func), args);
return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(13, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 14: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<14> (*)(void *...)>(func), args);
return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(14, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 15: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<15> (*)(void *...)>(func), args);
return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(15, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 16: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<16> (*)(void *...)>(func), args);
return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(16, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 17: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<17> (*)(void *...)>(func), args);
return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(17, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 18: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<18> (*)(void *...)>(func), args);
return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(18, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 19: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<19> (*)(void *...)>(func), args);
return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(19, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 20: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<20> (*)(void *...)>(func), args);
return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(20, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 21: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<21> (*)(void *...)>(func), args);
return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(21, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 22: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<22> (*)(void *...)>(func), args);
return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(22, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 23: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<23> (*)(void *...)>(func), args);
return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(23, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 24: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<24> (*)(void *...)>(func), args);
return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(24, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 25: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<25> (*)(void *...)>(func), args);
return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(25, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 26: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<26> (*)(void *...)>(func), args);
return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(26, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 27: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<27> (*)(void *...)>(func), args);
return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(27, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 28: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<28> (*)(void *...)>(func), args);
return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(28, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 29: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<29> (*)(void *...)>(func), args);
return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(29, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 30: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<30> (*)(void *...)>(func), args);
return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(30, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 31: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<31> (*)(void *...)>(func), args);
return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(31, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
case 32: {
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<32> (*)(void *...)>(func), args);
return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert(32, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}
default:

View File

@@ -82,9 +82,17 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) {
"ServerLambda::call is implemented for only one output");
auto output = args.clientParameters.outputs[0];
auto rank = args.clientParameters.bufferShape(output).size();
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank);
return clientlib::PublicResult::fromBuffers(clientParameters, {result});
;
// FIXME: Handle sign correctly
size_t element_width = (output.isEncrypted()) ? 64 : output.shape.width;
auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank,
element_width, false);
std::vector<TensorData> results;
results.push_back(std::move(result));
return clientlib::PublicResult::fromBuffers(clientParameters,
std::move(results));
}
} // namespace serverlib

View File

@@ -37,8 +37,8 @@ template <typename FnDstT, typename FnSrcT> FnDstT convert_fnptr(FnSrcT src) {
}
TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...),
std::vector<void *> args,
size_t rank) {
std::vector<void *> args, size_t rank,
size_t element_width, bool is_signed) {
using concretelang::clientlib::MemRefDescriptor;
constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef;
switch (rank) {""")
@@ -48,7 +48,8 @@ for tensor_rank in range(1, 33):
print(f""" case {tensor_rank}: {{
auto m = multi_arity_call(
convert_fnptr<MemRefDescriptor<{memref_rank}> (*)(void *...)>(func), args);
return convert({memref_rank}, m.allocated, m.aligned, m.offset, m.sizes, m.strides);
return convert({memref_rank}, element_width, is_signed, m.allocated, m.aligned,
m.offset, m.sizes, m.strides);
}}""")
print("""

View File

@@ -101,7 +101,8 @@ JITLambda::call(clientlib::PublicArguments &args,
return std::move(err);
}
std::vector<clientlib::TensorData> buffers;
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
return clientlib::PublicResult::fromBuffers(args.clientParameters,
std::move(buffers));
}
#endif
@@ -119,11 +120,13 @@ JITLambda::call(clientlib::PublicArguments &args,
numOutputs += numArgOfRankedMemrefCallingConvention(shape.size());
}
}
std::vector<void *> outputs(numOutputs);
std::vector<uint64_t> outputs(numOutputs);
// Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on
// inputs and outputs.
std::vector<void *> rawArgs(args.preparedArgs.size() + 1 /*runtime context*/ +
outputs.size());
std::vector<void *> rawArgs(
args.preparedArgs.size() + 1 /*runtime context*/ + 1 /* outputs */
);
size_t i = 0;
// Pointers on inputs
for (auto &arg : args.preparedArgs) {
@@ -136,10 +139,10 @@ JITLambda::call(clientlib::PublicArguments &args,
// is passed to the compiled function.
auto rtCtxPtr = &runtimeContext;
rawArgs[i++] = &rtCtxPtr;
// Pointers on outputs
for (auto &out : outputs) {
rawArgs[i++] = &out;
}
// Outputs
rawArgs[i++] = reinterpret_cast<void *>(outputs.data());
// Invoke
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
@@ -165,12 +168,20 @@ JITLambda::call(clientlib::PublicArguments &args,
outputOffset += rank;
size_t *strides = (size_t *)&outputs[outputOffset];
outputOffset += rank;
size_t elementWidth = (output.isEncrypted())
? clientlib::EncryptedScalarElementWidth
: output.shape.width;
// FIXME: Handle sign correctly
buffers.push_back(clientlib::tensorDataFromMemRef(
rank, allocated, aligned, offset, sizes, strides));
rank, elementWidth, false, allocated, aligned, offset, sizes,
strides));
}
}
}
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
return clientlib::PublicResult::fromBuffers(args.clientParameters,
std::move(buffers));
}
} // namespace concretelang

View File

@@ -5,7 +5,30 @@
// 1D tensor //////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
TEST(End2EndJit_ClearTensor_1D, DISABLED_identity) {
TEST(End2EndJit_ClearTensor_2D, constant_i8) {
checkedJit(lambda,
R"XXX(
func.func @main() -> tensor<2x2xi8> {
%cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi8>
return %cst : tensor<2x2xi8>
}
)XXX",
"main", true);
llvm::Expected<std::vector<uint8_t>> res =
lambda.operator()<std::vector<uint8_t>>();
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (size_t)4);
EXPECT_EQ((*res)[0], 0);
EXPECT_EQ((*res)[1], 1);
EXPECT_EQ((*res)[2], 2);
EXPECT_EQ((*res)[3], 3);
}
TEST(End2EndJit_ClearTensor_1D, identity) {
checkedJit(lambda,
R"XXX(
func.func @main(%t: tensor<10xi64>) -> tensor<10xi64> {
@@ -37,6 +60,29 @@ func.func @main(%t: tensor<10xi64>) -> tensor<10xi64> {
}
}
TEST(End2EndJit_ClearTensor_1D, identity_i8) {
checkedJit(lambda,
R"XXX(
func.func @main(%t: tensor<10xi8>) -> tensor<10xi8> {
return %t : tensor<10xi8>
}
)XXX",
"main", true);
uint8_t arg[]{16, 21, 3, 127, 9, 17, 32, 18, 29, 104};
llvm::Expected<std::vector<uint8_t>> res =
lambda.operator()<std::vector<uint8_t>>(arg, ARRAY_SIZE(arg));
ASSERT_EXPECTED_SUCCESS(res);
ASSERT_EQ(res->size(), (size_t)10);
for (size_t i = 0; i < res->size(); i++) {
EXPECT_EQ(arg[i], res->operator[](i)) << "result differ at index " << i;
}
}
TEST(End2EndJit_ClearTensor_1D, extract_64) {
checkedJit(lambda, R"XXX(
func.func @main(%t: tensor<10xi64>, %i: index) -> i64{