mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-15 07:05:09 -05:00
enhance(runtime): Fix the runtime tools to handle the ciphertext bufferization and the new compiler concrete bufferized API
This commit is contained in:
committed by
Quentin Bourgerie
parent
8a9cce64e3
commit
4e8a9d1077
@@ -12,9 +12,9 @@ pip install pybind11
|
||||
Build concrete library:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/zama-ai/concrete
|
||||
cd concrete
|
||||
git checkout feature/core_c_api
|
||||
git clone https://github.com/zama-ai/concrete_internal
|
||||
cd concrete_internal
|
||||
git checkout compiler_c_api
|
||||
cd concrete-ffi
|
||||
RUSTFLAGS="-C target-cpu=native" cargo build --release
|
||||
```
|
||||
@@ -23,7 +23,7 @@ Generate the compiler build system, in the `build` directory
|
||||
|
||||
```sh
|
||||
export LLVM_PROJECT="PATH_TO_LLVM_PROJECT"
|
||||
export CONCRETE_PROJECT="PATH_TO_CONCRETE_PROJECT"
|
||||
export CONCRETE_PROJECT="PATH_TO_CONCRETE_INTERNAL_PROJECT"
|
||||
make build-initialized
|
||||
```
|
||||
|
||||
|
||||
@@ -37,16 +37,36 @@ public:
|
||||
|
||||
// isInputEncrypted return true if the input at the given pos is encrypted.
|
||||
bool isInputEncrypted(size_t pos);
|
||||
// allocate a lwe ciphertext for the argument at argPos.
|
||||
llvm::Error allocate_lwe(size_t argPos, LweCiphertext_u64 **ciphertext);
|
||||
|
||||
// getInputLweSecretKeyParam returns the parameters of the lwe secret key for
|
||||
// the input at the given `pos`.
|
||||
// The input must be encrupted
|
||||
LweSecretKeyParam getInputLweSecretKeyParam(size_t pos) {
|
||||
auto gate = inputGate(pos);
|
||||
auto inputSk = this->secretKeys.find(gate.encryption->secretKeyID);
|
||||
return inputSk->second.first;
|
||||
}
|
||||
|
||||
// getOutputLweSecretKeyParam returns the parameters of the lwe secret key for
|
||||
// the given output.
|
||||
LweSecretKeyParam getOutputLweSecretKeyParam(size_t pos) {
|
||||
auto gate = outputGate(pos);
|
||||
auto outputSk = this->secretKeys.find(gate.encryption->secretKeyID);
|
||||
return outputSk->second.first;
|
||||
}
|
||||
|
||||
// allocate a lwe ciphertext buffer for the argument at argPos, set the size
|
||||
// of the allocated buffer.
|
||||
llvm::Error allocate_lwe(size_t argPos, uint64_t **ciphertext,
|
||||
uint64_t &size);
|
||||
|
||||
// encrypt the input to the ciphertext for the argument at argPos.
|
||||
llvm::Error encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
uint64_t input);
|
||||
llvm::Error encrypt_lwe(size_t argPos, uint64_t *ciphertext, uint64_t input);
|
||||
|
||||
// isOuputEncrypted return true if the output at the given pos is encrypted.
|
||||
bool isOutputEncrypted(size_t pos);
|
||||
// decrypt the ciphertext to the output for the argument at argPos.
|
||||
llvm::Error decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
llvm::Error decrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
uint64_t &output);
|
||||
|
||||
size_t numInputs() { return inputs.size(); }
|
||||
|
||||
@@ -20,10 +20,9 @@ typedef struct RuntimeContext {
|
||||
std::map<std::string, LweBootstrapKey_u64 *> bsk;
|
||||
|
||||
~RuntimeContext() {
|
||||
int err;
|
||||
for (const auto &key : bsk) {
|
||||
if (key.first != "_concretelang_base_context_bsk")
|
||||
free_lwe_bootstrap_key_u64(&err, key.second);
|
||||
free_lwe_bootstrap_key_u64(key.second);
|
||||
}
|
||||
}
|
||||
} RuntimeContext;
|
||||
|
||||
@@ -8,10 +8,48 @@
|
||||
|
||||
#include "concrete-ffi.h"
|
||||
|
||||
ForeignPlaintextList_u64 *
|
||||
runtime_foreign_plaintext_list_u64(int *err, uint64_t *allocated,
|
||||
uint64_t *aligned, uint64_t offset,
|
||||
uint64_t size_dim0, uint64_t stride_dim0,
|
||||
uint64_t size, uint32_t precision);
|
||||
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
|
||||
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
|
||||
uint64_t stride, uint32_t precision);
|
||||
|
||||
void memref_add_lwe_ciphertexts_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *ct1_allocated, uint64_t *ct1_aligned,
|
||||
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride);
|
||||
|
||||
void memref_add_plaintext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t plaintext);
|
||||
|
||||
void memref_mul_cleartext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t cleartext);
|
||||
|
||||
void memref_negate_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride);
|
||||
|
||||
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride);
|
||||
|
||||
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
struct GlweCiphertext_u64 *accumulator);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -96,13 +96,13 @@ public:
|
||||
// Store the values of outputs
|
||||
std::vector<void *> outputs;
|
||||
// Store the input gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> inputGates;
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> inputGates;
|
||||
// Store the outputs gates description and the offset of the argument.
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offet*/>> outputGates;
|
||||
std::vector<std::tuple<CircuitGate, size_t /*offset*/>> outputGates;
|
||||
// Store allocated lwe ciphertexts (for free)
|
||||
std::vector<LweCiphertext_u64 *> allocatedCiphertexts;
|
||||
std::vector<uint64_t *> allocatedCiphertexts;
|
||||
// Store buffers of ciphertexts
|
||||
std::vector<LweCiphertext_u64 **> ciphertextBuffers;
|
||||
std::vector<uint64_t *> ciphertextBuffers;
|
||||
|
||||
KeySet &keySet;
|
||||
RuntimeContext context;
|
||||
|
||||
@@ -6,31 +6,20 @@
|
||||
#include "concretelang/ClientLib/KeySet.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
|
||||
#define CAPI_ERR_TO_LLVM_ERROR(s, msg) \
|
||||
{ \
|
||||
int err; \
|
||||
s; \
|
||||
if (err != 0) { \
|
||||
return llvm::make_error<llvm::StringError>( \
|
||||
msg, llvm::inconvertibleErrorCode()); \
|
||||
} \
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
|
||||
KeySet::~KeySet() {
|
||||
int err;
|
||||
for (auto it : secretKeys) {
|
||||
free_lwe_secret_key_u64(&err, it.second.second);
|
||||
free_lwe_secret_key_u64(it.second.second);
|
||||
}
|
||||
for (auto it : bootstrapKeys) {
|
||||
free_lwe_bootstrap_key_u64(&err, it.second.second);
|
||||
free_lwe_bootstrap_key_u64(it.second.second);
|
||||
}
|
||||
for (auto it : keyswitchKeys) {
|
||||
free_lwe_keyswitch_key_u64(&err, it.second.second);
|
||||
free_lwe_keyswitch_key_u64(it.second.second);
|
||||
}
|
||||
free_encryption_generator(&err, encryptionRandomGenerator);
|
||||
free_encryption_generator(encryptionRandomGenerator);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<KeySet>>
|
||||
@@ -97,10 +86,8 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
}
|
||||
}
|
||||
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate encryption generator");
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(seed_msb, seed_lsb);
|
||||
|
||||
return llvm::Error::success();
|
||||
}
|
||||
@@ -108,13 +95,11 @@ llvm::Error KeySet::setupEncryptionMaterial(ClientParameters ¶ms,
|
||||
llvm::Error KeySet::generateKeysFromParams(ClientParameters ¶ms,
|
||||
uint64_t seed_msb,
|
||||
uint64_t seed_lsb) {
|
||||
|
||||
{
|
||||
// Generate LWE secret keys
|
||||
SecretRandomGenerator *generator;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
generator = allocate_secret_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate random generator");
|
||||
|
||||
generator = allocate_secret_generator(seed_msb, seed_lsb);
|
||||
for (auto secretKeyParam : params.secretKeys) {
|
||||
auto e = this->generateSecretKey(secretKeyParam.first,
|
||||
secretKeyParam.second, generator);
|
||||
@@ -122,14 +107,12 @@ llvm::Error KeySet::generateKeysFromParams(ClientParameters ¶ms,
|
||||
return std::move(e);
|
||||
}
|
||||
}
|
||||
CAPI_ERR_TO_LLVM_ERROR(free_secret_generator(&err, generator),
|
||||
"cannot free random generator");
|
||||
free_secret_generator(generator);
|
||||
}
|
||||
// Allocate the encryption random generator
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(&err, seed_msb, seed_lsb),
|
||||
"cannot allocate encryption generator");
|
||||
|
||||
this->encryptionRandomGenerator =
|
||||
allocate_encryption_generator(seed_msb, seed_lsb);
|
||||
// Generate bootstrap and keyswitch keys
|
||||
{
|
||||
for (auto bootstrapKeyParam : params.bootstrapKeys) {
|
||||
@@ -170,12 +153,9 @@ llvm::Error KeySet::generateSecretKey(LweSecretKeyID id,
|
||||
LweSecretKeyParam param,
|
||||
SecretRandomGenerator *generator) {
|
||||
LweSecretKey_u64 *sk;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
sk = allocate_lwe_secret_key_u64(&err, {param.size + 1}),
|
||||
"cannot allocate secret key");
|
||||
sk = allocate_lwe_secret_key_u64({param.size});
|
||||
|
||||
CAPI_ERR_TO_LLVM_ERROR(fill_lwe_secret_key_u64(&err, sk, generator),
|
||||
"cannot fill secret key with random generator");
|
||||
fill_lwe_secret_key_u64(sk, generator);
|
||||
|
||||
secretKeys[id] = {param, sk};
|
||||
|
||||
@@ -207,11 +187,9 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
|
||||
|
||||
uint64_t polynomialSize = total_dimension / param.glweDimension;
|
||||
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
bsk = allocate_lwe_bootstrap_key_u64(
|
||||
&err, {param.level}, {param.baseLog}, {param.glweDimension + 1},
|
||||
{inputSk->second.first.size + 1}, {polynomialSize}),
|
||||
"cannot allocate bootstrap key");
|
||||
bsk = allocate_lwe_bootstrap_key_u64(
|
||||
{param.level}, {param.baseLog}, {param.glweDimension},
|
||||
{inputSk->second.first.size}, {polynomialSize});
|
||||
|
||||
// Store the bootstrap key
|
||||
bootstrapKeys[id] = {param, bsk};
|
||||
@@ -219,23 +197,16 @@ llvm::Error KeySet::generateBootstrapKey(BootstrapKeyID id,
|
||||
// Convert the output lwe key to glwe key
|
||||
GlweSecretKey_u64 *glwe_sk;
|
||||
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
glwe_sk = allocate_glwe_secret_key_u64(&err, {param.glweDimension + 1},
|
||||
{polynomialSize}),
|
||||
"cannot allocate glwe key for initiliazation of bootstrap key");
|
||||
glwe_sk =
|
||||
allocate_glwe_secret_key_u64({param.glweDimension}, {polynomialSize});
|
||||
|
||||
CAPI_ERR_TO_LLVM_ERROR(fill_glwe_secret_key_with_lwe_secret_key_u64(
|
||||
&err, glwe_sk, outputSk->second.second),
|
||||
"cannot fill glwe key with big key");
|
||||
fill_glwe_secret_key_with_lwe_secret_key_u64(glwe_sk,
|
||||
outputSk->second.second);
|
||||
|
||||
// Initialize the bootstrap key
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
fill_lwe_bootstrap_key_u64(&err, bsk, inputSk->second.second, glwe_sk,
|
||||
generator, {param.variance}),
|
||||
"cannot fill bootstrap key");
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
free_glwe_secret_key_u64(&err, glwe_sk),
|
||||
"cannot free glwe key for initiliazation of bootstrap key")
|
||||
fill_lwe_bootstrap_key_u64(bsk, inputSk->second.second, glwe_sk, generator,
|
||||
{param.variance});
|
||||
free_glwe_secret_key_u64(glwe_sk);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -257,33 +228,32 @@ llvm::Error KeySet::generateKeyswitchKey(KeyswitchKeyID id,
|
||||
}
|
||||
// Allocate the keyswitch key
|
||||
LweKeyswitchKey_u64 *ksk;
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
ksk = allocate_lwe_keyswitch_key_u64(&err, {param.level}, {param.baseLog},
|
||||
{inputSk->second.first.size + 1},
|
||||
{outputSk->second.first.size + 1}),
|
||||
"cannot allocate keyswitch key");
|
||||
|
||||
ksk = allocate_lwe_keyswitch_key_u64({param.level}, {param.baseLog},
|
||||
{inputSk->second.first.size},
|
||||
{outputSk->second.first.size});
|
||||
|
||||
// Store the keyswitch key
|
||||
keyswitchKeys[id] = {param, ksk};
|
||||
// Initialize the keyswitch key
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
fill_lwe_keyswitch_key_u64(&err, ksk, inputSk->second.second,
|
||||
outputSk->second.second, generator,
|
||||
{param.variance}),
|
||||
"cannot fill bootsrap key");
|
||||
|
||||
fill_lwe_keyswitch_key_u64(ksk, inputSk->second.second,
|
||||
outputSk->second.second, generator,
|
||||
{param.variance});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::allocate_lwe(size_t argPos,
|
||||
LweCiphertext_u64 **ciphertext) {
|
||||
llvm::Error KeySet::allocate_lwe(size_t argPos, uint64_t **ciphertext,
|
||||
uint64_t &size) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
"allocate_lwe position of argument is too high",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
auto inputSk = inputs[argPos];
|
||||
CAPI_ERR_TO_LLVM_ERROR(*ciphertext = allocate_lwe_ciphertext_u64(
|
||||
&err, {std::get<1>(inputSk).size + 1}),
|
||||
"cannot allocate ciphertext");
|
||||
|
||||
size = std::get<1>(inputSk).size + 1;
|
||||
*ciphertext = (uint64_t *)malloc(sizeof(uint64_t) * size);
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -297,7 +267,7 @@ bool KeySet::isOutputEncrypted(size_t argPos) {
|
||||
std::get<0>(outputs[argPos]).encryption.hasValue();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
llvm::Error KeySet::encrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
uint64_t input) {
|
||||
if (argPos >= inputs.size()) {
|
||||
return llvm::make_error<llvm::StringError>(
|
||||
@@ -311,19 +281,15 @@ llvm::Error KeySet::encrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Encode - TODO we could check if the input value is in the right range
|
||||
Plaintext_u64 plaintext = {
|
||||
input << (64 -
|
||||
(std::get<0>(inputSk).encryption->encoding.precision + 1))};
|
||||
// Encrypt
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
encrypt_lwe_u64(&err, std::get<2>(inputSk), ciphertext, plaintext,
|
||||
encryptionRandomGenerator,
|
||||
{std::get<0>(inputSk).encryption->variance}),
|
||||
"cannot encrypt");
|
||||
uint64_t plaintext =
|
||||
input << (64 - (std::get<0>(inputSk).encryption->encoding.precision + 1));
|
||||
encrypt_lwe_u64(std::get<2>(inputSk), ciphertext, plaintext,
|
||||
encryptionRandomGenerator,
|
||||
{std::get<0>(inputSk).encryption->variance});
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
llvm::Error KeySet::decrypt_lwe(size_t argPos, uint64_t *ciphertext,
|
||||
uint64_t &output) {
|
||||
|
||||
if (argPos >= outputs.size()) {
|
||||
@@ -337,14 +303,10 @@ llvm::Error KeySet::decrypt_lwe(size_t argPos, LweCiphertext_u64 *ciphertext,
|
||||
"decrypt_lwe: the positional argument is not encrypted",
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
// Decrypt
|
||||
Plaintext_u64 plaintext = {0};
|
||||
CAPI_ERR_TO_LLVM_ERROR(
|
||||
decrypt_lwe_u64(&err, std::get<2>(outputSk), ciphertext, &plaintext),
|
||||
"cannot decrypt");
|
||||
uint64_t plaintext = decrypt_lwe_u64(std::get<2>(outputSk), ciphertext);
|
||||
// Decode
|
||||
size_t precision = std::get<0>(outputSk).encryption->encoding.precision;
|
||||
output = plaintext._0 >> (64 - precision - 2);
|
||||
output = plaintext >> (64 - precision - 2);
|
||||
size_t carry = output % 2;
|
||||
output = ((output >> 1) + carry) % (1 << (precision + 1));
|
||||
return llvm::Error::success();
|
||||
|
||||
@@ -142,13 +142,10 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
auto contextType =
|
||||
mlir::concretelang::Concrete::ContextType::get(rewriter.getContext());
|
||||
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
|
||||
// Insert forward declaration of allocate lwe ciphertext
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
rewriter.getIndexType(),
|
||||
},
|
||||
|
||||
@@ -163,7 +160,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
@@ -179,7 +175,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericPlaintextType,
|
||||
@@ -195,7 +190,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
genericCleartextType,
|
||||
@@ -211,7 +205,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType, genericLweCiphertextType, genericLweCiphertextType}, {});
|
||||
{genericLweCiphertextType, genericLweCiphertextType}, {});
|
||||
if (insertForwardDeclaration(op, rewriter, "negate_lwe_ciphertext_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
@@ -231,7 +225,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
genericBSKType,
|
||||
genericLweCiphertextType,
|
||||
genericLweCiphertextType,
|
||||
@@ -256,7 +249,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
// ksk
|
||||
genericKSKType,
|
||||
// output ct
|
||||
@@ -274,7 +266,6 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{
|
||||
errType,
|
||||
rewriter.getI32Type(),
|
||||
rewriter.getI32Type(),
|
||||
},
|
||||
@@ -287,9 +278,9 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
}
|
||||
// Insert forward declaration of the alloc_plaintext_list function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{errType, rewriter.getI32Type()},
|
||||
{genericPlaintextListType});
|
||||
auto funcType =
|
||||
mlir::FunctionType::get(rewriter.getContext(), {rewriter.getI32Type()},
|
||||
{genericPlaintextListType});
|
||||
if (insertForwardDeclaration(op, rewriter, "allocate_plaintext_list_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
@@ -300,7 +291,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType, genericPlaintextListType, genericForeignPlaintextList}, {});
|
||||
{genericPlaintextListType, genericForeignPlaintextList}, {});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "fill_plaintext_list_with_expansion_u64", funcType)
|
||||
.failed()) {
|
||||
@@ -310,7 +301,7 @@ mlir::LogicalResult insertForwardDeclarations(mlir::Operation *op,
|
||||
// Insert forward declaration of the add_plaintext_list_glwe function
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(rewriter.getContext(),
|
||||
{errType, genericGlweCiphertextType,
|
||||
{genericGlweCiphertextType,
|
||||
genericGlweCiphertextType,
|
||||
genericPlaintextListType},
|
||||
{});
|
||||
@@ -356,24 +347,20 @@ struct ConcreteOpToConcreteCAPICallPattern : public mlir::OpRewritePattern<Op> {
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
// Replace the operation with a call to the `funcName`
|
||||
{
|
||||
// Create the err value
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
// Get the size from the dimension
|
||||
int64_t lweDimension = lweResultType.getDimension();
|
||||
int64_t lweSize = lweDimension + 1;
|
||||
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweSize));
|
||||
|
||||
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweDimension));
|
||||
// Add the call to the allocation
|
||||
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
|
||||
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
|
||||
auto allocGeneric = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), allocName,
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
|
||||
// Construct operands for the operation.
|
||||
// errOp doesn't need to be casted to something generic, allocGeneric
|
||||
// already is. All the rest will be converted if needed
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{errOp,
|
||||
allocGeneric.getResult(0)};
|
||||
mlir::SmallVector<mlir::Value, 4> newOperands{allocGeneric.getResult(0)};
|
||||
for (mlir::Value operand : op->getOperands()) {
|
||||
mlir::Type operandType = operand.getType();
|
||||
mlir::Type castedType = getGenericType(operandType);
|
||||
@@ -420,16 +407,13 @@ struct ConcreteZeroOpPattern
|
||||
mlir::Type resultType = op->getResultTypes().front();
|
||||
auto lweResultType =
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>();
|
||||
// Create the err value
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
// Get the size from the dimension
|
||||
int64_t lweDimension = lweResultType.getDimension();
|
||||
int64_t lweSize = lweDimension + 1;
|
||||
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweSize));
|
||||
|
||||
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweDimension));
|
||||
// Allocate a fresh new ciphertext
|
||||
mlir::SmallVector<mlir::Value> allocOperands{errOp, lweSizeOp};
|
||||
mlir::SmallVector<mlir::Value> allocOperands{lweDimensionOp};
|
||||
auto allocGeneric = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocOperands);
|
||||
@@ -506,7 +490,6 @@ struct GlweFromTableOpPattern
|
||||
matchAndRewrite(mlir::concretelang::Concrete::GlweFromTable op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
ConcreteToConcreteCAPITypeConverter typeConverter;
|
||||
auto errType = mlir::IndexType::get(rewriter.getContext());
|
||||
|
||||
// TODO: move this to insertForwardDeclarations
|
||||
// issue: can't define function with tensor<*xtype> that accept ranked
|
||||
@@ -516,28 +499,22 @@ struct GlweFromTableOpPattern
|
||||
{
|
||||
auto funcType = mlir::FunctionType::get(
|
||||
rewriter.getContext(),
|
||||
{errType, op->getOperandTypes().front(), rewriter.getI64Type(),
|
||||
rewriter.getI32Type()},
|
||||
{op->getOperandTypes().front(), rewriter.getI32Type()},
|
||||
{getGenericForeignPlaintextListType(rewriter.getContext())});
|
||||
if (insertForwardDeclaration(
|
||||
op, rewriter, "runtime_foreign_plaintext_list_u64", funcType)
|
||||
if (insertForwardDeclaration(op, rewriter,
|
||||
"memref_runtime_foreign_plaintext_list_u64",
|
||||
funcType)
|
||||
.failed()) {
|
||||
return mlir::failure();
|
||||
}
|
||||
}
|
||||
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
// Get the size from the dimension
|
||||
int64_t glweDimension =
|
||||
op->getAttr("glweDimension").cast<mlir::IntegerAttr>().getInt();
|
||||
int64_t glweSize = glweDimension + 1;
|
||||
mlir::Value glweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getI32IntegerAttr(glweSize));
|
||||
// allocate two glwe to build accumulator
|
||||
auto polySizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op->getAttr("polynomialSize"));
|
||||
mlir::SmallVector<mlir::Value> allocGlweOperands{errOp, glweSizeOp,
|
||||
auto glweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), op->getAttr("glweDimension"));
|
||||
mlir::SmallVector<mlir::Value> allocGlweOperands{glweDimensionOp,
|
||||
polySizeOp};
|
||||
// first accumulator would replace the op since it's the returned value
|
||||
auto accumulatorOp = rewriter.replaceOpWithNewOp<mlir::CallOp>(
|
||||
@@ -548,8 +525,7 @@ struct GlweFromTableOpPattern
|
||||
op.getLoc(), "allocate_glwe_ciphertext_u64",
|
||||
getGenericGlweCiphertextType(rewriter.getContext()), allocGlweOperands);
|
||||
// allocate plaintext list
|
||||
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{errOp,
|
||||
polySizeOp};
|
||||
mlir::SmallVector<mlir::Value> allocPlaintextListOperands{polySizeOp};
|
||||
auto plaintextListOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_plaintext_list_u64",
|
||||
getGenericPlaintextListType(rewriter.getContext()),
|
||||
@@ -559,27 +535,23 @@ struct GlweFromTableOpPattern
|
||||
op->getOperandTypes().front().cast<mlir::RankedTensorType>();
|
||||
assert(rankedTensorType.getRank() == 1 &&
|
||||
"table lookup must be of a single dimension");
|
||||
auto sizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(),
|
||||
rewriter.getI64IntegerAttr(rankedTensorType.getDimSize(0)));
|
||||
auto precisionOp =
|
||||
rewriter.create<mlir::arith::ConstantOp>(op.getLoc(), op->getAttr("p"));
|
||||
mlir::SmallVector<mlir::Value> ForeignPlaintextListOperands{
|
||||
errOp, op->getOperand(0), sizeOp, precisionOp};
|
||||
op->getOperand(0), precisionOp};
|
||||
auto foreignPlaintextListOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "runtime_foreign_plaintext_list_u64",
|
||||
op.getLoc(), "memref_runtime_foreign_plaintext_list_u64",
|
||||
getGenericForeignPlaintextListType(rewriter.getContext()),
|
||||
ForeignPlaintextListOperands);
|
||||
// fill plaintext list
|
||||
mlir::SmallVector<mlir::Value> FillPlaintextListOperands{
|
||||
errOp, plaintextListOp.getResult(0),
|
||||
foreignPlaintextListOp.getResult(0)};
|
||||
plaintextListOp.getResult(0), foreignPlaintextListOp.getResult(0)};
|
||||
rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "fill_plaintext_list_with_expansion_u64",
|
||||
mlir::TypeRange({}), FillPlaintextListOperands);
|
||||
// add plaintext list and glwe to build final accumulator for pbs
|
||||
mlir::SmallVector<mlir::Value> AddPlaintextListGlweOperands{
|
||||
errOp, accumulatorOp.getResult(0), _accumulatorOp.getResult(0),
|
||||
accumulatorOp.getResult(0), _accumulatorOp.getResult(0),
|
||||
plaintextListOp.getResult(0)};
|
||||
rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "add_plaintext_list_glwe_ciphertext_u64",
|
||||
@@ -626,18 +598,15 @@ struct ConcreteBootstrapLweOpPattern
|
||||
matchAndRewrite(mlir::concretelang::Concrete::BootstrapLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto resultType = op->getResultTypes().front();
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
// Get the size from the dimension
|
||||
int64_t outputLweDimension =
|
||||
resultType.cast<mlir::concretelang::Concrete::LweCiphertextType>()
|
||||
.getDimension();
|
||||
int64_t outputLweSize = outputLweDimension + 1;
|
||||
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(outputLweSize));
|
||||
op.getLoc(), rewriter.getIndexAttr(outputLweDimension));
|
||||
// allocate the result lwe ciphertext, should be of a generic type, to cast
|
||||
// before return
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweSizeOp};
|
||||
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
@@ -662,7 +631,7 @@ struct ConcreteBootstrapLweOpPattern
|
||||
op.getOperand(1))
|
||||
.getResult(0);
|
||||
mlir::SmallVector<mlir::Value> bootstrapOperands{
|
||||
errOp, getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
getBskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
lweToBootstrap, accumulator};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "bootstrap_lwe_u64",
|
||||
mlir::TypeRange({}), bootstrapOperands);
|
||||
@@ -690,20 +659,17 @@ struct ConcreteKeySwitchLweOpPattern
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::concretelang::Concrete::KeySwitchLweOp op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto errOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(0));
|
||||
// Get the size from the dimension
|
||||
int64_t lweDimension =
|
||||
op.getResult()
|
||||
.getType()
|
||||
.cast<mlir::concretelang::Concrete::LweCiphertextType>()
|
||||
.getDimension();
|
||||
int64_t lweSize = lweDimension + 1;
|
||||
mlir::Value lweSizeOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweSize));
|
||||
mlir::Value lweDimensionOp = rewriter.create<mlir::arith::ConstantOp>(
|
||||
op.getLoc(), rewriter.getIndexAttr(lweDimension));
|
||||
// allocate the result lwe ciphertext, should be of a generic type, to cast
|
||||
// before return
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{errOp, lweSizeOp};
|
||||
mlir::SmallVector<mlir::Value> allocLweCtOperands{lweDimensionOp};
|
||||
auto allocateGenericLweCtOp = rewriter.create<mlir::CallOp>(
|
||||
op.getLoc(), "allocate_lwe_ciphertext_u64",
|
||||
getGenericLweCiphertextType(rewriter.getContext()), allocLweCtOperands);
|
||||
@@ -721,7 +687,7 @@ struct ConcreteKeySwitchLweOpPattern
|
||||
op.getOperand())
|
||||
.getResult(0);
|
||||
mlir::SmallVector<mlir::Value> keyswitchOperands{
|
||||
errOp, getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
getKskOp.getResult(0), allocateGenericLweCtOp.getResult(0),
|
||||
lweToKeyswitch};
|
||||
rewriter.create<mlir::CallOp>(op.getLoc(), "keyswitch_lwe_u64",
|
||||
mlir::TypeRange({}), keyswitchOperands);
|
||||
|
||||
@@ -14,9 +14,9 @@ if(CONCRETELANG_PARALLEL_EXECUTION_ENABLED)
|
||||
install(TARGETS DFRuntime EXPORT DFRuntime)
|
||||
install(EXPORT DFRuntime DESTINATION "./")
|
||||
|
||||
target_link_libraries(ConcretelangRuntime Concrete pthread m dl HPX::hpx)
|
||||
target_link_libraries(ConcretelangRuntime Concrete pthread m dl HPX::hpx $<TARGET_OBJECTS:mlir_c_runner_utils>)
|
||||
else()
|
||||
target_link_libraries(ConcretelangRuntime Concrete pthread m dl)
|
||||
target_link_libraries(ConcretelangRuntime Concrete pthread m dl $<TARGET_OBJECTS:mlir_c_runner_utils>)
|
||||
endif()
|
||||
|
||||
install(TARGETS ConcretelangRuntime EXPORT ConcretelangRuntime)
|
||||
|
||||
@@ -19,7 +19,6 @@ get_keyswitch_key(mlir::concretelang::RuntimeContext *context) {
|
||||
LweBootstrapKey_u64 *
|
||||
get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
int err;
|
||||
std::string threadName = hpx::get_thread_name();
|
||||
auto bskIt = context->bsk.find(threadName);
|
||||
if (bskIt == context->bsk.end()) {
|
||||
@@ -27,10 +26,8 @@ get_bootstrap_key(mlir::concretelang::RuntimeContext *context) {
|
||||
.insert(std::pair<std::string, LweBootstrapKey_u64 *>(
|
||||
threadName,
|
||||
clone_lwe_bootstrap_key_u64(
|
||||
&err, context->bsk["_concretelang_base_context_bsk"])))
|
||||
context->bsk["_concretelang_base_context_bsk"])))
|
||||
.first;
|
||||
if (err != 0)
|
||||
fprintf(stderr, "Runtime: cloning bootstrap key failed.\n");
|
||||
}
|
||||
#else
|
||||
std::string baseName = "_concretelang_base_context_bsk";
|
||||
|
||||
@@ -1,20 +1,89 @@
|
||||
#include "concretelang/Runtime/wrappers.h"
|
||||
#include <assert.h>
|
||||
#include <stdio.h>
|
||||
|
||||
ForeignPlaintextList_u64 *
|
||||
runtime_foreign_plaintext_list_u64(int *err, uint64_t *allocated,
|
||||
uint64_t *aligned, uint64_t offset,
|
||||
uint64_t size_dim0, uint64_t stride_dim0,
|
||||
uint64_t size, uint32_t precision) {
|
||||
if (stride_dim0 != 1) {
|
||||
fprintf(stderr, "Runtime: stride not equal to 1, check "
|
||||
"runtime_foreign_plaintext_list_u64");
|
||||
}
|
||||
struct ForeignPlaintextList_u64 *memref_runtime_foreign_plaintext_list_u64(
|
||||
uint64_t *allocated, uint64_t *aligned, uint64_t offset, uint64_t size,
|
||||
uint64_t stride, uint32_t precision) {
|
||||
|
||||
assert(stride == 1 && "Runtime: stride not equal to 1, check "
|
||||
"runtime_foreign_plaintext_list_u64");
|
||||
|
||||
// Encode table values in u64
|
||||
uint64_t *encoded_table = malloc(size * sizeof(uint64_t));
|
||||
for (uint64_t i = 0; i < size; i++) {
|
||||
encoded_table[i] = (aligned + offset)[i] << (64 - precision - 1);
|
||||
}
|
||||
return foreign_plaintext_list_u64(err, encoded_table, size);
|
||||
return foreign_plaintext_list_u64(encoded_table, size);
|
||||
// TODO: is it safe to free after creating plaintext_list?
|
||||
}
|
||||
|
||||
void memref_add_lwe_ciphertexts_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t *ct1_allocated, uint64_t *ct1_aligned,
|
||||
uint64_t ct1_offset, uint64_t ct1_size, uint64_t ct1_stride) {
|
||||
assert(out_size == ct0_size && out_size == ct1_size &&
|
||||
"size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
add_two_lwe_ciphertexts_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset,
|
||||
ct1_aligned + ct1_offset, lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_add_plaintext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t plaintext) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
add_plaintext_to_lwe_ciphertext_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, plaintext,
|
||||
lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_mul_cleartext_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride, uint64_t cleartext) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
mul_cleartext_lwe_ciphertext_u64(out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, cleartext,
|
||||
lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_negate_lwe_ciphertext_u64(
|
||||
uint64_t *out_allocated, uint64_t *out_aligned, uint64_t out_offset,
|
||||
uint64_t out_size, uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset, uint64_t ct0_size,
|
||||
uint64_t ct0_stride) {
|
||||
assert(out_size == ct0_size && "size of lwe buffer are incompatible");
|
||||
LweDimension lwe_dimension = {out_size - 1};
|
||||
neg_lwe_ciphertext_u64(out_aligned + out_offset, ct0_aligned + ct0_offset,
|
||||
lwe_dimension);
|
||||
}
|
||||
|
||||
void memref_keyswitch_lwe_u64(struct LweKeyswitchKey_u64 *keyswitch_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride) {
|
||||
bufferized_keyswitch_lwe_u64(keyswitch_key, out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset);
|
||||
}
|
||||
|
||||
void memref_bootstrap_lwe_u64(struct LweBootstrapKey_u64 *bootstrap_key,
|
||||
uint64_t *out_allocated, uint64_t *out_aligned,
|
||||
uint64_t out_offset, uint64_t out_size,
|
||||
uint64_t out_stride, uint64_t *ct0_allocated,
|
||||
uint64_t *ct0_aligned, uint64_t ct0_offset,
|
||||
uint64_t ct0_size, uint64_t ct0_stride,
|
||||
struct GlweCiphertext_u64 *accumulator) {
|
||||
bufferized_bootstrap_lwe_u64(bootstrap_key, out_aligned + out_offset,
|
||||
ct0_aligned + ct0_offset, accumulator);
|
||||
}
|
||||
|
||||
@@ -78,6 +78,12 @@ llvm::Error JITLambda::invoke(Argument &args) {
|
||||
<< actualInputs << "arguments instead of " << expectedInputs;
|
||||
}
|
||||
|
||||
// memref is a struct which is flattened aligned, allocated pointers, offset,
|
||||
// and two array of rank size for sizes and strides.
|
||||
uint64_t numArgOfRankedMemrefCallingConvention(uint64_t rank) {
|
||||
return 3 + 2 * rank;
|
||||
}
|
||||
|
||||
JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
// Setting the inputs
|
||||
auto numInputs = 0;
|
||||
@@ -86,16 +92,20 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
auto offset = numInputs;
|
||||
auto gate = keySet.inputGate(i);
|
||||
inputGates.push_back({gate, offset});
|
||||
if (keySet.inputGate(i).shape.dimensions.empty()) {
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
// scalar gate
|
||||
numInputs = numInputs + 1;
|
||||
if (gate.encryption.hasValue()) {
|
||||
// encrypted is a memref<lweSizexi64>
|
||||
numInputs = numInputs + numArgOfRankedMemrefCallingConvention(1);
|
||||
} else {
|
||||
numInputs = numInputs + 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
numInputs = numInputs + 3;
|
||||
// Offsets and strides are array of size N where N is the number of
|
||||
// dimension of the tensor.
|
||||
numInputs = numInputs + 2 * keySet.inputGate(i).shape.dimensions.size();
|
||||
auto rank = keySet.inputGate(i).shape.dimensions.size() +
|
||||
(keySet.isInputEncrypted(i) ? 1 /* for lwe size */ : 0);
|
||||
numInputs = numInputs + numArgOfRankedMemrefCallingConvention(rank);
|
||||
}
|
||||
// Reserve for the context argument
|
||||
numInputs = numInputs + 1;
|
||||
@@ -111,19 +121,21 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
outputGates.push_back({gate, offset});
|
||||
if (gate.shape.dimensions.empty()) {
|
||||
// scalar gate
|
||||
numOutputs = numOutputs + 1;
|
||||
if (gate.encryption.hasValue()) {
|
||||
// encrypted is a memref<lweSizexi64>
|
||||
numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(1);
|
||||
} else {
|
||||
numOutputs = numOutputs + 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
// memref gate, as we follow the standard calling convention
|
||||
numOutputs = numOutputs + 3;
|
||||
// Offsets and strides are array of size N where N is the number of
|
||||
// dimension of the tensor.
|
||||
numOutputs =
|
||||
numOutputs + 2 * keySet.outputGate(i).shape.dimensions.size();
|
||||
auto rank = keySet.outputGate(i).shape.dimensions.size() +
|
||||
(keySet.isOutputEncrypted(i) ? 1 /* for lwe size */ : 0);
|
||||
numOutputs = numOutputs + numArgOfRankedMemrefCallingConvention(rank);
|
||||
}
|
||||
outputs = std::vector<void *>(numOutputs);
|
||||
}
|
||||
|
||||
// The raw argument contains pointers to inputs and pointers to store the
|
||||
// results
|
||||
rawArg = std::vector<void *>(inputs.size() + outputs.size(), nullptr);
|
||||
@@ -139,9 +151,8 @@ JITLambda::Argument::Argument(KeySet &keySet) : keySet(keySet) {
|
||||
}
|
||||
|
||||
JITLambda::Argument::~Argument() {
|
||||
int err;
|
||||
for (auto ct : allocatedCiphertexts) {
|
||||
free_lwe_ciphertext_u64(&err, ct);
|
||||
free(ct);
|
||||
}
|
||||
for (auto buffer : ciphertextBuffers) {
|
||||
free(buffer);
|
||||
@@ -185,16 +196,31 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, uint64_t arg) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, allocate ciphertext and encrypt.
|
||||
LweCiphertext_u64 *ctArg;
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctArg)) {
|
||||
uint64_t *ctArg;
|
||||
uint64_t ctSize;
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctArg, ctSize)) {
|
||||
return std::move(err);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctArg);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctArg, arg)) {
|
||||
return std::move(err);
|
||||
}
|
||||
inputs[offset] = ctArg;
|
||||
// memref calling convention
|
||||
// allocated
|
||||
inputs[offset] = nullptr;
|
||||
// aligned
|
||||
inputs[offset + 1] = ctArg;
|
||||
// offset
|
||||
inputs[offset + 2] = (void *)0;
|
||||
// size
|
||||
inputs[offset + 3] = (void *)ctSize;
|
||||
// stride
|
||||
inputs[offset + 4] = (void *)1;
|
||||
rawArg[offset] = &inputs[offset];
|
||||
rawArg[offset + 1] = &inputs[offset + 1];
|
||||
rawArg[offset + 2] = &inputs[offset + 2];
|
||||
rawArg[offset + 3] = &inputs[offset + 3];
|
||||
rawArg[offset + 4] = &inputs[offset + 4];
|
||||
return llvm::Error::success();
|
||||
}
|
||||
|
||||
@@ -279,17 +305,18 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
// Allocate a buffer for ciphertexts.
|
||||
auto ctBuffer = (LweCiphertext_u64 **)malloc(info.shape.size *
|
||||
sizeof(LweCiphertext_u64 *));
|
||||
// Allocate a buffer for ciphertexts, the size of the buffer is the number
|
||||
// of elements of the tensor * the size of the lwe ciphertext
|
||||
auto lweSize = keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
uint64_t *ctBuffer =
|
||||
(uint64_t *)malloc(info.shape.size * lweSize * sizeof(uint64_t));
|
||||
ciphertextBuffers.push_back(ctBuffer);
|
||||
// Allocate ciphertexts and encrypt
|
||||
for (size_t i = 0; i < info.shape.size; i++) {
|
||||
if (auto err = this->keySet.allocate_lwe(pos, &ctBuffer[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
allocatedCiphertexts.push_back(ctBuffer[i]);
|
||||
if (auto err = this->keySet.encrypt_lwe(pos, ctBuffer[i], data8[i])) {
|
||||
// Encrypt ciphertexts
|
||||
for (size_t i = 0, offset = 0; i < info.shape.size;
|
||||
i++, offset += lweSize) {
|
||||
|
||||
if (auto err =
|
||||
this->keySet.encrypt_lwe(pos, ctBuffer + offset, data8[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
}
|
||||
@@ -316,17 +343,27 @@ llvm::Error JITLambda::Argument::setArg(size_t pos, size_t width,
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
}
|
||||
// If encrypted +1 for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
inputs[offset] = (void *)(keySet.getInputLweSecretKeyParam(pos).size + 1);
|
||||
rawArg[offset] = &inputs[offset];
|
||||
offset++;
|
||||
}
|
||||
|
||||
// Set the stride for each dimension, equal to the product of the
|
||||
// following dimensions.
|
||||
int64_t stride = 1;
|
||||
|
||||
// If encrypted +1 set the stride for the lwe size rank
|
||||
if (keySet.isInputEncrypted(pos)) {
|
||||
inputs[offset + shape.size()] = (void *)stride;
|
||||
rawArg[offset + shape.size()] = &inputs[offset];
|
||||
stride *= keySet.getInputLweSecretKeyParam(pos).size + 1;
|
||||
}
|
||||
for (ssize_t i = shape.size() - 1; i >= 0; i--) {
|
||||
inputs[offset + i] = (void *)stride;
|
||||
rawArg[offset + i] = &inputs[offset + i];
|
||||
stride *= shape[i];
|
||||
}
|
||||
|
||||
offset += shape.size();
|
||||
|
||||
return llvm::Error::success();
|
||||
@@ -349,7 +386,7 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, uint64_t &res) {
|
||||
return llvm::Error::success();
|
||||
}
|
||||
// Else if is encryted, decrypt
|
||||
LweCiphertext_u64 *ct = (LweCiphertext_u64 *)(outputs[offset]);
|
||||
uint64_t *ct = (uint64_t *)(outputs[offset + 1]);
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, res)) {
|
||||
return std::move(err);
|
||||
}
|
||||
@@ -463,8 +500,10 @@ llvm::Error JITLambda::Argument::getResult(size_t pos, void *res,
|
||||
}
|
||||
} else {
|
||||
// decrypt and fill the result buffer
|
||||
for (size_t i = 0; i < numElements; i++) {
|
||||
LweCiphertext_u64 *ct = ((LweCiphertext_u64 **)alignedBytes)[i];
|
||||
auto lweSize = keySet.getOutputLweSecretKeyParam(pos).size + 1;
|
||||
|
||||
for (size_t i = 0, o = 0; i < numElements; i++, o += lweSize) {
|
||||
uint64_t *ct = ((uint64_t *)alignedBytes) + o;
|
||||
if (auto err = this->keySet.decrypt_lwe(pos, ct, ((uint64_t *)res)[i])) {
|
||||
return std::move(err);
|
||||
}
|
||||
|
||||
@@ -58,7 +58,6 @@ JitCompilerEngine::buildLambda(llvm::StringRef s, llvm::StringRef funcName,
|
||||
std::unique_ptr<llvm::MemoryBuffer> mb = llvm::MemoryBuffer::getMemBuffer(s);
|
||||
llvm::Expected<JitCompilerEngine::Lambda> res =
|
||||
this->buildLambda(std::move(mb), funcName, cache, runtimeLibPath);
|
||||
|
||||
return std::move(res);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user