mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
fix(dfr): deallocate receive buffers for evaluation keys on remote nodes.
This commit is contained in:
@@ -34,13 +34,16 @@ static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
|
||||
|
||||
template <typename LweKeyType> struct KeyWrapper {
|
||||
LweKeyType *key;
|
||||
Buffer buffer;
|
||||
|
||||
KeyWrapper() : key(nullptr) {}
|
||||
KeyWrapper(LweKeyType *key) : key(key) {}
|
||||
KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {}
|
||||
KeyWrapper(const KeyWrapper &kw) : key(kw.key) {}
|
||||
KeyWrapper(KeyWrapper &&moved) noexcept
|
||||
: key(moved.key), buffer(moved.buffer) {}
|
||||
KeyWrapper(LweKeyType *key);
|
||||
KeyWrapper(const KeyWrapper &kw) : key(kw.key), buffer(kw.buffer) {}
|
||||
KeyWrapper &operator=(const KeyWrapper &rhs) {
|
||||
this->key = rhs.key;
|
||||
this->buffer = rhs.buffer;
|
||||
return *this;
|
||||
}
|
||||
friend class hpx::serialization::access;
|
||||
@@ -50,6 +53,13 @@ template <typename LweKeyType> struct KeyWrapper {
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
};
|
||||
|
||||
template <>
|
||||
KeyWrapper<LweKeyswitchKey_u64>::KeyWrapper(LweKeyswitchKey_u64 *key)
|
||||
: key(key), buffer(serialize_lwe_keyswitching_key_u64(key)) {}
|
||||
template <>
|
||||
KeyWrapper<LweBootstrapKey_u64>::KeyWrapper(LweBootstrapKey_u64 *key)
|
||||
: key(key), buffer(serialize_lwe_bootstrap_key_u64(key)) {}
|
||||
|
||||
template <typename LweKeyType>
|
||||
bool operator==(const KeyWrapper<LweKeyType> &lhs,
|
||||
const KeyWrapper<LweKeyType> &rhs) {
|
||||
@@ -60,7 +70,6 @@ template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweBootstrapKey_u64>::save(Archive &ar,
|
||||
const unsigned int version) const {
|
||||
Buffer buffer = serialize_lwe_bootstrap_key_u64(key);
|
||||
ar << buffer.length;
|
||||
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
}
|
||||
@@ -68,19 +77,16 @@ template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
|
||||
const unsigned int version) {
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_bootstrap_key_u64(buffer);
|
||||
ar >> buffer.length;
|
||||
buffer.pointer = new uint8_t[buffer.length];
|
||||
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
key = deserialize_lwe_bootstrap_key_u64({buffer.pointer, buffer.length});
|
||||
}
|
||||
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweKeyswitchKey_u64>::save(Archive &ar,
|
||||
const unsigned int version) const {
|
||||
Buffer buffer = serialize_lwe_keyswitching_key_u64(key);
|
||||
ar << buffer.length;
|
||||
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
}
|
||||
@@ -88,12 +94,10 @@ template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
|
||||
const unsigned int version) {
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_keyswitching_key_u64(buffer);
|
||||
ar >> buffer.length;
|
||||
buffer.pointer = new uint8_t[buffer.length];
|
||||
ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
key = deserialize_lwe_keyswitching_key_u64({buffer.pointer, buffer.length});
|
||||
}
|
||||
|
||||
/************************/
|
||||
@@ -121,10 +125,10 @@ struct RuntimeContextManager {
|
||||
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
|
||||
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
|
||||
|
||||
auto kskFut = hpx::collectives::broadcast_to(
|
||||
"ksk_keystore", KeyWrapper<LweKeyswitchKey_u64>(ksk));
|
||||
auto bskFut = hpx::collectives::broadcast_to(
|
||||
"bsk_keystore", KeyWrapper<LweBootstrapKey_u64>(bsk));
|
||||
KeyWrapper<LweKeyswitchKey_u64> kskw(ksk);
|
||||
KeyWrapper<LweBootstrapKey_u64> bskw(bsk);
|
||||
hpx::collectives::broadcast_to("ksk_keystore", kskw);
|
||||
hpx::collectives::broadcast_to("bsk_keystore", bskw);
|
||||
} else {
|
||||
auto kskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey_u64>>(
|
||||
@@ -133,13 +137,16 @@ struct RuntimeContextManager {
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey_u64>>(
|
||||
"bsk_keystore");
|
||||
|
||||
KeyWrapper<LweKeyswitchKey_u64> kskw = kskFut.get();
|
||||
KeyWrapper<LweBootstrapKey_u64> bskw = bskFut.get();
|
||||
context = new mlir::concretelang::RuntimeContext();
|
||||
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
|
||||
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
|
||||
new ::concretelang::clientlib::LweKeyswitchKey(kskFut.get().key)),
|
||||
new ::concretelang::clientlib::LweKeyswitchKey(kskw.key)),
|
||||
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
|
||||
new ::concretelang::clientlib::LweBootstrapKey(
|
||||
bskFut.get().key)));
|
||||
new ::concretelang::clientlib::LweBootstrapKey(bskw.key)));
|
||||
delete (kskw.buffer.pointer);
|
||||
delete (bskw.buffer.pointer);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user