fix(compiler): Add support for clear result tensors with element width != 64 bits

Returning tensors with elements whose width is not equal to 64 results
in garbled data. This commit extends the `TensorData` class used to
represent tensors in JIT compilation with support for signed /
unsigned elements of 8/16/32 and 64 bits, such that all clear text
tensors with up to 64 bits can be represented accurately.
This commit is contained in:
Andi Drebes
2022-09-09 16:04:09 +02:00
committed by rudy-6-4
parent f1833f06f2
commit 8255d3e190
13 changed files with 1048 additions and 197 deletions

View File

@@ -31,32 +31,34 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
preparedArgs.push_back((void *)arg);
return outcome::success();
}
std::vector<int64_t> shape = keySet.clientParameters().bufferShape(input);
// Allocate empty
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
ciphertextBuffers.emplace_back(shape, clientlib::EncryptedScalarElementType);
TensorData &values_and_sizes = ciphertextBuffers.back();
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input));
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
OUTCOME_TRYV(keySet.encrypt_lwe(
pos, values_and_sizes.getElementPointer<decrypted_scalar_t>(0), arg));
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back((void *)values_and_sizes.values.data());
preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer());
// offset
preparedArgs.push_back((void *)0);
// sizes
for (auto size : values_and_sizes.sizes) {
for (auto size : values_and_sizes.getDimensions()) {
preparedArgs.push_back((void *)size);
}
// strides
int64_t stride = values_and_sizes.length();
for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) {
auto size = values_and_sizes.sizes[i];
int64_t stride = TensorData::getNumElements(shape);
for (size_t size : values_and_sizes.getDimensions()) {
stride = (size == 0 ? 0 : (stride / size));
preparedArgs.push_back((void *)stride);
}
preparedArgs.push_back((void *)1);
return outcome::success();
}