fix(dfr): deallocate receive buffers for evaluation keys on remote nodes.

This commit is contained in:
Antoniu Pop
2022-07-29 08:52:47 +01:00
committed by Antoniu Pop
parent 2f5f9f6cf1
commit dd2b2b9ce9

View File

@@ -34,13 +34,16 @@ static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
template <typename LweKeyType> struct KeyWrapper {
LweKeyType *key;
Buffer buffer;
KeyWrapper() : key(nullptr) {}
KeyWrapper(LweKeyType *key) : key(key) {}
KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {}
KeyWrapper(const KeyWrapper &kw) : key(kw.key) {}
KeyWrapper(KeyWrapper &&moved) noexcept
: key(moved.key), buffer(moved.buffer) {}
KeyWrapper(LweKeyType *key);
KeyWrapper(const KeyWrapper &kw) : key(kw.key), buffer(kw.buffer) {}
KeyWrapper &operator=(const KeyWrapper &rhs) {
this->key = rhs.key;
this->buffer = rhs.buffer;
return *this;
}
friend class hpx::serialization::access;
@@ -50,6 +53,13 @@ template <typename LweKeyType> struct KeyWrapper {
HPX_SERIALIZATION_SPLIT_MEMBER()
};
template <>
KeyWrapper<LweKeyswitchKey_u64>::KeyWrapper(LweKeyswitchKey_u64 *key)
: key(key), buffer(serialize_lwe_keyswitching_key_u64(key)) {}
template <>
KeyWrapper<LweBootstrapKey_u64>::KeyWrapper(LweBootstrapKey_u64 *key)
: key(key), buffer(serialize_lwe_bootstrap_key_u64(key)) {}
template <typename LweKeyType>
bool operator==(const KeyWrapper<LweKeyType> &lhs,
const KeyWrapper<LweKeyType> &rhs) {
@@ -60,7 +70,6 @@ template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey_u64>::save(Archive &ar,
const unsigned int version) const {
Buffer buffer = serialize_lwe_bootstrap_key_u64(key);
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
@@ -68,19 +77,16 @@ template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
const unsigned int version) {
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_bootstrap_key_u64(buffer);
ar >> buffer.length;
buffer.pointer = new uint8_t[buffer.length];
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
key = deserialize_lwe_bootstrap_key_u64({buffer.pointer, buffer.length});
}
template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey_u64>::save(Archive &ar,
const unsigned int version) const {
Buffer buffer = serialize_lwe_keyswitching_key_u64(key);
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
@@ -88,12 +94,10 @@ template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
const unsigned int version) {
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_keyswitching_key_u64(buffer);
ar >> buffer.length;
buffer.pointer = new uint8_t[buffer.length];
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
key = deserialize_lwe_keyswitching_key_u64({buffer.pointer, buffer.length});
}
/************************/
@@ -121,10 +125,10 @@ struct RuntimeContextManager {
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
auto kskFut = hpx::collectives::broadcast_to(
"ksk_keystore", KeyWrapper<LweKeyswitchKey_u64>(ksk));
auto bskFut = hpx::collectives::broadcast_to(
"bsk_keystore", KeyWrapper<LweBootstrapKey_u64>(bsk));
KeyWrapper<LweKeyswitchKey_u64> kskw(ksk);
KeyWrapper<LweBootstrapKey_u64> bskw(bsk);
hpx::collectives::broadcast_to("ksk_keystore", kskw);
hpx::collectives::broadcast_to("bsk_keystore", bskw);
} else {
auto kskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey_u64>>(
@@ -133,13 +137,16 @@ struct RuntimeContextManager {
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey_u64>>(
"bsk_keystore");
KeyWrapper<LweKeyswitchKey_u64> kskw = kskFut.get();
KeyWrapper<LweBootstrapKey_u64> bskw = bskFut.get();
context = new mlir::concretelang::RuntimeContext();
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
new ::concretelang::clientlib::LweKeyswitchKey(kskFut.get().key)),
new ::concretelang::clientlib::LweKeyswitchKey(kskw.key)),
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
new ::concretelang::clientlib::LweBootstrapKey(
bskFut.get().key)));
new ::concretelang::clientlib::LweBootstrapKey(bskw.key)));
delete (kskw.buffer.pointer);
delete (bskw.buffer.pointer);
}
}