mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
enhance(testlib): Fix the runtime testlib tools to handle the ciphertext bufferization and the new compiler concrete bufferized API
This commit is contained in:
committed by
Quentin Bourgerie
parent
b1d6b7e653
commit
9627864d23
@@ -91,9 +91,9 @@ private:
|
||||
std::vector<void *> preparedArgs;
|
||||
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
|
||||
std::vector<uint64_t *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
std::vector<uint64_t *> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
RuntimeContext context;
|
||||
|
||||
@@ -89,7 +89,7 @@ protected:
|
||||
|
||||
ClientParameters clientParameters;
|
||||
std::shared_ptr<KeySet> keySet;
|
||||
void *(*func)(void *...);
|
||||
void *func;
|
||||
// Retain module and open shared lib alive
|
||||
std::shared_ptr<DynamicModule> module;
|
||||
};
|
||||
|
||||
@@ -13,8 +13,7 @@ namespace concretelang {
|
||||
|
||||
Arguments::~Arguments() {
|
||||
for (auto ct : allocatedCiphertexts) {
|
||||
int err;
|
||||
free_lwe_ciphertext_u64(&err, ct);
|
||||
free(ct);
|
||||
}
|
||||
for (auto ctBuffer : ciphertextBuffers) {
|
||||
free(ctBuffer);
|
||||
@@ -46,15 +45,28 @@ llvm::Error Arguments::pushArg(uint64_t arg) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// encrypted scalar: allocate, encrypt and push
|
||||
LweCiphertext_u64 *ctArg;
|
||||
if (auto err = keySet.allocate_lwe(pos, &ctArg)) {
|
||||
uint64_t *ctArg;
|
||||
uint64_t ctSize = 0;
|
||||
if (auto err = keySet.allocate_lwe(pos, &ctArg, ctSize)) {
|
||||
return err;
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
if (auto err = keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return err;
|
||||
}
|
||||
preparedArgs.push_back((void *)ctArg);
|
||||
// Note: Since we bufferized lwe ciphertext take care of memref calling
|
||||
// convention
|
||||
// allocated
|
||||
preparedArgs.push_back(nullptr);
|
||||
// aligned
|
||||
preparedArgs.push_back(ctArg);
|
||||
// offset
|
||||
preparedArgs.push_back((void *)0);
|
||||
// size
|
||||
preparedArgs.push_back((void *)ctSize);
|
||||
// stride
|
||||
preparedArgs.push_back((void *)1);
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -106,16 +118,16 @@ llvm::Error Arguments::pushArg(size_t width, void *data,
|
||||
const uint8_t *data8 = (const uint8_t *)data;
|
||||
|
||||
// Allocate a buffer for ciphertexts of size of tensor
|
||||
auto ctBuffer = (LweCiphertext_u64 **)malloc(input.shape.size *
|
||||
sizeof(LweCiphertext_u64 *));
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
auto ctBuffer =
|
||||
(uint64_t *)malloc(input.shape.size * lweSize * sizeof(uint64_t));
|
||||
ciphertextBuffers.push_back(ctBuffer);
|
||||
// Allocate ciphertexts and encrypt, for every values in tensor
|
||||
for (size_t i = 0; i < input.shape.size; i++) {
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
|
||||
return err;
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctBuffer[i]);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
|
||||
for (size_t i = 0, offset = 0; i < input.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
|
||||
if (auto err =
|
||||
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
|
||||
return err;
|
||||
}
|
||||
}
|
||||
@@ -132,10 +144,24 @@ llvm::Error Arguments::pushArg(size_t width, void *data,
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
preparedArgs.push_back((void *)shape[i]);
|
||||
}
|
||||
// strides - FIXME make it works
|
||||
// strides is an array of size equals to numDim
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
preparedArgs.push_back((void *)0);
|
||||
// If encrypted +1 for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
preparedArgs.push_back(
|
||||
(void *)(keySet.getInputLweSecretKeyParam(pos).size + 1));
|
||||
}
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = 1;
|
||||
// If encrypted +1 set the stride for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
stride *= keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
}
|
||||
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
|
||||
preparedArgs.push_back((void *)stride);
|
||||
stride *= shape[i];
|
||||
}
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
preparedArgs.push_back((void *)1);
|
||||
}
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -14,27 +14,27 @@ namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
template <size_t N> struct MemRefDescriptor {
|
||||
LweCiphertext_u64 **allocated;
|
||||
LweCiphertext_u64 **aligned;
|
||||
uint64_t *allocated;
|
||||
uint64_t *aligned;
|
||||
size_t offset;
|
||||
size_t sizes[N];
|
||||
size_t strides[N];
|
||||
};
|
||||
|
||||
llvm::Expected<std::vector<uint64_t>> decryptSlice(LweCiphertext_u64 **aligned,
|
||||
KeySet &keySet, size_t start,
|
||||
size_t size,
|
||||
size_t stride = 1) {
|
||||
stride = (stride == 0) ? 1 : stride;
|
||||
llvm::Expected<std::vector<uint64_t>>
|
||||
decryptSlice(KeySet &keySet, uint64_t *aligned, size_t size) {
|
||||
auto pos = 0;
|
||||
std::vector<uint64_t> result(size);
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
size_t offset = start + i * stride;
|
||||
auto err = keySet.decrypt_lwe(0, aligned[offset], result[i]);
|
||||
size_t offset = i * lweSize;
|
||||
auto err = keySet.decrypt_lwe(pos, aligned + offset, result[i]);
|
||||
if (err) {
|
||||
return StreamStringError()
|
||||
<< "cannot decrypt result #" << i << ", err:" << err;
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -53,8 +53,7 @@ DynamicLambda::load(std::shared_ptr<DynamicModule> module,
|
||||
DynamicLambda lambda;
|
||||
lambda.module =
|
||||
module; // prevent module and library handler from being destroyed
|
||||
lambda.func =
|
||||
(void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str());
|
||||
lambda.func = dlsym(module->libraryHandle, funcName.c_str());
|
||||
|
||||
if (auto err = dlerror()) {
|
||||
return StreamStringError("Cannot open lambda: ") << err;
|
||||
@@ -93,13 +92,13 @@ llvm::Expected<uint64_t> invoke<uint64_t>(DynamicLambda &lambda,
|
||||
return StreamStringError("the function doesn't return a scalar");
|
||||
}
|
||||
// Scalar encrypted result
|
||||
auto fCasted = (LweCiphertext_u64 * (*)(void *...))(lambda.func);
|
||||
;
|
||||
LweCiphertext_u64 *lweResult =
|
||||
auto fCasted = (MemRefDescriptor<1>(*)(void *...))(lambda.func);
|
||||
MemRefDescriptor<1> lweResult =
|
||||
mlir::concretelang::call(fCasted, args.preparedArgs);
|
||||
|
||||
uint64_t decryptedResult;
|
||||
if (auto err = lambda.keySet->decrypt_lwe(0, lweResult, decryptedResult)) {
|
||||
if (auto err =
|
||||
lambda.keySet->decrypt_lwe(0, lweResult.aligned, decryptedResult)) {
|
||||
return std::move(err);
|
||||
}
|
||||
return decryptedResult;
|
||||
@@ -112,15 +111,15 @@ DynamicLambda::invokeMemRefDecriptor(const Arguments &args) {
|
||||
if (output.shape.size == 0) {
|
||||
return StreamStringError("the function doesn't return a tensor");
|
||||
}
|
||||
if (output.shape.dimensions.size() != Rank) {
|
||||
if (output.shape.dimensions.size() != Rank - 1) {
|
||||
return StreamStringError("the function doesn't return a tensor of rank ")
|
||||
<< Rank;
|
||||
<< Rank - 1;
|
||||
}
|
||||
// Tensor encrypted result
|
||||
auto fCasted = (MemRefDescriptor<Rank>(*)(void *...))(func);
|
||||
auto encryptedResult = mlir::concretelang::call(fCasted, args.preparedArgs);
|
||||
|
||||
for (size_t dim = 0; dim < Rank; dim++) {
|
||||
for (size_t dim = 0; dim < Rank - 1; dim++) {
|
||||
size_t actual_size = encryptedResult.sizes[dim];
|
||||
size_t expected_size = output.shape.dimensions[dim];
|
||||
if (actual_size != expected_size) {
|
||||
@@ -134,35 +133,32 @@ DynamicLambda::invokeMemRefDecriptor(const Arguments &args) {
|
||||
template <>
|
||||
llvm::Expected<std::vector<uint64_t>>
|
||||
invoke<std::vector<uint64_t>>(DynamicLambda &lambda, const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<1>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
auto &keySet = lambda.keySet;
|
||||
return decryptSlice(encryptedResult.aligned, *keySet, encryptedResult.offset,
|
||||
encryptedResult.sizes[0], encryptedResult.strides[0]);
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<uint64_t>>>
|
||||
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<2>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
auto &keySet = lambda.keySet;
|
||||
return decryptSlice(*keySet, encryptedResult.aligned,
|
||||
encryptedResult.sizes[0]);
|
||||
}
|
||||
|
||||
template <>
|
||||
llvm::Expected<std::vector<std::vector<uint64_t>>>
|
||||
invoke<std::vector<std::vector<uint64_t>>>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
auto &encryptedResult = encryptedResultOrErr.get();
|
||||
|
||||
std::vector<std::vector<uint64_t>> result;
|
||||
result.reserve(encryptedResult.sizes[0]);
|
||||
for (size_t i = 0; i < encryptedResult.sizes[0]; i++) {
|
||||
// TODO : strides
|
||||
int offset = encryptedResult.offset + i * encryptedResult.sizes[1];
|
||||
auto slice =
|
||||
decryptSlice(encryptedResult.aligned, *keySet, offset,
|
||||
encryptedResult.sizes[1], encryptedResult.strides[1]);
|
||||
int offset = encryptedResult.offset + i * encryptedResult.strides[1];
|
||||
auto slice = decryptSlice(*lambda.keySet, encryptedResult.aligned + offset,
|
||||
encryptedResult.sizes[1]);
|
||||
if (!slice) {
|
||||
return StreamStringError(llvm::toString(slice.takeError()));
|
||||
}
|
||||
@@ -175,7 +171,7 @@ template <>
|
||||
llvm::Expected<std::vector<std::vector<std::vector<uint64_t>>>>
|
||||
invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
|
||||
const Arguments &args) {
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args);
|
||||
auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<4>(args);
|
||||
if (!encryptedResultOrErr) {
|
||||
return encryptedResultOrErr.takeError();
|
||||
}
|
||||
@@ -188,13 +184,10 @@ invoke<std::vector<std::vector<std::vector<uint64_t>>>>(DynamicLambda &lambda,
|
||||
std::vector<std::vector<uint64_t>> result1;
|
||||
result1.reserve(encryptedResult.sizes[1]);
|
||||
for (size_t j = 0; j < encryptedResult.sizes[1]; j++) {
|
||||
// TODO : strides
|
||||
int offset = encryptedResult.offset +
|
||||
i * encryptedResult.sizes[1] * encryptedResult.sizes[2] +
|
||||
j * encryptedResult.sizes[2];
|
||||
auto slice =
|
||||
decryptSlice(encryptedResult.aligned, *keySet, offset,
|
||||
encryptedResult.sizes[2], encryptedResult.strides[2]);
|
||||
int offset = encryptedResult.offset + (i * encryptedResult.sizes[1] + j) *
|
||||
encryptedResult.strides[1];
|
||||
auto slice = decryptSlice(*keySet, encryptedResult.aligned + offset,
|
||||
encryptedResult.sizes[2]);
|
||||
if (!slice) {
|
||||
return StreamStringError(llvm::toString(slice.takeError()));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user