From 9627864d2387ece8af1ffdcd33a5079bf4e1a971 Mon Sep 17 00:00:00 2001 From: Quentin Bourgerie Date: Fri, 11 Feb 2022 14:31:52 +0100 Subject: [PATCH] enhance(testlib): Fix the runtime testlib tools to handle the ciphertext bufferization and the new compiler concrete bufferized API --- .../include/concretelang/TestLib/Arguments.h | 4 +- .../concretelang/TestLib/DynamicLambda.h | 2 +- compiler/lib/TestLib/Arguments.cpp | 60 ++++++++++---- compiler/lib/TestLib/DynamicLambda.cpp | 83 +++++++++---------- 4 files changed, 84 insertions(+), 65 deletions(-) diff --git a/compiler/include/concretelang/TestLib/Arguments.h b/compiler/include/concretelang/TestLib/Arguments.h index f7e58a47d..25c356e3c 100644 --- a/compiler/include/concretelang/TestLib/Arguments.h +++ b/compiler/include/concretelang/TestLib/Arguments.h @@ -91,9 +91,9 @@ private: std::vector preparedArgs; // Store allocated lwe ciphertexts (for free) - std::vector allocatedCiphertexts; + std::vector allocatedCiphertexts; // Store buffers of ciphertexts - std::vector ciphertextBuffers; + std::vector ciphertextBuffers; KeySet &keySet; RuntimeContext context; diff --git a/compiler/include/concretelang/TestLib/DynamicLambda.h b/compiler/include/concretelang/TestLib/DynamicLambda.h index 80d98d567..6b41b30aa 100644 --- a/compiler/include/concretelang/TestLib/DynamicLambda.h +++ b/compiler/include/concretelang/TestLib/DynamicLambda.h @@ -89,7 +89,7 @@ protected: ClientParameters clientParameters; std::shared_ptr keySet; - void *(*func)(void *...); + void *func; // Retain module and open shared lib alive std::shared_ptr module; }; diff --git a/compiler/lib/TestLib/Arguments.cpp b/compiler/lib/TestLib/Arguments.cpp index bf796d0ae..f4ccaf562 100644 --- a/compiler/lib/TestLib/Arguments.cpp +++ b/compiler/lib/TestLib/Arguments.cpp @@ -13,8 +13,7 @@ namespace concretelang { Arguments::~Arguments() { for (auto ct : allocatedCiphertexts) { - int err; - free_lwe_ciphertext_u64(&err, ct); + free(ct); } for (auto ctBuffer : ciphertextBuffers) { free(ctBuffer); @@ -46,15 +45,28 @@ llvm::Error Arguments::pushArg(uint64_t arg) { return llvm::Error::success(); } // encrypted scalar: allocate, encrypt and push - LweCiphertext_u64 *ctArg; - if (auto err = keySet.allocate_lwe(pos, &ctArg)) { + uint64_t *ctArg; + uint64_t ctSize = 0; + if (auto err = keySet.allocate_lwe(pos, &ctArg, ctSize)) { return err; } allocatedCiphertexts.push_back(ctArg); if (auto err = keySet.encrypt_lwe(pos, ctArg, arg)) { return err; } - preparedArgs.push_back((void *)ctArg); + // Note: Since we bufferized lwe ciphertext take care of memref calling + // convention + // allocated + preparedArgs.push_back(nullptr); + // aligned + preparedArgs.push_back(ctArg); + // offset + preparedArgs.push_back((void *)0); + // size + preparedArgs.push_back((void *)ctSize); + // stride + preparedArgs.push_back((void *)1); + return llvm::Error::success(); } @@ -106,16 +118,16 @@ llvm::Error Arguments::pushArg(size_t width, void *data, const uint8_t *data8 = (const uint8_t *)data; // Allocate a buffer for ciphertexts of size of tensor - auto ctBuffer = (LweCiphertext_u64 **)malloc(input.shape.size * - sizeof(LweCiphertext_u64 *)); + auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1; + auto ctBuffer = + (uint64_t *)malloc(input.shape.size * lweSize * sizeof(uint64_t)); ciphertextBuffers.push_back(ctBuffer); // Allocate ciphertexts and encrypt, for every values in tensor - for (size_t i = 0; i < input.shape.size; i++) { - if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) { - return err; - } - allocatedCiphertexts.push_back(ctBuffer[i]); - if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) { + for (size_t i = 0, offset = 0; i < input.shape.size; + i++, offset += lweSize) { + + if (auto err = + this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) { return err; } } @@ -132,10 +144,24 @@ llvm::Error Arguments::pushArg(size_t width, void *data, for (size_t i = 0; i < shape.size(); i++) { preparedArgs.push_back((void *)shape[i]); } - // strides - FIXME make it works - // strides is an array of size equals to numDim - for (size_t i = 0; i < shape.size(); i++) { - preparedArgs.push_back((void *)0); + // If encrypted +1 for the lwe size rank + if (keySet.isInputEncrypted(pos)) { + preparedArgs.push_back( + (void *)(keySet.getInputLweSecretKeyParam(pos).size + 1)); + } + // Set the stride for each dimension, equal to the product of the + // following dimensions. + int64_t stride = 1; + // If encrypted +1 set the stride for the lwe size rank + if (keySet.isInputEncrypted(pos)) { + stride *= keySet.getInputLweSecretKeyParam(pos).size + 1; + } + for (ssize_t i = shape.size() - 1; i >= 0; i--) { + preparedArgs.push_back((void *)stride); + stride *= shape[i]; + } + if (keySet.isInputEncrypted(pos)) { + preparedArgs.push_back((void *)1); } return llvm::Error::success(); } diff --git a/compiler/lib/TestLib/DynamicLambda.cpp b/compiler/lib/TestLib/DynamicLambda.cpp index b91cea706..816839683 100644 --- a/compiler/lib/TestLib/DynamicLambda.cpp +++ b/compiler/lib/TestLib/DynamicLambda.cpp @@ -14,27 +14,27 @@ namespace mlir { namespace concretelang { template struct MemRefDescriptor { - LweCiphertext_u64 **allocated; - LweCiphertext_u64 **aligned; + uint64_t *allocated; + uint64_t *aligned; size_t offset; size_t sizes[N]; size_t strides[N]; }; -llvm::Expected> decryptSlice(LweCiphertext_u64 **aligned, - KeySet &keySet, size_t start, - size_t size, - size_t stride = 1) { - stride = (stride == 0) ? 1 : stride; +llvm::Expected> +decryptSlice(KeySet &keySet, uint64_t *aligned, size_t size) { + auto pos = 0; std::vector result(size); + auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1; for (size_t i = 0; i < size; i++) { - size_t offset = start + i * stride; - auto err = keySet.decrypt_lwe(0, aligned[offset], result[i]); + size_t offset = i * lweSize; + auto err = keySet.decrypt_lwe(pos, aligned + offset, result[i]); if (err) { return StreamStringError() << "cannot decrypt result #" << i << ", err:" << err; } } + return result; } @@ -53,8 +53,7 @@ DynamicLambda::load(std::shared_ptr module, DynamicLambda lambda; lambda.module = module; // prevent module and library handler from being destroyed - lambda.func = - (void *(*)(void *, ...))dlsym(module->libraryHandle, funcName.c_str()); + lambda.func = dlsym(module->libraryHandle, funcName.c_str()); if (auto err = dlerror()) { return StreamStringError("Cannot open lambda: ") << err; @@ -93,13 +92,13 @@ llvm::Expected invoke(DynamicLambda &lambda, return StreamStringError("the function doesn't return a scalar"); } // Scalar encrypted result - auto fCasted = (LweCiphertext_u64 * (*)(void *...))(lambda.func); - ; - LweCiphertext_u64 *lweResult = + auto fCasted = (MemRefDescriptor<1>(*)(void *...))(lambda.func); + MemRefDescriptor<1> lweResult = mlir::concretelang::call(fCasted, args.preparedArgs); uint64_t decryptedResult; - if (auto err = lambda.keySet->decrypt_lwe(0, lweResult, decryptedResult)) { + if (auto err = + lambda.keySet->decrypt_lwe(0, lweResult.aligned, decryptedResult)) { return std::move(err); } return decryptedResult; @@ -112,15 +111,15 @@ DynamicLambda::invokeMemRefDecriptor(const Arguments &args) { if (output.shape.size == 0) { return StreamStringError("the function doesn't return a tensor"); } - if (output.shape.dimensions.size() != Rank) { + if (output.shape.dimensions.size() != Rank - 1) { return StreamStringError("the function doesn't return a tensor of rank ") - << Rank; + << Rank - 1; } // Tensor encrypted result auto fCasted = (MemRefDescriptor(*)(void *...))(func); auto encryptedResult = mlir::concretelang::call(fCasted, args.preparedArgs); - for (size_t dim = 0; dim < Rank; dim++) { + for (size_t dim = 0; dim < Rank - 1; dim++) { size_t actual_size = encryptedResult.sizes[dim]; size_t expected_size = output.shape.dimensions[dim]; if (actual_size != expected_size) { @@ -134,35 +133,32 @@ DynamicLambda::invokeMemRefDecriptor(const Arguments &args) { template <> llvm::Expected> invoke>(DynamicLambda &lambda, const Arguments &args) { - auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<1>(args); - if (!encryptedResultOrErr) { - return encryptedResultOrErr.takeError(); - } - auto &encryptedResult = encryptedResultOrErr.get(); - auto &keySet = lambda.keySet; - return decryptSlice(encryptedResult.aligned, *keySet, encryptedResult.offset, - encryptedResult.sizes[0], encryptedResult.strides[0]); -} - -template <> -llvm::Expected>> -invoke>>(DynamicLambda &lambda, - const Arguments &args) { auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<2>(args); if (!encryptedResultOrErr) { return encryptedResultOrErr.takeError(); } auto &encryptedResult = encryptedResultOrErr.get(); auto &keySet = lambda.keySet; + return decryptSlice(*keySet, encryptedResult.aligned, + encryptedResult.sizes[0]); +} + +template <> +llvm::Expected>> +invoke>>(DynamicLambda &lambda, + const Arguments &args) { + auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args); + if (!encryptedResultOrErr) { + return encryptedResultOrErr.takeError(); + } + auto &encryptedResult = encryptedResultOrErr.get(); std::vector> result; result.reserve(encryptedResult.sizes[0]); for (size_t i = 0; i < encryptedResult.sizes[0]; i++) { - // TODO : strides - int offset = encryptedResult.offset + i * encryptedResult.sizes[1]; - auto slice = - decryptSlice(encryptedResult.aligned, *keySet, offset, - encryptedResult.sizes[1], encryptedResult.strides[1]); + int offset = encryptedResult.offset + i * encryptedResult.strides[1]; + auto slice = decryptSlice(*lambda.keySet, encryptedResult.aligned + offset, + encryptedResult.sizes[1]); if (!slice) { return StreamStringError(llvm::toString(slice.takeError())); } @@ -175,7 +171,7 @@ template <> llvm::Expected>>> invoke>>>(DynamicLambda &lambda, const Arguments &args) { - auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<3>(args); + auto encryptedResultOrErr = lambda.invokeMemRefDecriptor<4>(args); if (!encryptedResultOrErr) { return encryptedResultOrErr.takeError(); } @@ -188,13 +184,10 @@ invoke>>>(DynamicLambda &lambda, std::vector> result1; result1.reserve(encryptedResult.sizes[1]); for (size_t j = 0; j < encryptedResult.sizes[1]; j++) { - // TODO : strides - int offset = encryptedResult.offset + - i * encryptedResult.sizes[1] * encryptedResult.sizes[2] + - j * encryptedResult.sizes[2]; - auto slice = - decryptSlice(encryptedResult.aligned, *keySet, offset, - encryptedResult.sizes[2], encryptedResult.strides[2]); + int offset = encryptedResult.offset + (i * encryptedResult.sizes[1] + j) * + encryptedResult.strides[1]; + auto slice = decryptSlice(*keySet, encryptedResult.aligned + offset, + encryptedResult.sizes[2]); if (!slice) { return StreamStringError(llvm::toString(slice.takeError())); }