diff --git a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h index e5448a7d0..434d18079 100644 --- a/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h +++ b/compilers/concrete-compiler/compiler/include/concretelang/Common/Keys.h @@ -10,6 +10,7 @@ #include "concretelang/Common/Csprng.h" #include "concretelang/Common/Protocol.h" #include +#include #include #include @@ -87,7 +88,8 @@ public: LweBootstrapKey(std::shared_ptr> buffer, Message info) : seededBuffer(std::make_shared>()), buffer(buffer), - info(info){}; + info(info), decompress_mutext(std::make_shared()), + decompressed(false){}; /// @brief Initialize the key from the protocol message. static LweBootstrapKey @@ -107,7 +109,9 @@ public: private: LweBootstrapKey(Message info) : seededBuffer(std::make_shared>()), - buffer(std::make_shared>()), info(info){}; + buffer(std::make_shared>()), info(info), + decompress_mutext(std::make_shared()), + decompressed(false){}; LweBootstrapKey() = delete; /// @brief The buffer of the seeded key if needed. @@ -118,6 +122,12 @@ private: /// @brief The metadata of the bootrap key. Message info; + + /// @brief Mutex to guard the decompression + std::shared_ptr decompress_mutext; + + /// @brief A boolean that indicates if the decompression is done or not + bool decompressed; }; class LweKeyswitchKey { @@ -130,7 +140,8 @@ public: LweKeyswitchKey(std::shared_ptr> buffer, Message info) : seededBuffer(std::make_shared>()), buffer(buffer), - info(info){}; + info(info), decompress_mutext(std::make_shared()), + decompressed(false){}; /// @brief Initialize the key from the protocol message. static LweKeyswitchKey @@ -150,7 +161,9 @@ public: private: LweKeyswitchKey(Message info) : seededBuffer(std::make_shared>()), - buffer(std::make_shared>()), info(info){}; + buffer(std::make_shared>()), info(info), + decompress_mutext(std::make_shared()), + decompressed(false){}; /// @brief The buffer of the seeded key if needed. std::shared_ptr> seededBuffer; @@ -160,6 +173,12 @@ private: /// @brief The metadata of the bootrap key. Message info; + + /// @brief Mutex to guard the decompression + std::shared_ptr decompress_mutext; + + /// @brief A boolean that indicates if the decompression is done or not + bool decompressed; }; class PackingKeyswitchKey { diff --git a/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp b/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp index add19359f..64f04cb2e 100644 --- a/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp +++ b/compilers/concrete-compiler/compiler/lib/Common/Keys.cpp @@ -193,8 +193,7 @@ Message LweBootstrapKey::toProto() const { } const std::vector &LweBootstrapKey::getBuffer() { - if (buffer->size() == 0) - decompress(); + decompress(); return *buffer; } @@ -220,6 +219,11 @@ void LweBootstrapKey::decompress() { case concreteprotocol::Compression::NONE: return; case concreteprotocol::Compression::SEED: { + if (decompressed) + return; + const std::lock_guard guard(*decompress_mutext); + if (decompressed) + return; auto params = info.asReader().getParams(); buffer->resize(concrete_cpu_bootstrap_key_size_u64( params.getLevelCount(), params.getGlweDimension(), @@ -230,6 +234,7 @@ void LweBootstrapKey::decompress() { buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(), params.getPolynomialSize(), params.getGlweDimension(), params.getLevelCount(), params.getBaseLog(), seed); + decompressed = true; return; } default: @@ -313,8 +318,7 @@ LweKeyswitchKey::getInfo() const { } const std::vector &LweKeyswitchKey::getBuffer() { - if (buffer->size() == 0) - decompress(); + decompress(); return *buffer; } @@ -335,6 +339,11 @@ void LweKeyswitchKey::decompress() { case concreteprotocol::Compression::NONE: return; case concreteprotocol::Compression::SEED: { + if (decompressed) + return; + const std::lock_guard guard(*decompress_mutext); + if (decompressed) + return; auto params = info.asReader().getParams(); buffer->resize(concrete_cpu_keyswitch_key_size_u64( params.getLevelCount(), params.getInputLweDimension(), @@ -345,6 +354,7 @@ void LweKeyswitchKey::decompress() { buffer->data(), seededBuffer->data() + 2, params.getInputLweDimension(), params.getOutputLweDimension(), params.getLevelCount(), params.getBaseLog(), seed); + decompressed = true; return; } default: