diff --git a/compiler/include/concretelang/Runtime/key_manager.hpp b/compiler/include/concretelang/Runtime/key_manager.hpp index d7b73a772..2d63a4186 100644 --- a/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compiler/include/concretelang/Runtime/key_manager.hpp @@ -34,13 +34,16 @@ static RuntimeContextManager *_dfr_node_level_runtime_context_manager; template struct KeyWrapper { LweKeyType *key; + Buffer buffer; KeyWrapper() : key(nullptr) {} - KeyWrapper(LweKeyType *key) : key(key) {} - KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {} - KeyWrapper(const KeyWrapper &kw) : key(kw.key) {} + KeyWrapper(KeyWrapper &&moved) noexcept + : key(moved.key), buffer(moved.buffer) {} + KeyWrapper(LweKeyType *key); + KeyWrapper(const KeyWrapper &kw) : key(kw.key), buffer(kw.buffer) {} KeyWrapper &operator=(const KeyWrapper &rhs) { this->key = rhs.key; + this->buffer = rhs.buffer; return *this; } friend class hpx::serialization::access; @@ -50,6 +53,13 @@ template struct KeyWrapper { HPX_SERIALIZATION_SPLIT_MEMBER() }; +template <> +KeyWrapper::KeyWrapper(LweKeyswitchKey_u64 *key) + : key(key), buffer(serialize_lwe_keyswitching_key_u64(key)) {} +template <> +KeyWrapper::KeyWrapper(LweBootstrapKey_u64 *key) + : key(key), buffer(serialize_lwe_bootstrap_key_u64(key)) {} + template bool operator==(const KeyWrapper &lhs, const KeyWrapper &rhs) { @@ -60,7 +70,6 @@ template <> template void KeyWrapper::save(Archive &ar, const unsigned int version) const { - Buffer buffer = serialize_lwe_bootstrap_key_u64(key); ar << buffer.length; ar << hpx::serialization::make_array(buffer.pointer, buffer.length); } @@ -68,19 +77,16 @@ template <> template void KeyWrapper::load(Archive &ar, const unsigned int version) { - size_t length; - ar >> length; - uint8_t *pointer = new uint8_t[length]; - ar >> hpx::serialization::make_array(pointer, length); - BufferView buffer = {(const uint8_t *)pointer, length}; - key = deserialize_lwe_bootstrap_key_u64(buffer); + ar >> buffer.length; + buffer.pointer = new uint8_t[buffer.length]; + ar >> hpx::serialization::make_array(buffer.pointer, buffer.length); + key = deserialize_lwe_bootstrap_key_u64({buffer.pointer, buffer.length}); } template <> template void KeyWrapper::save(Archive &ar, const unsigned int version) const { - Buffer buffer = serialize_lwe_keyswitching_key_u64(key); ar << buffer.length; ar << hpx::serialization::make_array(buffer.pointer, buffer.length); } @@ -88,12 +94,10 @@ template <> template void KeyWrapper::load(Archive &ar, const unsigned int version) { - size_t length; - ar >> length; - uint8_t *pointer = new uint8_t[length]; - ar >> hpx::serialization::make_array(pointer, length); - BufferView buffer = {(const uint8_t *)pointer, length}; - key = deserialize_lwe_keyswitching_key_u64(buffer); + ar >> buffer.length; + buffer.pointer = new uint8_t[buffer.length]; + ar >> hpx::serialization::make_array(buffer.pointer, buffer.length); + key = deserialize_lwe_keyswitching_key_u64({buffer.pointer, buffer.length}); } /************************/ @@ -121,10 +125,10 @@ struct RuntimeContextManager { LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context); LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context); - auto kskFut = hpx::collectives::broadcast_to( - "ksk_keystore", KeyWrapper(ksk)); - auto bskFut = hpx::collectives::broadcast_to( - "bsk_keystore", KeyWrapper(bsk)); + KeyWrapper kskw(ksk); + KeyWrapper bskw(bsk); + hpx::collectives::broadcast_to("ksk_keystore", kskw); + hpx::collectives::broadcast_to("bsk_keystore", bskw); } else { auto kskFut = hpx::collectives::broadcast_from>( @@ -133,13 +137,16 @@ struct RuntimeContextManager { hpx::collectives::broadcast_from>( "bsk_keystore"); + KeyWrapper kskw = kskFut.get(); + KeyWrapper bskw = bskFut.get(); context = new mlir::concretelang::RuntimeContext(); context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys( std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>( - new ::concretelang::clientlib::LweKeyswitchKey(kskFut.get().key)), + new ::concretelang::clientlib::LweKeyswitchKey(kskw.key)), std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>( - new ::concretelang::clientlib::LweBootstrapKey( - bskFut.get().key))); + new ::concretelang::clientlib::LweBootstrapKey(bskw.key))); + delete (kskw.buffer.pointer); + delete (bskw.buffer.pointer); } }