mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(compiler): Add support for clear result tensors with element width != 64 bits
Returning tensors with elements whose width is not equal to 64 results in garbled data. This commit extends the `TensorData` class used to represent tensors in JIT compilation with support for signed / unsigned elements of 8/16/32 and 64 bits, such that all clear text tensors with up to 64 bits can be represented accurately.
This commit is contained in:
@@ -149,70 +149,135 @@ std::ostream &operator<<(std::ostream &ostream,
|
||||
return ostream;
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(uint64_t *values, size_t length,
|
||||
std::ostream &ostream) {
|
||||
if (incorrectMode(ostream)) {
|
||||
return ostream;
|
||||
}
|
||||
writeSize(ostream, length);
|
||||
for (size_t i = 0; i < length; i++) {
|
||||
writeWord(ostream, values[i]);
|
||||
}
|
||||
return ostream;
|
||||
template <typename T>
|
||||
static std::istream &unserializeTensorDataElements(TensorData &values_and_sizes,
|
||||
std::istream &istream) {
|
||||
readWords(istream, values_and_sizes.getElementPointer<T>(0),
|
||||
values_and_sizes.getNumElements());
|
||||
|
||||
return istream;
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(std::vector<int64_t> &sizes, uint64_t *values,
|
||||
std::ostream &serializeTensorData(const TensorData &values_and_sizes,
|
||||
std::ostream &ostream) {
|
||||
size_t length = 1;
|
||||
for (auto size : sizes) {
|
||||
length *= size;
|
||||
writeSize(ostream, size);
|
||||
switch (values_and_sizes.getElementType()) {
|
||||
case ElementType::u64:
|
||||
return serializeTensorDataRaw<uint64_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint64_t>(), ostream);
|
||||
case ElementType::i64:
|
||||
return serializeTensorDataRaw<int64_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int64_t>(), ostream);
|
||||
case ElementType::u32:
|
||||
return serializeTensorDataRaw<uint32_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint32_t>(), ostream);
|
||||
case ElementType::i32:
|
||||
return serializeTensorDataRaw<int32_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int32_t>(), ostream);
|
||||
case ElementType::u16:
|
||||
return serializeTensorDataRaw<uint16_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint16_t>(), ostream);
|
||||
case ElementType::i16:
|
||||
return serializeTensorDataRaw<int16_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int16_t>(), ostream);
|
||||
case ElementType::u8:
|
||||
return serializeTensorDataRaw<uint8_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<uint8_t>(), ostream);
|
||||
case ElementType::i8:
|
||||
return serializeTensorDataRaw<int8_t>(
|
||||
values_and_sizes.getDimensions(),
|
||||
values_and_sizes.getElements<int8_t>(), ostream);
|
||||
}
|
||||
serializeTensorData(values, length, ostream);
|
||||
assert(ostream.good());
|
||||
return ostream;
|
||||
|
||||
assert(false && "Unhandled element type");
|
||||
}
|
||||
|
||||
std::ostream &serializeTensorData(TensorData &values_and_sizes,
|
||||
std::ostream &ostream) {
|
||||
std::vector<int64_t> &sizes = values_and_sizes.sizes;
|
||||
encrypted_scalars_t values = values_and_sizes.values.data();
|
||||
return serializeTensorData(sizes, values, ostream);
|
||||
}
|
||||
|
||||
TensorData unserializeTensorData(
|
||||
outcome::checked<TensorData, StringError> unserializeTensorData(
|
||||
std::vector<int64_t> &expectedSizes, // includes lweSize, unsigned to
|
||||
// accomodate non static sizes
|
||||
std::istream &istream) {
|
||||
TensorData result;
|
||||
|
||||
if (incorrectMode(istream)) {
|
||||
return result;
|
||||
return StringError("Stream is in incorrect mode");
|
||||
}
|
||||
for (auto expectedSize : expectedSizes) {
|
||||
size_t actualSize;
|
||||
readSize(istream, actualSize);
|
||||
if ((size_t)expectedSize != actualSize) {
|
||||
|
||||
uint64_t numDimensions;
|
||||
readWord(istream, numDimensions);
|
||||
|
||||
std::vector<size_t> dims;
|
||||
|
||||
for (uint64_t i = 0; i < numDimensions; i++) {
|
||||
int64_t dimSize;
|
||||
readWord(istream, dimSize);
|
||||
|
||||
if (dimSize != expectedSizes[i]) {
|
||||
istream.setstate(std::ios::badbit);
|
||||
return StringError("Number of dimensions did not match the number of "
|
||||
"expected dimensions");
|
||||
}
|
||||
assert(actualSize > 0);
|
||||
result.sizes.push_back(actualSize);
|
||||
assert(result.sizes.back() > 0);
|
||||
|
||||
dims.push_back(dimSize);
|
||||
}
|
||||
size_t expectedLen = result.length();
|
||||
assert(expectedLen > 0);
|
||||
// TODO: full read in one step
|
||||
size_t actualLen;
|
||||
readSize(istream, actualLen);
|
||||
if (expectedLen != actualLen) {
|
||||
istream.setstate(std::ios::badbit);
|
||||
|
||||
uint64_t elementWidth;
|
||||
readWord(istream, elementWidth);
|
||||
|
||||
switch (elementWidth) {
|
||||
case 64:
|
||||
case 32:
|
||||
case 16:
|
||||
case 8:
|
||||
break;
|
||||
default:
|
||||
return StringError("Element width must be either 64, 32, 16 or 8, but got ")
|
||||
<< elementWidth;
|
||||
}
|
||||
assert(actualLen == expectedLen);
|
||||
result.values.resize(actualLen);
|
||||
for (uint64_t &value : result.values) {
|
||||
value = 0;
|
||||
readWord(istream, value);
|
||||
|
||||
uint8_t elementSignedness;
|
||||
readWord(istream, elementSignedness);
|
||||
|
||||
if (elementSignedness != 0 && elementSignedness != 1) {
|
||||
return StringError("Numerical value for element signedness must be either "
|
||||
"0 or 1, but got ")
|
||||
<< elementSignedness;
|
||||
}
|
||||
return result;
|
||||
|
||||
TensorData result(dims, elementWidth, elementSignedness == 1);
|
||||
|
||||
switch (result.getElementType()) {
|
||||
case ElementType::u64:
|
||||
unserializeTensorDataElements<uint64_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i64:
|
||||
unserializeTensorDataElements<int64_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u32:
|
||||
unserializeTensorDataElements<uint32_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i32:
|
||||
unserializeTensorDataElements<int32_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u16:
|
||||
unserializeTensorDataElements<uint16_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i16:
|
||||
unserializeTensorDataElements<int16_t>(result, istream);
|
||||
break;
|
||||
case ElementType::u8:
|
||||
unserializeTensorDataElements<uint8_t>(result, istream);
|
||||
break;
|
||||
case ElementType::i8:
|
||||
unserializeTensorDataElements<int8_t>(result, istream);
|
||||
break;
|
||||
}
|
||||
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
std::ostream &operator<<(std::ostream &ostream,
|
||||
|
||||
Reference in New Issue
Block a user