mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-17 08:01:20 -05:00
fix(compiler): Guard decompression of seeded keys to avoid access conflicts in parallel programs
This commit is contained in:
committed by
Quentin Bourgerie
parent
760e5ef02a
commit
7651cb1129
@@ -10,6 +10,7 @@
|
||||
#include "concretelang/Common/Csprng.h"
|
||||
#include "concretelang/Common/Protocol.h"
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdlib.h>
|
||||
#include <vector>
|
||||
|
||||
@@ -87,7 +88,8 @@ public:
|
||||
LweBootstrapKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
Message<concreteprotocol::LweBootstrapKeyInfo> info)
|
||||
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
|
||||
info(info){};
|
||||
info(info), decompress_mutext(std::make_shared<std::mutex>()),
|
||||
decompressed(false){};
|
||||
|
||||
/// @brief Initialize the key from the protocol message.
|
||||
static LweBootstrapKey
|
||||
@@ -107,7 +109,9 @@ public:
|
||||
private:
|
||||
LweBootstrapKey(Message<concreteprotocol::LweBootstrapKeyInfo> info)
|
||||
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
|
||||
buffer(std::make_shared<std::vector<uint64_t>>()), info(info){};
|
||||
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
|
||||
decompress_mutext(std::make_shared<std::mutex>()),
|
||||
decompressed(false){};
|
||||
LweBootstrapKey() = delete;
|
||||
|
||||
/// @brief The buffer of the seeded key if needed.
|
||||
@@ -118,6 +122,12 @@ private:
|
||||
|
||||
/// @brief The metadata of the bootrap key.
|
||||
Message<concreteprotocol::LweBootstrapKeyInfo> info;
|
||||
|
||||
/// @brief Mutex to guard the decompression
|
||||
std::shared_ptr<std::mutex> decompress_mutext;
|
||||
|
||||
/// @brief A boolean that indicates if the decompression is done or not
|
||||
bool decompressed;
|
||||
};
|
||||
|
||||
class LweKeyswitchKey {
|
||||
@@ -130,7 +140,8 @@ public:
|
||||
LweKeyswitchKey(std::shared_ptr<std::vector<uint64_t>> buffer,
|
||||
Message<concreteprotocol::LweKeyswitchKeyInfo> info)
|
||||
: seededBuffer(std::make_shared<std::vector<uint64_t>>()), buffer(buffer),
|
||||
info(info){};
|
||||
info(info), decompress_mutext(std::make_shared<std::mutex>()),
|
||||
decompressed(false){};
|
||||
|
||||
/// @brief Initialize the key from the protocol message.
|
||||
static LweKeyswitchKey
|
||||
@@ -150,7 +161,9 @@ public:
|
||||
private:
|
||||
LweKeyswitchKey(Message<concreteprotocol::LweKeyswitchKeyInfo> info)
|
||||
: seededBuffer(std::make_shared<std::vector<uint64_t>>()),
|
||||
buffer(std::make_shared<std::vector<uint64_t>>()), info(info){};
|
||||
buffer(std::make_shared<std::vector<uint64_t>>()), info(info),
|
||||
decompress_mutext(std::make_shared<std::mutex>()),
|
||||
decompressed(false){};
|
||||
|
||||
/// @brief The buffer of the seeded key if needed.
|
||||
std::shared_ptr<std::vector<uint64_t>> seededBuffer;
|
||||
@@ -160,6 +173,12 @@ private:
|
||||
|
||||
/// @brief The metadata of the bootrap key.
|
||||
Message<concreteprotocol::LweKeyswitchKeyInfo> info;
|
||||
|
||||
/// @brief Mutex to guard the decompression
|
||||
std::shared_ptr<std::mutex> decompress_mutext;
|
||||
|
||||
/// @brief A boolean that indicates if the decompression is done or not
|
||||
bool decompressed;
|
||||
};
|
||||
|
||||
class PackingKeyswitchKey {
|
||||
|
||||
@@ -193,8 +193,7 @@ Message<concreteprotocol::LweBootstrapKey> LweBootstrapKey::toProto() const {
|
||||
}
|
||||
|
||||
const std::vector<uint64_t> &LweBootstrapKey::getBuffer() {
|
||||
if (buffer->size() == 0)
|
||||
decompress();
|
||||
decompress();
|
||||
return *buffer;
|
||||
}
|
||||
|
||||
@@ -220,6 +219,11 @@ void LweBootstrapKey::decompress() {
|
||||
case concreteprotocol::Compression::NONE:
|
||||
return;
|
||||
case concreteprotocol::Compression::SEED: {
|
||||
if (decompressed)
|
||||
return;
|
||||
const std::lock_guard<std::mutex> guard(*decompress_mutext);
|
||||
if (decompressed)
|
||||
return;
|
||||
auto params = info.asReader().getParams();
|
||||
buffer->resize(concrete_cpu_bootstrap_key_size_u64(
|
||||
params.getLevelCount(), params.getGlweDimension(),
|
||||
@@ -230,6 +234,7 @@ void LweBootstrapKey::decompress() {
|
||||
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
|
||||
params.getPolynomialSize(), params.getGlweDimension(),
|
||||
params.getLevelCount(), params.getBaseLog(), seed);
|
||||
decompressed = true;
|
||||
return;
|
||||
}
|
||||
default:
|
||||
@@ -313,8 +318,7 @@ LweKeyswitchKey::getInfo() const {
|
||||
}
|
||||
|
||||
const std::vector<uint64_t> &LweKeyswitchKey::getBuffer() {
|
||||
if (buffer->size() == 0)
|
||||
decompress();
|
||||
decompress();
|
||||
return *buffer;
|
||||
}
|
||||
|
||||
@@ -335,6 +339,11 @@ void LweKeyswitchKey::decompress() {
|
||||
case concreteprotocol::Compression::NONE:
|
||||
return;
|
||||
case concreteprotocol::Compression::SEED: {
|
||||
if (decompressed)
|
||||
return;
|
||||
const std::lock_guard<std::mutex> guard(*decompress_mutext);
|
||||
if (decompressed)
|
||||
return;
|
||||
auto params = info.asReader().getParams();
|
||||
buffer->resize(concrete_cpu_keyswitch_key_size_u64(
|
||||
params.getLevelCount(), params.getInputLweDimension(),
|
||||
@@ -345,6 +354,7 @@ void LweKeyswitchKey::decompress() {
|
||||
buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(),
|
||||
params.getOutputLweDimension(), params.getLevelCount(),
|
||||
params.getBaseLog(), seed);
|
||||
decompressed = true;
|
||||
return;
|
||||
}
|
||||
default:
|
||||
|
||||
Reference in New Issue
Block a user