diff --git a/compiler/include/concretelang/ClientLib/EncryptedArguments.h b/compiler/include/concretelang/ClientLib/EncryptedArguments.h index 264c4ec78..87fb4224b 100644 --- a/compiler/include/concretelang/ClientLib/EncryptedArguments.h +++ b/compiler/include/concretelang/ClientLib/EncryptedArguments.h @@ -147,9 +147,6 @@ public: << pos << "has not the expected number of dimension, got " << shape.size() << " expected " << input.shape.dimensions.size(); } - // Allocate empty - ciphertextBuffers.resize(ciphertextBuffers.size() + 1); - TensorData &values_and_sizes = ciphertextBuffers.back(); // Check shape for (size_t i = 0; i < shape.size(); i++) { @@ -159,53 +156,49 @@ public: << shape[i] << " expected " << input.shape.dimensions[i]; } } + // Set sizes - values_and_sizes.sizes = keySet.clientParameters().bufferShape(input); + std::vector sizes = keySet.clientParameters().bufferShape(input); if (input.encryption.hasValue()) { - // Allocate values - values_and_sizes.values.resize( - keySet.clientParameters().bufferSize(input)); + TensorData td(sizes, ElementType::u64); + auto lweSize = keySet.clientParameters().lweBufferSize(input); - auto &values = values_and_sizes.values; + for (size_t i = 0, offset = 0; i < input.shape.size; i++, offset += lweSize) { - OUTCOME_TRYV(keySet.encrypt_lwe(pos, values.data() + offset, data[i])); + OUTCOME_TRYV(keySet.encrypt_lwe( + pos, td.getElementPointer(offset), data[i])); } + ciphertextBuffers.push_back(std::move(td)); } else { - // Allocate values take care of gate bitwidth auto bitsPerValue = bitWidthAsWord(input.shape.width); - auto bytesPerValue = bitsPerValue / 8; - auto nbWordPerValue = 8 / bytesPerValue; - // ceil division - auto size = (input.shape.size / nbWordPerValue) + - (input.shape.size % nbWordPerValue != 0); - size = size == 0 ? 1 : size; - values_and_sizes.values.resize(size); - auto v = (uint8_t *)values_and_sizes.values.data(); - for (size_t i = 0; i < input.shape.size; i++) { - auto dst = v + i * bytesPerValue; - auto src = (const uint8_t *)&data[i]; - for (size_t j = 0; j < bytesPerValue; j++) { - dst[j] = src[j]; - } - } + + // FIXME: This always requests a tensor with unsigned elements, + // as the signedness is not captured in the description of the + // circuit + TensorData td(sizes, bitsPerValue, false); + llvm::ArrayRef values(data, TensorData::getNumElements(sizes)); + td.bulkAssign(values); + ciphertextBuffers.push_back(std::move(td)); } + TensorData &td = ciphertextBuffers.back(); + // allocated preparedArgs.push_back(nullptr); // aligned - preparedArgs.push_back((void *)values_and_sizes.values.data()); + preparedArgs.push_back(td.getValuesAsOpaquePointer()); // offset preparedArgs.push_back((void *)0); // sizes - for (size_t size : values_and_sizes.sizes) { + for (size_t size : td.getDimensions()) { preparedArgs.push_back((void *)size); } // Set the stride for each dimension, equal to the product of the // following dimensions. - int64_t stride = values_and_sizes.length(); - for (size_t size : values_and_sizes.sizes) { + int64_t stride = td.getNumElements(); + for (size_t size : td.getDimensions()) { stride = (size == 0 ? 0 : (stride / size)); preparedArgs.push_back((void *)stride); } diff --git a/compiler/include/concretelang/ClientLib/PublicArguments.h b/compiler/include/concretelang/ClientLib/PublicArguments.h index 3707fa310..59882e092 100644 --- a/compiler/include/concretelang/ClientLib/PublicArguments.h +++ b/compiler/include/concretelang/ClientLib/PublicArguments.h @@ -66,16 +66,16 @@ private: struct PublicResult { PublicResult(const ClientParameters &clientParameters, - std::vector buffers = {}) - : clientParameters(clientParameters), buffers(buffers){}; + std::vector &&buffers = {}) + : clientParameters(clientParameters), buffers(std::move(buffers)){}; PublicResult(PublicResult &) = delete; /// Create a public result from buffers. static std::unique_ptr fromBuffers(const ClientParameters &clientParameters, - std::vector buffers) { - return std::make_unique(clientParameters, buffers); + std::vector &&buffers) { + return std::make_unique(clientParameters, std::move(buffers)); } /// Unserialize from an input stream inplace. @@ -99,21 +99,22 @@ struct PublicResult { outcome::checked, StringError> asClearTextVector(KeySet &keySet, size_t pos) { OUTCOME_TRY(auto gate, clientParameters.ouput(pos)); - if (!gate.isEncrypted()) { - std::vector result; - result.reserve(buffers[pos].values.size()); - std::copy(buffers[pos].values.begin(), buffers[pos].values.end(), - std::back_inserter(result)); - return result; - } + if (!gate.isEncrypted()) + return buffers[pos].asFlatVector(); - auto buffer = buffers[pos]; + auto &buffer = buffers[pos]; auto lweSize = clientParameters.lweBufferSize(gate); + std::vector decryptedValues(buffer.length() / lweSize); for (size_t i = 0; i < decryptedValues.size(); i++) { - auto ciphertext = &buffer.values[i * lweSize]; + auto ciphertext = buffer.getOpaqueElementPointer(i * lweSize); uint64_t decrypted; - OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertext, decrypted)); + + // Convert to uint64_t* as required by `KeySet::decrypt_lwe` + // FIXME: this may break alignment restrictions on some + // architectures + auto ciphertextu64 = reinterpret_cast(ciphertext); + OUTCOME_TRYV(keySet.decrypt_lwe(0, ciphertextu64, decrypted)); decryptedValues[i] = decrypted; } return decryptedValues; @@ -137,10 +138,9 @@ TensorData tensorDataFromScalar(uint64_t value); /// Helper function to convert from MemRefDescriptor to /// TensorData -TensorData tensorDataFromMemRef(size_t memref_rank, - encrypted_scalars_t allocated, - encrypted_scalars_t aligned, size_t offset, - size_t *sizes, size_t *strides); +TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width, + bool is_signed, void *allocated, void *aligned, + size_t offset, size_t *sizes, size_t *strides); } // namespace clientlib } // namespace concretelang diff --git a/compiler/include/concretelang/ClientLib/Serializers.h b/compiler/include/concretelang/ClientLib/Serializers.h index d597d9512..0986ec773 100644 --- a/compiler/include/concretelang/ClientLib/Serializers.h +++ b/compiler/include/concretelang/ClientLib/Serializers.h @@ -7,6 +7,7 @@ #define CONCRETELANG_CLIENTLIB_SERIALIZERS_ARGUMENTS_H #include +#include #include "concretelang/ClientLib/ClientParameters.h" #include "concretelang/ClientLib/EvaluationKeys.h" @@ -40,6 +41,14 @@ std::istream &readWord(std::istream &istream, Word &word) { return istream; } +template +std::istream &readWords(std::istream &istream, Word *words, size_t numWords) { + assert(std::numeric_limits::max() / sizeof(*words) > numWords); + istream.read(reinterpret_cast(words), sizeof(*words) * numWords); + assert(istream.good()); + return istream; +} + template std::istream &readSize(std::istream &istream, Size &size) { return readWord(istream, size); @@ -57,13 +66,29 @@ std::ostream &operator<<(std::ostream &ostream, const RuntimeContext &runtimeContext); std::istream &operator>>(std::istream &istream, RuntimeContext &runtimeContext); -std::ostream &serializeTensorData(std::vector &sizes, uint64_t *values, +std::ostream &serializeTensorData(const TensorData &values_and_sizes, std::ostream &ostream); -std::ostream &serializeTensorData(TensorData &values_and_sizes, - std::ostream &ostream); +template +std::ostream &serializeTensorDataRaw(const llvm::ArrayRef &dimensions, + const llvm::ArrayRef &values, + std::ostream &ostream) { -TensorData unserializeTensorData( + writeWord(ostream, dimensions.size()); + + for (size_t dim : dimensions) + writeWord(ostream, dim); + + writeWord(ostream, sizeof(T) * 8); + writeWord(ostream, std::is_signed()); + + for (T val : values) + writeWord(ostream, val); + + return ostream; +} + +outcome::checked unserializeTensorData( std::vector &expectedSizes, // includes unsigned to // accomodate non static sizes std::istream &istream); diff --git a/compiler/include/concretelang/ClientLib/Types.h b/compiler/include/concretelang/ClientLib/Types.h index e2b5de0fc..1bde34638 100644 --- a/compiler/include/concretelang/ClientLib/Types.h +++ b/compiler/include/concretelang/ClientLib/Types.h @@ -6,6 +6,8 @@ #ifndef CONCRETELANG_CLIENTLIB_TYPES_H_ #define CONCRETELANG_CLIENTLIB_TYPES_H_ +#include "llvm/ADT/ArrayRef.h" + #include #include #include @@ -30,24 +32,630 @@ template using encrypted_tensor_t = MemRefDescriptor; using encrypted_scalar_t = uint64_t *; using encrypted_scalars_t = uint64_t *; -struct TensorData { - std::vector values; // tensor of rank r + 1 - std::vector sizes; // r sizes +// Element types for `TensorData` +enum class ElementType { u64, i64, u32, i32, u16, i16, u8, i8 }; - inline size_t length() { - if (sizes.empty()) { - return 0; - } - size_t len = 1; - for (auto size : sizes) { - len *= size; - } - return len; +namespace { +// Returns the number of bits for an element type +static constexpr size_t getElementTypeWidth(ElementType t) { + switch (t) { + case ElementType::u64: + case ElementType::i64: + return 64; + case ElementType::u32: + case ElementType::i32: + return 32; + case ElementType::u16: + case ElementType::i16: + return 16; + case ElementType::u8: + case ElementType::i8: + return 8; + } +} +} // namespace + +// Constants for the element types used for tensors representing +// encrypted data and data after decryption +constexpr ElementType EncryptedScalarElementType = ElementType::u64; +constexpr size_t EncryptedScalarElementWidth = + getElementTypeWidth(ElementType::u64); + +using EncryptedScalarElement = uint64_t; + +namespace detail { +namespace TensorData { + +// Union used to store the pointer to the actual data of an instance +// of `TensorData`. Values are stored contiguously in memory in a +// `std::vector` whose element type corresponds to the element type of +// the tensor. +union value_vector_union { + std::vector *u64; + std::vector *i64; + std::vector *u32; + std::vector *i32; + std::vector *u16; + std::vector *i16; + std::vector *u8; + std::vector *i8; +}; + +// Function templates that would go into the class `TensorData`, but +// which need to declared in namespace scope, since specializations of +// templates on the return type cannot be done for member functions as +// per the C++ standard +template T begin(union value_vector_union &vec); +template T end(union value_vector_union &vec); +template T cbegin(union value_vector_union &vec); +template T cend(union value_vector_union &vec); +template T getElements(union value_vector_union &vec); +template T getConstElements(const union value_vector_union &vec); + +template +T getElementValue(union value_vector_union &vec, size_t idx, + ElementType elementType); +template +T &getElementReference(union value_vector_union &vec, size_t idx, + ElementType elementType); +template +T *getElementPointer(union value_vector_union &vec, size_t idx, + ElementType elementType); + +// Specializations for the above templates +#define TENSORDATA_SPECIALIZE_FOR_ITERATOR(ELTY, SUFFIX) \ + template <> \ + inline std::vector::iterator begin(union value_vector_union &vec) { \ + return vec.SUFFIX->begin(); \ + } \ + \ + template <> \ + inline std::vector::iterator end(union value_vector_union &vec) { \ + return vec.SUFFIX->end(); \ + } \ + \ + template <> \ + inline std::vector::const_iterator cbegin( \ + union value_vector_union &vec) { \ + return vec.SUFFIX->cbegin(); \ + } \ + \ + template <> \ + inline std::vector::const_iterator cend( \ + union value_vector_union &vec) { \ + return vec.SUFFIX->cend(); \ + } \ + \ + template <> \ + inline std::vector &getElements(union value_vector_union &vec) { \ + return *vec.SUFFIX; \ + } \ + \ + template <> \ + inline const std::vector &getConstElements( \ + const union value_vector_union &vec) { \ + return *vec.SUFFIX; \ } - inline size_t lweSize() { return sizes.back(); } +TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint64_t, u64) +TENSORDATA_SPECIALIZE_FOR_ITERATOR(int64_t, i64) +TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint32_t, u32) +TENSORDATA_SPECIALIZE_FOR_ITERATOR(int32_t, i32) +TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint16_t, u16) +TENSORDATA_SPECIALIZE_FOR_ITERATOR(int16_t, i16) +TENSORDATA_SPECIALIZE_FOR_ITERATOR(uint8_t, u8) +TENSORDATA_SPECIALIZE_FOR_ITERATOR(int8_t, i8) + +#define TENSORDATA_SPECIALIZE_VALUE_GETTER(ELTY, SUFFIX) \ + template <> \ + inline ELTY getElementValue(union value_vector_union &vec, size_t idx, \ + ElementType elementType) { \ + assert(elementType == ElementType::SUFFIX); \ + return (*vec.SUFFIX)[idx]; \ + } \ + \ + template <> \ + inline ELTY &getElementReference(union value_vector_union &vec, size_t idx, \ + ElementType elementType) { \ + assert(elementType == ElementType::SUFFIX); \ + return (*vec.SUFFIX)[idx]; \ + } \ + \ + template <> \ + inline ELTY *getElementPointer(union value_vector_union &vec, size_t idx, \ + ElementType elementType) { \ + assert(elementType == ElementType::SUFFIX); \ + return &(*vec.SUFFIX)[idx]; \ + } + +TENSORDATA_SPECIALIZE_VALUE_GETTER(uint64_t, u64) +TENSORDATA_SPECIALIZE_VALUE_GETTER(int64_t, i64) +TENSORDATA_SPECIALIZE_VALUE_GETTER(uint32_t, u32) +TENSORDATA_SPECIALIZE_VALUE_GETTER(int32_t, i32) +TENSORDATA_SPECIALIZE_VALUE_GETTER(uint16_t, u16) +TENSORDATA_SPECIALIZE_VALUE_GETTER(int16_t, i16) +TENSORDATA_SPECIALIZE_VALUE_GETTER(uint8_t, u8) +TENSORDATA_SPECIALIZE_VALUE_GETTER(int8_t, i8) + +} // namespace TensorData +} // namespace detail + +// Representation of a tensor with an arbitrary number of dimensions +class TensorData { +protected: + detail::TensorData::value_vector_union values; + ElementType elementType; + std::vector dimensions; + + /* Multi-dimensional, uninitialized, but preallocated tensor */ + void initPreallocated(llvm::ArrayRef dimensions, + ElementType elementType) { + assert(dimensions.size() != 0); + this->dimensions.resize(dimensions.size()); + + size_t n = getNumElements(dimensions); + + switch (elementType) { + case ElementType::u64: + this->values.u64 = new std::vector(n); + break; + case ElementType::i64: + this->values.i64 = new std::vector(n); + break; + case ElementType::u32: + this->values.u32 = new std::vector(n); + break; + case ElementType::i32: + this->values.i32 = new std::vector(n); + break; + case ElementType::u16: + this->values.u16 = new std::vector(n); + break; + case ElementType::i16: + this->values.i16 = new std::vector(n); + break; + case ElementType::u8: + this->values.u8 = new std::vector(n); + break; + case ElementType::i8: + this->values.i8 = new std::vector(n); + break; + } + this->elementType = elementType; + std::copy(dimensions.begin(), dimensions.end(), this->dimensions.begin()); + } + + // Creates a vector from an ArrayRef + template + static std::vector toDimSpec(llvm::ArrayRef dims) { + return std::vector(dims.begin(), dims.end()); + } + +public: + // Returns the total number of elements of a tensor with the + // specified dimensions + template static size_t getNumElements(T dimensions) { + size_t n = 1; + for (auto dim : dimensions) + n *= dim; + + return n; + } + + // Returns the number of bits of an integer capable of storing + // values with up to `elementWidth` bits. + static size_t storageWidth(size_t elementWidth) { + if (elementWidth > 64) { + assert(false && "Maximum supported element width is 64"); + } else if (elementWidth > 32) { + return 64; + } else if (elementWidth > 16) { + return 32; + } else if (elementWidth > 8) { + return 16; + } else { + return 8; + } + } + + // Move constructor. Leaves `that` uninitialized. + TensorData(TensorData &&that) + : elementType(that.elementType), dimensions(std::move(that.dimensions)) { + switch (that.elementType) { + case ElementType::u64: + this->values.u64 = that.values.u64; + that.values.u64 = nullptr; + break; + case ElementType::i64: + this->values.i64 = that.values.i64; + that.values.i64 = nullptr; + break; + case ElementType::u32: + this->values.u32 = that.values.u32; + that.values.u32 = nullptr; + break; + case ElementType::i32: + this->values.i32 = that.values.i32; + that.values.i32 = nullptr; + break; + case ElementType::u16: + this->values.u16 = that.values.u16; + that.values.u16 = nullptr; + break; + case ElementType::i16: + this->values.i16 = that.values.i16; + that.values.i16 = nullptr; + break; + case ElementType::u8: + this->values.u8 = that.values.u8; + that.values.u8 = nullptr; + break; + case ElementType::i8: + this->values.i8 = that.values.i8; + that.values.i8 = nullptr; + break; + } + } + + // Constructor to build a multi-dimensional tensor with the + // corresponding element type. All elements are initialized with the + // default value of `0`. + TensorData(llvm::ArrayRef dimensions, ElementType elementType) { + initPreallocated(dimensions, elementType); + } + + TensorData(llvm::ArrayRef dimensions, ElementType elementType) + : TensorData(toDimSpec(dimensions), elementType) {} + + // Constructor to build a multi-dimensional tensor with the element + // type corresponding to `elementWidth` and `sign`. The value for + // `elementWidth` must be a power of 2 of up to 64. All elements are + // initialized with the default value of `0`. + TensorData(llvm::ArrayRef dimensions, size_t elementWidth, + bool sign) { + switch (elementWidth) { + case 64: + initPreallocated(dimensions, + (sign) ? ElementType::i64 : ElementType::u64); + break; + case 32: + initPreallocated(dimensions, + (sign) ? ElementType::i32 : ElementType::u32); + break; + case 16: + initPreallocated(dimensions, + (sign) ? ElementType::i16 : ElementType::u16); + break; + case 8: + initPreallocated(dimensions, (sign) ? ElementType::i8 : ElementType::u8); + break; + default: + assert(false && "Element width must be 64, 32, 16 or 8 bits"); + } + } + + TensorData(llvm::ArrayRef dimensions, size_t elementWidth, bool sign) + : TensorData(toDimSpec(dimensions), elementWidth, sign) {} + +#define DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(ELTY, SUFFIX) \ + /* Multi-dimensional, initialized tensor, values copied from */ \ + /* `values` */ \ + TensorData(llvm::ArrayRef values, llvm::ArrayRef dimensions) \ + : dimensions(dimensions.begin(), dimensions.end()) { \ + assert(dimensions.size() != 0); \ + size_t n = getNumElements(dimensions); \ + this->values.SUFFIX = new std::vector(n); \ + this->elementType = ElementType::SUFFIX; \ + this->bulkAssign(values); \ + } \ + \ + /* One-dimensional, initialized tensor. Values are copied from */ \ + /* `values` */ \ + TensorData(llvm::ArrayRef values) \ + : TensorData(values, llvm::SmallVector{values.size()}) {} + + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint64_t, u64) + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int64_t, i64) + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint32_t, u32) + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int32_t, i32) + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint16_t, u16) + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int16_t, i16) + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(uint8_t, u8) + DEF_TENSOR_DATA_TENSOR_COSTRUCTORS(int8_t, i8) + + ~TensorData() { + switch (this->elementType) { + case ElementType::u64: + delete values.u64; + break; + case ElementType::i64: + delete values.i64; + break; + case ElementType::u32: + delete values.u32; + break; + case ElementType::i32: + delete values.i32; + break; + case ElementType::u16: + delete values.u16; + break; + case ElementType::i16: + delete values.i16; + break; + case ElementType::u8: + delete values.u8; + break; + case ElementType::i8: + delete values.i8; + break; + } + } + + // Returns the total number of elements of the tensor + size_t length() const { return getNumElements(this->dimensions); } + + // Returns a vector with the size for each dimension of the tensor + const std::vector &getDimensions() const { return this->dimensions; } + + // Returns the number of dimensions + size_t getRank() const { return this->dimensions.size(); } + + // Multi-dimensional access to a tensor element + template T &operator[](llvm::ArrayRef index) { + // Number of dimensions must match + assert(index.size() == dimensions.size()); + + int64_t offset = 0; + int64_t multiplier = 1; + for (int64_t i = index.size() - 1; i > 0; i--) { + offset += index[i] * multiplier; + multiplier *= this->dimensions[i]; + } + + return detail::TensorData::getElementReference(values, offset, + elementType); + } + + // Iterator pointing to the first element of a flat representation + // of the tensor. + template typename std::vector::iterator begin() { + return detail::TensorData::begin::iterator>(values); + } + + // Iterator pointing past the last element of a flat representation + // of the tensor. + template typename std::vector::iterator end() { + return detail::TensorData::end::iterator>(values); + } + + // Const iterator pointing to the first element of a flat + // representation of the tensor. + template typename std::vector::iterator cbegin() { + return detail::TensorData::cbegin::iterator>( + values); + } + + // Const iterator pointing past the last element of a flat + // representation of the tensor. + template typename std::vector::iterator cend() { + return detail::TensorData::cend::iterator>(values); + } + + // Flat representation of the const tensor + template const std::vector &getElements() const { + return detail::TensorData::getConstElements &>(values); + } + + // Flat representation of the tensor + template const std::vector &getElements() { + return detail::TensorData::getElements &>(values); + } + + // Returns the `index`-th value of a flat representation of the tensor + template T getElementValue(size_t index) { + return detail::TensorData::getElementValue(values, index, elementType); + } + + // Returns a reference to the `index`-th value of a flat + // representation of the tensor + template T &getElementReference(size_t index) { + return detail::TensorData::getElementReference(values, index, + elementType); + } + + // Returns a pointer to the `index`-th value of a flat + // representation of the tensor + template T *getElementPointer(size_t index) { + return detail::TensorData::getElementPointer(values, index, elementType); + } + + // Returns a void pointer to the `index`-th value of a flat + // representation of the tensor + void *getOpaqueElementPointer(size_t index) { + switch (this->elementType) { + case ElementType::u64: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + case ElementType::i64: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + case ElementType::u32: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + case ElementType::i32: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + case ElementType::u16: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + case ElementType::i16: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + case ElementType::u8: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + case ElementType::i8: + return reinterpret_cast( + detail::TensorData::getElementPointer(values, index, + elementType)); + } + + assert(false && "Unknown element type"); + } + + // Returns the element type of the tensor + ElementType getElementType() const { return this->elementType; } + + // Returns the size of a tensor element in bytes + size_t getElementSize() const { + switch (this->elementType) { + case ElementType::u64: + case ElementType::i64: + return 8; + case ElementType::u32: + case ElementType::i32: + return 4; + case ElementType::u16: + case ElementType::i16: + return 2; + case ElementType::u8: + case ElementType::i8: + return 1; + } + } + + // Returns `true` if elements are signed, otherwise `false` + bool getElementSignedness() const { + switch (this->elementType) { + case ElementType::u64: + case ElementType::u32: + case ElementType::u16: + case ElementType::u8: + return false; + case ElementType::i64: + case ElementType::i32: + case ElementType::i16: + case ElementType::i8: + return true; + } + } + + // Returns the width of an element in bits + size_t getElementWidth() const { + return getElementTypeWidth(this->elementType); + } + + // Returns the total number of elements of the tensor + size_t getNumElements() const { return getNumElements(this->dimensions); } + + // Copy all elements from `values` to the tensor. Note that this + // does not append values to the tensor, but overwrites existing + // values. + template void bulkAssign(llvm::ArrayRef values) { + assert(values.size() <= this->getNumElements()); + + switch (this->elementType) { + case ElementType::u64: + std::copy(values.begin(), values.end(), this->values.u64->begin()); + break; + case ElementType::i64: + std::copy(values.begin(), values.end(), this->values.i64->begin()); + break; + case ElementType::u32: + std::copy(values.begin(), values.end(), this->values.u32->begin()); + break; + case ElementType::i32: + std::copy(values.begin(), values.end(), this->values.i32->begin()); + break; + case ElementType::u16: + std::copy(values.begin(), values.end(), this->values.u16->begin()); + break; + case ElementType::i16: + std::copy(values.begin(), values.end(), this->values.i16->begin()); + break; + case ElementType::u8: + std::copy(values.begin(), values.end(), this->values.u8->begin()); + break; + case ElementType::i8: + std::copy(values.begin(), values.end(), this->values.i8->begin()); + break; + } + } + + // Copies all elements of a flat representation of the tensor to the + // positions starting with the iterator `start`. + template void copy(IT start) { + switch (this->elementType) { + case ElementType::u64: + std::copy(this->values.u64->begin(), this->values.u64->end(), start); + break; + case ElementType::i64: + std::copy(this->values.i64->begin(), this->values.i64->end(), start); + break; + case ElementType::u32: + std::copy(this->values.u32->begin(), this->values.u32->end(), start); + break; + case ElementType::i32: + std::copy(this->values.i32->begin(), this->values.i32->end(), start); + break; + case ElementType::u16: + std::copy(this->values.u16->begin(), this->values.u16->end(), start); + break; + case ElementType::i16: + std::copy(this->values.i16->begin(), this->values.i16->end(), start); + break; + case ElementType::u8: + std::copy(this->values.u8->begin(), this->values.u8->end(), start); + break; + case ElementType::i8: + std::copy(this->values.i8->begin(), this->values.i8->end(), start); + break; + } + } + + // Returns a flat representation of the tensor with elements + // converted to the type `T` + template std::vector asFlatVector() { + std::vector ret(getNumElements()); + this->copy(ret.begin()); + return ret; + } + + // Returns a void pointer to the first element of a flat + // representation of the tensor + void *getValuesAsOpaquePointer() { + switch (this->elementType) { + case ElementType::u64: + return static_cast(values.u64->data()); + case ElementType::i64: + return static_cast(values.i64->data()); + case ElementType::u32: + return static_cast(values.u32->data()); + case ElementType::i32: + return static_cast(values.i32->data()); + case ElementType::u16: + return static_cast(values.u16->data()); + case ElementType::i16: + return static_cast(values.i16->data()); + case ElementType::u8: + return static_cast(values.u8->data()); + case ElementType::i8: + return static_cast(values.i8->data()); + } + + assert(false && "Unhandled element type"); + } }; } // namespace clientlib } // namespace concretelang + #endif diff --git a/compiler/include/concretelang/ServerLib/DynamicRankCall.h b/compiler/include/concretelang/ServerLib/DynamicRankCall.h index bd72309e8..08c8fb8a0 100644 --- a/compiler/include/concretelang/ServerLib/DynamicRankCall.h +++ b/compiler/include/concretelang/ServerLib/DynamicRankCall.h @@ -16,7 +16,8 @@ namespace serverlib { using concretelang::clientlib::TensorData; TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...), - std::vector args, size_t rank); + std::vector args, size_t rank, + size_t element_width, bool is_signed); } // namespace serverlib } // namespace concretelang diff --git a/compiler/lib/ClientLib/EncryptedArguments.cpp b/compiler/lib/ClientLib/EncryptedArguments.cpp index c193c7085..f1d49d51d 100644 --- a/compiler/lib/ClientLib/EncryptedArguments.cpp +++ b/compiler/lib/ClientLib/EncryptedArguments.cpp @@ -31,32 +31,34 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) { preparedArgs.push_back((void *)arg); return outcome::success(); } + + std::vector shape = keySet.clientParameters().bufferShape(input); + // Allocate empty - ciphertextBuffers.resize(ciphertextBuffers.size() + 1); + ciphertextBuffers.emplace_back(shape, clientlib::EncryptedScalarElementType); TensorData &values_and_sizes = ciphertextBuffers.back(); - values_and_sizes.sizes = keySet.clientParameters().bufferShape(input); - values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input)); - OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg)); + + OUTCOME_TRYV(keySet.encrypt_lwe( + pos, values_and_sizes.getElementPointer(0), arg)); // Note: Since we bufferized lwe ciphertext take care of memref calling // convention // allocated preparedArgs.push_back(nullptr); // aligned - preparedArgs.push_back((void *)values_and_sizes.values.data()); + preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer()); // offset preparedArgs.push_back((void *)0); // sizes - for (auto size : values_and_sizes.sizes) { + for (auto size : values_and_sizes.getDimensions()) { preparedArgs.push_back((void *)size); } // strides - int64_t stride = values_and_sizes.length(); - for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) { - auto size = values_and_sizes.sizes[i]; + int64_t stride = TensorData::getNumElements(shape); + for (size_t size : values_and_sizes.getDimensions()) { stride = (size == 0 ? 0 : (stride / size)); preparedArgs.push_back((void *)stride); } - preparedArgs.push_back((void *)1); + return outcome::success(); } diff --git a/compiler/lib/ClientLib/PublicArguments.cpp b/compiler/lib/ClientLib/PublicArguments.cpp index 9bab5e433..8777f0ce6 100644 --- a/compiler/lib/ClientLib/PublicArguments.cpp +++ b/compiler/lib/ClientLib/PublicArguments.cpp @@ -41,11 +41,12 @@ PublicArguments::serialize(std::ostream &ostream) { "are not yet supported. Argument ") << iGate; } + /*auto allocated = */ preparedArgs[iPreparedArgs++]; auto aligned = (encrypted_scalars_t)preparedArgs[iPreparedArgs++]; assert(aligned != nullptr); auto offset = (size_t)preparedArgs[iPreparedArgs++]; - std::vector sizes; // includes lweSize as last dim + std::vector sizes; // includes lweSize as last dim sizes.resize(rank + 1); for (auto dim = 0u; dim < sizes.size(); dim++) { // sizes are part of the client parameters signature @@ -60,8 +61,13 @@ PublicArguments::serialize(std::ostream &ostream) { } // TODO: STRIDES auto values = aligned + offset; - serializeTensorData(sizes, values, ostream); + + serializeTensorDataRaw(sizes, + llvm::ArrayRef{ + values, TensorData::getNumElements(sizes)}, + ostream); } + return outcome::success(); } @@ -76,18 +82,24 @@ PublicArguments::unserializeArgs(std::istream &istream) { auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); std::vector sizes = gate.shape.dimensions; sizes.push_back(lweSize); - ciphertextBuffers.push_back(unserializeTensorData(sizes, istream)); + auto tdOrErr = unserializeTensorData(sizes, istream); + + if (tdOrErr.has_error()) + return tdOrErr.error(); + + ciphertextBuffers.push_back(std::move(tdOrErr.value())); auto &values_and_sizes = ciphertextBuffers.back(); + if (istream.fail()) { return StringError( "PublicArguments::unserializeArgs: Failed to read argument ") << iGate; } preparedArgs.push_back(/*allocated*/ nullptr); - preparedArgs.push_back((void *)values_and_sizes.values.data()); + preparedArgs.push_back(values_and_sizes.getValuesAsOpaquePointer()); preparedArgs.push_back(/*offset*/ 0); // sizes - for (auto size : values_and_sizes.sizes) { + for (auto size : values_and_sizes.getDimensions()) { preparedArgs.push_back((void *)size); } // strides has been removed by serialization @@ -120,10 +132,12 @@ PublicResult::unserialize(std::istream &istream) { auto lweSize = clientParameters.lweSecretKeyParam(gate).value().lweSize(); std::vector sizes = gate.shape.dimensions; sizes.push_back(lweSize); - buffers.push_back(unserializeTensorData(sizes, istream)); - if (istream.fail()) { - return StringError("Cannot read tensor data"); - } + auto tdOrErr = unserializeTensorData(sizes, istream); + + if (tdOrErr.has_error()) + return tdOrErr.error(); + + buffers.push_back(std::move(tdOrErr.value())); } return outcome::success(); } @@ -134,7 +148,7 @@ PublicResult::serialize(std::ostream &ostream) { return StringError( "PublicResult::serialize: ostream should be in binary mode"); } - for (auto tensorData : buffers) { + for (const TensorData &tensorData : buffers) { serializeTensorData(tensorData, ostream); if (ostream.fail()) { return StringError("Cannot write tensor data"); @@ -166,40 +180,85 @@ size_t global_index(size_t index[], size_t sizes[], size_t strides[], return g_index; } -TensorData tensorDataFromScalar(uint64_t value) { return {{value}, {1}}; } +TensorData tensorDataFromScalar(uint64_t value) { + return TensorData{llvm::ArrayRef{value}, {1}}; +} -TensorData tensorDataFromMemRef(size_t memref_rank, - encrypted_scalars_t allocated, - encrypted_scalars_t aligned, size_t offset, - size_t *sizes, size_t *strides) { - TensorData result; +static inline bool isReferenceToMLIRGlobalMemory(void *ptr) { + return reinterpret_cast(ptr) == 0xdeadbeef; +} + +template +TensorData tensorDataFromMemRefTyped(size_t memref_rank, void *allocatedVoid, + void *alignedVoid, size_t offset, + size_t *sizes, size_t *strides) { + T *allocated = reinterpret_cast(allocatedVoid); + T *aligned = reinterpret_cast(alignedVoid); + + // FIXME: handle sign correctly + TensorData result(llvm::ArrayRef{sizes, memref_rank}, sizeof(T) * 8, + false); assert(aligned != nullptr); - result.sizes.resize(memref_rank); - for (size_t r = 0; r < memref_rank; r++) { - result.sizes[r] = sizes[r]; - } + // ephemeral multi dim index to compute global strides size_t *index = new size_t[memref_rank]; for (size_t r = 0; r < memref_rank; r++) { index[r] = 0; } auto len = result.length(); - result.values.resize(len); + // TODO: add a fast path for dense result (no real strides) for (size_t i = 0; i < len; i++) { int g_index = offset + global_index(index, sizes, strides, memref_rank); - result.values[i] = aligned[g_index]; + result.getElementReference(i) = aligned[g_index]; next_coord_index(index, sizes, memref_rank); } delete[] index; // TEMPORARY: That quick and dirty but as this function is used only to // convert a result of the mlir program and as data are copied here, we // release the alocated pointer if it set. - if (allocated != nullptr) { + + if (allocated != nullptr && !isReferenceToMLIRGlobalMemory(allocated)) { free(allocated); } + return result; } +TensorData tensorDataFromMemRef(size_t memref_rank, size_t element_width, + bool is_signed, void *allocated, void *aligned, + size_t offset, size_t *sizes, size_t *strides) { + size_t storage_width = TensorData::storageWidth(element_width); + + switch (storage_width) { + case 64: + return (is_signed) + ? std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)) + : std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)); + case 32: + return (is_signed) + ? std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)) + : std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)); + case 16: + return (is_signed) + ? std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)) + : std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)); + case 8: + return (is_signed) + ? std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)) + : std::move(tensorDataFromMemRefTyped( + memref_rank, allocated, aligned, offset, sizes, strides)); + default: + assert(false); + } +} + } // namespace clientlib } // namespace concretelang diff --git a/compiler/lib/ClientLib/Serializers.cpp b/compiler/lib/ClientLib/Serializers.cpp index f57b03266..b90e65c17 100644 --- a/compiler/lib/ClientLib/Serializers.cpp +++ b/compiler/lib/ClientLib/Serializers.cpp @@ -149,70 +149,135 @@ std::ostream &operator<<(std::ostream &ostream, return ostream; } -std::ostream &serializeTensorData(uint64_t *values, size_t length, - std::ostream &ostream) { - if (incorrectMode(ostream)) { - return ostream; - } - writeSize(ostream, length); - for (size_t i = 0; i < length; i++) { - writeWord(ostream, values[i]); - } - return ostream; +template +static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes, + std::istream &istream) { + readWords(istream, values_and_sizes.getElementPointer(0), + values_and_sizes.getNumElements()); + + return istream; } -std::ostream &serializeTensorData(std::vector &sizes, uint64_t *values, +std::ostream &serializeTensorData(const TensorData &values_and_sizes, std::ostream &ostream) { - size_t length = 1; - for (auto size : sizes) { - length *= size; - writeSize(ostream, size); + switch (values_and_sizes.getElementType()) { + case ElementType::u64: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); + case ElementType::i64: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); + case ElementType::u32: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); + case ElementType::i32: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); + case ElementType::u16: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); + case ElementType::i16: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); + case ElementType::u8: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); + case ElementType::i8: + return serializeTensorDataRaw( + values_and_sizes.getDimensions(), + values_and_sizes.getElements(), ostream); } - serializeTensorData(values, length, ostream); - assert(ostream.good()); - return ostream; + + assert(false && "Unhandled element type"); } -std::ostream &serializeTensorData(TensorData &values_and_sizes, - std::ostream &ostream) { - std::vector &sizes = values_and_sizes.sizes; - encrypted_scalars_t values = values_and_sizes.values.data(); - return serializeTensorData(sizes, values, ostream); -} - -TensorData unserializeTensorData( +outcome::checked unserializeTensorData( std::vector &expectedSizes, // includes lweSize, unsigned to // accomodate non static sizes std::istream &istream) { - TensorData result; + if (incorrectMode(istream)) { - return result; + return StringError("Stream is in incorrect mode"); } - for (auto expectedSize : expectedSizes) { - size_t actualSize; - readSize(istream, actualSize); - if ((size_t)expectedSize != actualSize) { + + uint64_t numDimensions; + readWord(istream, numDimensions); + + std::vector dims; + + for (uint64_t i = 0; i < numDimensions; i++) { + int64_t dimSize; + readWord(istream, dimSize); + + if (dimSize != expectedSizes[i]) { istream.setstate(std::ios::badbit); + return StringError("Number of dimensions did not match the number of " + "expected dimensions"); } - assert(actualSize > 0); - result.sizes.push_back(actualSize); - assert(result.sizes.back() > 0); + + dims.push_back(dimSize); } - size_t expectedLen = result.length(); - assert(expectedLen > 0); - // TODO: full read in one step - size_t actualLen; - readSize(istream, actualLen); - if (expectedLen != actualLen) { - istream.setstate(std::ios::badbit); + + uint64_t elementWidth; + readWord(istream, elementWidth); + + switch (elementWidth) { + case 64: + case 32: + case 16: + case 8: + break; + default: + return StringError("Element width must be either 64, 32, 16 or 8, but got ") + << elementWidth; } - assert(actualLen == expectedLen); - result.values.resize(actualLen); - for (uint64_t &value : result.values) { - value = 0; - readWord(istream, value); + + uint8_t elementSignedness; + readWord(istream, elementSignedness); + + if (elementSignedness != 0 && elementSignedness != 1) { + return StringError("Numerical value for element signedness must be either " + "0 or 1, but got ") + << elementSignedness; } - return result; + + TensorData result(dims, elementWidth, elementSignedness == 1); + + switch (result.getElementType()) { + case ElementType::u64: + unserializeTensorDataElements(result, istream); + break; + case ElementType::i64: + unserializeTensorDataElements(result, istream); + break; + case ElementType::u32: + unserializeTensorDataElements(result, istream); + break; + case ElementType::i32: + unserializeTensorDataElements(result, istream); + break; + case ElementType::u16: + unserializeTensorDataElements(result, istream); + break; + case ElementType::i16: + unserializeTensorDataElements(result, istream); + break; + case ElementType::u8: + unserializeTensorDataElements(result, istream); + break; + case ElementType::i8: + unserializeTensorDataElements(result, istream); + break; + } + + return std::move(result); } std::ostream &operator<<(std::ostream &ostream, diff --git a/compiler/lib/ServerLib/DynamicRankCall.cpp b/compiler/lib/ServerLib/DynamicRankCall.cpp index 197820d6d..2c2f9e73e 100644 --- a/compiler/lib/ServerLib/DynamicRankCall.cpp +++ b/compiler/lib/ServerLib/DynamicRankCall.cpp @@ -33,170 +33,202 @@ template FnDstT convert_fnptr(FnSrcT src) { } TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...), - std::vector args, - size_t rank) { + std::vector args, size_t rank, + size_t element_width, bool is_signed) { using concretelang::clientlib::MemRefDescriptor; constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef; switch (rank) { case 1: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(1, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(1, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 2: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(2, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(2, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 3: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(3, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(3, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 4: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(4, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(4, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 5: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(5, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(5, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 6: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(6, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(6, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 7: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(7, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(7, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 8: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(8, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(8, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 9: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(9, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(9, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 10: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(10, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(10, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 11: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(11, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(11, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 12: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(12, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(12, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 13: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(13, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(13, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 14: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(14, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(14, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 15: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(15, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(15, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 16: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(16, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(16, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 17: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(17, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(17, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 18: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(18, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(18, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 19: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(19, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(19, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 20: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(20, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(20, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 21: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(21, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(21, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 22: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(22, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(22, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 23: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(23, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(23, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 24: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(24, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(24, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 25: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(25, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(25, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 26: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(26, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(26, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 27: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(27, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(27, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 28: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(28, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(28, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 29: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(29, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(29, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 30: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(30, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(30, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 31: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(31, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(31, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } case 32: { auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert(32, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert(32, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); } default: diff --git a/compiler/lib/ServerLib/ServerLambda.cpp b/compiler/lib/ServerLib/ServerLambda.cpp index 7786b54eb..f4f627a67 100644 --- a/compiler/lib/ServerLib/ServerLambda.cpp +++ b/compiler/lib/ServerLib/ServerLambda.cpp @@ -82,9 +82,17 @@ ServerLambda::call(PublicArguments &args, EvaluationKeys &evaluationKeys) { "ServerLambda::call is implemented for only one output"); auto output = args.clientParameters.outputs[0]; auto rank = args.clientParameters.bufferShape(output).size(); - auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank); - return clientlib::PublicResult::fromBuffers(clientParameters, {result}); - ; + + // FIXME: Handle sign correctly + size_t element_width = (output.isEncrypted()) ? 64 : output.shape.width; + auto result = multi_arity_call_dynamic_rank(func, preparedArgs, rank, + element_width, false); + + std::vector results; + results.push_back(std::move(result)); + + return clientlib::PublicResult::fromBuffers(clientParameters, + std::move(results)); } } // namespace serverlib diff --git a/compiler/lib/ServerLib/genDynamicRankCall.py b/compiler/lib/ServerLib/genDynamicRankCall.py index cc9959298..429db53ec 100644 --- a/compiler/lib/ServerLib/genDynamicRankCall.py +++ b/compiler/lib/ServerLib/genDynamicRankCall.py @@ -37,8 +37,8 @@ template FnDstT convert_fnptr(FnSrcT src) { } TensorData multi_arity_call_dynamic_rank(void *(*func)(void *...), - std::vector args, - size_t rank) { + std::vector args, size_t rank, + size_t element_width, bool is_signed) { using concretelang::clientlib::MemRefDescriptor; constexpr auto convert = concretelang::clientlib::tensorDataFromMemRef; switch (rank) {""") @@ -48,7 +48,8 @@ for tensor_rank in range(1, 33): print(f""" case {tensor_rank}: {{ auto m = multi_arity_call( convert_fnptr (*)(void *...)>(func), args); - return convert({memref_rank}, m.allocated, m.aligned, m.offset, m.sizes, m.strides); + return convert({memref_rank}, element_width, is_signed, m.allocated, m.aligned, + m.offset, m.sizes, m.strides); }}""") print(""" diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 7ae82089d..d0d42ef0a 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -101,7 +101,8 @@ JITLambda::call(clientlib::PublicArguments &args, return std::move(err); } std::vector buffers; - return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers); + return clientlib::PublicResult::fromBuffers(args.clientParameters, + std::move(buffers)); } #endif @@ -119,11 +120,13 @@ JITLambda::call(clientlib::PublicArguments &args, numOutputs += numArgOfRankedMemrefCallingConvention(shape.size()); } } - std::vector outputs(numOutputs); + std::vector outputs(numOutputs); + // Prepare the raw arguments of invokeRaw, i.e. a vector with pointer on // inputs and outputs. - std::vector rawArgs(args.preparedArgs.size() + 1 /*runtime context*/ + - outputs.size()); + std::vector rawArgs( + args.preparedArgs.size() + 1 /*runtime context*/ + 1 /* outputs */ + ); size_t i = 0; // Pointers on inputs for (auto &arg : args.preparedArgs) { @@ -136,10 +139,10 @@ JITLambda::call(clientlib::PublicArguments &args, // is passed to the compiled function. auto rtCtxPtr = &runtimeContext; rawArgs[i++] = &rtCtxPtr; - // Pointers on outputs - for (auto &out : outputs) { - rawArgs[i++] = &out; - } + + // Outputs + rawArgs[i++] = reinterpret_cast(outputs.data()); + // Invoke if (auto err = invokeRaw(rawArgs)) { return std::move(err); @@ -165,12 +168,20 @@ JITLambda::call(clientlib::PublicArguments &args, outputOffset += rank; size_t *strides = (size_t *)&outputs[outputOffset]; outputOffset += rank; + + size_t elementWidth = (output.isEncrypted()) + ? clientlib::EncryptedScalarElementWidth + : output.shape.width; + + // FIXME: Handle sign correctly buffers.push_back(clientlib::tensorDataFromMemRef( - rank, allocated, aligned, offset, sizes, strides)); + rank, elementWidth, false, allocated, aligned, offset, sizes, + strides)); } } } - return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers); + return clientlib::PublicResult::fromBuffers(args.clientParameters, + std::move(buffers)); } } // namespace concretelang diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_clear_tensor.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_clear_tensor.cc index c678603a7..29749a106 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_clear_tensor.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_clear_tensor.cc @@ -5,7 +5,30 @@ // 1D tensor ////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(End2EndJit_ClearTensor_1D, DISABLED_identity) { +TEST(End2EndJit_ClearTensor_2D, constant_i8) { + checkedJit(lambda, + R"XXX( +func.func @main() -> tensor<2x2xi8> { + %cst = arith.constant dense<[[0, 1], [2, 3]]> : tensor<2x2xi8> + return %cst : tensor<2x2xi8> +} +)XXX", + "main", true); + + llvm::Expected> res = + lambda.operator()>(); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (size_t)4); + + EXPECT_EQ((*res)[0], 0); + EXPECT_EQ((*res)[1], 1); + EXPECT_EQ((*res)[2], 2); + EXPECT_EQ((*res)[3], 3); +} + +TEST(End2EndJit_ClearTensor_1D, identity) { checkedJit(lambda, R"XXX( func.func @main(%t: tensor<10xi64>) -> tensor<10xi64> { @@ -37,6 +60,29 @@ func.func @main(%t: tensor<10xi64>) -> tensor<10xi64> { } } +TEST(End2EndJit_ClearTensor_1D, identity_i8) { + checkedJit(lambda, + R"XXX( +func.func @main(%t: tensor<10xi8>) -> tensor<10xi8> { + return %t : tensor<10xi8> +} +)XXX", + "main", true); + + uint8_t arg[]{16, 21, 3, 127, 9, 17, 32, 18, 29, 104}; + + llvm::Expected> res = + lambda.operator()>(arg, ARRAY_SIZE(arg)); + + ASSERT_EXPECTED_SUCCESS(res); + + ASSERT_EQ(res->size(), (size_t)10); + + for (size_t i = 0; i < res->size(); i++) { + EXPECT_EQ(arg[i], res->operator[](i)) << "result differ at index " << i; + } +} + TEST(End2EndJit_ClearTensor_1D, extract_64) { checkedJit(lambda, R"XXX( func.func @main(%t: tensor<10xi64>, %i: index) -> i64{