enhance(testlib): Fix the runtime testlib tools to handle the ciphertext bufferization and the new compiler concrete bufferized API

This commit is contained in:
Quentin Bourgerie
2022-02-11 14:31:52 +01:00
committed by Quentin Bourgerie
parent b1d6b7e653
commit 9627864d23
4 changed files with 84 additions and 65 deletions

View File

@@ -13,8 +13,7 @@ namespace concretelang {
Arguments::~Arguments() {
for (auto ct : allocatedCiphertexts) {
int err;
free_lwe_ciphertext_u64(&err, ct);
free(ct);
}
for (auto ctBuffer : ciphertextBuffers) {
free(ctBuffer);
@@ -46,15 +45,28 @@ llvm::Error Arguments::pushArg(uint64_t arg) {
return llvm::Error::success();
}
// encrypted scalar: allocate, encrypt and push
LweCiphertext_u64 *ctArg;
if (auto err = keySet.allocate_lwe(pos, &ctArg)) {
uint64_t *ctArg;
uint64_t ctSize = 0;
if (auto err = keySet.allocate_lwe(pos, &ctArg, ctSize)) {
return err;
}
allocatedCiphertexts.push_back(ctArg);
if (auto err = keySet.encrypt_lwe(pos, ctArg, arg)) {
return err;
}
preparedArgs.push_back((void *)ctArg);
// Note: Since we bufferized lwe ciphertext take care of memref calling
// convention
// allocated
preparedArgs.push_back(nullptr);
// aligned
preparedArgs.push_back(ctArg);
// offset
preparedArgs.push_back((void *)0);
// size
preparedArgs.push_back((void *)ctSize);
// stride
preparedArgs.push_back((void *)1);
return llvm::Error::success();
}
@@ -106,16 +118,16 @@ llvm::Error Arguments::pushArg(size_t width, void *data,
const uint8_t *data8 = (const uint8_t *)data;
// Allocate a buffer for ciphertexts of size of tensor
auto ctBuffer = (LweCiphertext_u64 **)malloc(input.shape.size *
sizeof(LweCiphertext_u64 *));
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
auto ctBuffer =
(uint64_t *)malloc(input.shape.size * lweSize * sizeof(uint64_t));
ciphertextBuffers.push_back(ctBuffer);
// Allocate ciphertexts and encrypt, for every values in tensor
for (size_t i = 0; i < input.shape.size; i++) {
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
return err;
}
allocatedCiphertexts.push_back(ctBuffer[i]);
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
for (size_t i = 0, offset = 0; i < input.shape.size;
i++, offset += lweSize) {
if (auto err =
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
return err;
}
}
@@ -132,10 +144,24 @@ llvm::Error Arguments::pushArg(size_t width, void *data,
for (size_t i = 0; i < shape.size(); i++) {
preparedArgs.push_back((void *)shape[i]);
}
// strides - FIXME make it works
// strides is an array of size equals to numDim
for (size_t i = 0; i < shape.size(); i++) {
preparedArgs.push_back((void *)0);
// If encrypted +1 for the lwe size rank
if (keySet.isInputEncrypted(pos)) {
preparedArgs.push_back(
(void *)(keySet.getInputLweSecretKeyParam(pos).size + 1));
}
// Set the stride for each dimension, equal to the product of the
// following dimensions.
int64_t stride = 1;
// If encrypted +1 set the stride for the lwe size rank
if (keySet.isInputEncrypted(pos)) {
stride *= keySet.getInputLweSecretKeyParam(pos).size + 1;
}
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
preparedArgs.push_back((void *)stride);
stride *= shape[i];
}
if (keySet.isInputEncrypted(pos)) {
preparedArgs.push_back((void *)1);
}
return llvm::Error::success();
}