mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(compiler): Add support for clear result tensors with element width != 64 bits
Returning tensors with elements whose width is not equal to 64 results in garbled data. This commit extends the `TensorData` class used to represent tensors in JIT compilation with support for signed / unsigned elements of 8/16/32 and 64 bits, such that all clear text tensors with up to 64 bits can be represented accurately.
This commit is contained in:
@@ -31,32 +31,34 @@ EncryptedArguments::pushArg(uint64_t arg, KeySet &keySet) {
|
||||
preparedArgs.push_back((void *)arg);
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
std::vector<int64_t> shape = keySet.clientParameters().bufferShape(input);
|
||||
|
||||
// Allocate empty
|
||||
ciphertextBuffers.resize(ciphertextBuffers.size() + 1);
|
||||
ciphertextBuffers.emplace_back(shape, clientlib::EncryptedScalarElementType);
|
||||
TensorData &values_and_sizes = ciphertextBuffers.back();
|
||||
values_and_sizes.sizes = keySet.clientParameters().bufferShape(input);
|
||||
values_and_sizes.values.resize(keySet.clientParameters().bufferSize(input));
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(pos, values_and_sizes.values.data(), arg));
|
||||
|
||||
OUTCOME_TRYV(keySet.encrypt_lwe(
|
||||
pos, values_and_sizes.getElementPointer<decrypted_scalar_t>(0), arg));
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back((void *)values_and_sizes.values.data());
|
||||
preparedArgs.push_back((void *)values_and_sizes.getValuesAsOpaquePointer());
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// sizes
|
||||
for (auto size : values_and_sizes.sizes) {
|
||||
for (auto size : values_and_sizes.getDimensions()) {
|
||||
preparedArgs.push_back((void *)size);
|
||||
}
|
||||
// strides
|
||||
int64_t stride = values_and_sizes.length();
|
||||
for (size_t i = 0; i < values_and_sizes.sizes.size() - 1; i++) {
|
||||
auto size = values_and_sizes.sizes[i];
|
||||
int64_t stride = TensorData::getNumElements(shape);
|
||||
for (size_t size : values_and_sizes.getDimensions()) {
|
||||
stride = (size == 0 ? 0 : (stride / size));
|
||||
preparedArgs.push_back((void *)stride);
|
||||
}
|
||||
preparedArgs.push_back((void *)1);
|
||||
|
||||
return outcome::success();
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user