fix(compiler): Guard decompression of seeded keys to avoid access conflicts in parallel programs

This commit is contained in:
Bourgerie Quentin
2024-02-21 13:55:54 +01:00
committed by Quentin Bourgerie
parent 760e5ef02a
commit 7651cb1129
2 changed files with 37 additions and 8 deletions

View File

@@ -10,6 +10,7 @@
#include "concretelang/Common/Csprng.h"
#include "concretelang/Common/Protocol.h"
#include <memory>
#include <mutex>
#include <stdlib.h>
#include <vector>
@@ -87,7 +88,8 @@ public:
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
Message<concreteprotocol::LweBootstrapKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
info(info){};
info(info), decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
/// @brief Initialize the key from the protocol message.
static LweBootstrapKey
@@ -107,7 +109,9 @@ public:
private:
LweBootstrapKey(Message<concreteprotocol::LweBootstrapKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
buffer(std::make_shared<std::vector<uint64_t>>()), info(info){};
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
LweBootstrapKey() = delete;
/// @brief The buffer of the seeded key if needed.
@@ -118,6 +122,12 @@ private:
/// @brief The metadata of the bootrap key.
Message<concreteprotocol::LweBootstrapKeyInfo> info;
/// @brief Mutex to guard the decompression
std::shared_ptr<std::mutex> decompress_mutext;
/// @brief A boolean that indicates if the decompression is done or not
bool decompressed;
};
class LweKeyswitchKey {
@@ -130,7 +140,8 @@ public:
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
Message<concreteprotocol::LweKeyswitchKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
info(info){};
info(info), decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
/// @brief Initialize the key from the protocol message.
static LweKeyswitchKey
@@ -150,7 +161,9 @@ public:
private:
LweKeyswitchKey(Message<concreteprotocol::LweKeyswitchKeyInfo> info)
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
buffer(std::make_shared<std::vector<uint64_t>>()), info(info){};
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
decompress_mutext(std::make_shared<std::mutex>()),
decompressed(false){};
/// @brief The buffer of the seeded key if needed.
std::shared_ptr<std::vector<uint64_t>> seededBuffer;
@@ -160,6 +173,12 @@ private:
/// @brief The metadata of the bootrap key.
Message<concreteprotocol::LweKeyswitchKeyInfo> info;
/// @brief Mutex to guard the decompression
std::shared_ptr<std::mutex> decompress_mutext;
/// @brief A boolean that indicates if the decompression is done or not
bool decompressed;
};
class PackingKeyswitchKey {

View File

@@ -193,8 +193,7 @@ Message<concreteprotocol::LweBootstrapKey> LweBootstrapKey::toProto() const {
}
const std::vector<uint64_t> &LweBootstrapKey::getBuffer() {
if (buffer->size() == 0)
decompress();
decompress();
return *buffer;
}
@@ -220,6 +219,11 @@ void LweBootstrapKey::decompress() {
case concreteprotocol::Compression::NONE:
return;
case concreteprotocol::Compression::SEED: {
if (decompressed)
return;
const std::lock_guard<std::mutex> guard(*decompress_mutext);
if (decompressed)
return;
auto params = info.asReader().getParams();
buffer->resize(concrete_cpu_bootstrap_key_size_u64(
params.getLevelCount(), params.getGlweDimension(),
@@ -230,6 +234,7 @@ void LweBootstrapKey::decompress() {
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
params.getPolynomialSize(), params.getGlweDimension(),
params.getLevelCount(), params.getBaseLog(), seed);
decompressed = true;
return;
}
default:
@@ -313,8 +318,7 @@ LweKeyswitchKey::getInfo() const {
}
const std::vector<uint64_t> &LweKeyswitchKey::getBuffer() {
if (buffer->size() == 0)
decompress();
decompress();
return *buffer;
}
@@ -335,6 +339,11 @@ void LweKeyswitchKey::decompress() {
case concreteprotocol::Compression::NONE:
return;
case concreteprotocol::Compression::SEED: {
if (decompressed)
return;
const std::lock_guard<std::mutex> guard(*decompress_mutext);
if (decompressed)
return;
auto params = info.asReader().getParams();
buffer->resize(concrete_cpu_keyswitch_key_size_u64(
params.getLevelCount(), params.getInputLweDimension(),
@@ -345,6 +354,7 @@ void LweKeyswitchKey::decompress() {
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
params.getOutputLweDimension(), params.getLevelCount(),
params.getBaseLog(), seed);
decompressed = true;
return;
}
default: