fix(compiler): fix key serialization for distributed computing in DFR.

Key serialization for transfers between nodes in clusters was broken
since the changes introduced to separate keys from key parameters and
introduction of support for multi-key (ref
cacffadbd2).

This commit restores functionality for distributing keys to non-shared
memory nodes.
This commit is contained in:
Antoniu Pop
2023-03-10 10:26:00 +00:00
committed by Antoniu Pop
parent 3456978c24
commit 9363c40753
3 changed files with 64 additions and 90 deletions

View File

@@ -8,6 +8,7 @@
#include <memory>
#include <mutex>
#include <stdlib.h>
#include <utility>
#include <hpx/include/runtime.hpp>
@@ -15,22 +16,26 @@
#include <hpx/modules/serialization.hpp>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/ClientLib/PublicArguments.h"
#include "concretelang/Common/Error.h"
namespace mlir {
namespace concretelang {
namespace dfr {
using namespace ::concretelang::clientlib;
struct RuntimeContextManager;
namespace {
static void *dl_handle;
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
template <typename LweKeyType> struct KeyWrapper {
template <typename LweKeyType, typename KeyParamType> struct KeyWrapper {
std::vector<LweKeyType> keys;
KeyWrapper() {}
@@ -42,17 +47,38 @@ template <typename LweKeyType> struct KeyWrapper {
}
KeyWrapper(std::vector<LweKeyType> keyvec) : keys(keyvec) {}
friend class hpx::serialization::access;
// template <class Archive>
// void save(Archive &ar, const unsigned int version) const;
template <class Archive>
void serialize(Archive &ar, const unsigned int version) const {}
// template <class Archive> void load(Archive &ar, const unsigned int
// version); HPX_SERIALIZATION_SPLIT_MEMBER()
void save(Archive &ar, const unsigned int version) const {
ar << (size_t)keys.size();
for (auto k : keys) {
auto params = k.parameters();
size_t param_size = sizeof(KeyParamType);
ar << hpx::serialization::make_array((char *)&params, param_size);
ar << (size_t)k.size();
ar << hpx::serialization::make_array(k.buffer(), k.size());
}
}
template <class Archive> void load(Archive &ar, const unsigned int version) {
size_t num_keys;
ar >> num_keys;
for (uint i = 0; i < num_keys; ++i) {
KeyParamType params;
size_t param_size = sizeof(params);
ar >> hpx::serialization::make_array((char *)&params, param_size);
size_t key_size;
ar >> key_size;
auto buffer = std::make_shared<std::vector<uint64_t>>();
buffer->resize(key_size);
ar >> hpx::serialization::make_array(buffer->data(), key_size);
keys.push_back(LweKeyType(buffer, params));
}
}
HPX_SERIALIZATION_SPLIT_MEMBER()
};
template <typename LweKeyType>
bool operator==(const KeyWrapper<LweKeyType> &lhs,
const KeyWrapper<LweKeyType> &rhs) {
template <typename LweKeyType, typename KeyParamType>
bool operator==(const KeyWrapper<LweKeyType, KeyParamType> &lhs,
const KeyWrapper<LweKeyType, KeyParamType> &rhs) {
if (lhs.keys.size() != rhs.keys.size())
return false;
for (size_t i = 0; i < lhs.keys.size(); ++i)
@@ -61,54 +87,6 @@ bool operator==(const KeyWrapper<LweKeyType> &lhs,
return true;
}
// template <>
// template <class Archive>
// void KeyWrapper<LweBootstrapKey>::save(Archive &ar,
// const unsigned int version) const {
// ar << buffer.length;
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
// }
// template <>
// template <class Archive>
// void KeyWrapper<LweBootstrapKey>::load(Archive &ar,
// const unsigned int version) {
// DefaultSerializationEngine *engine;
// // No Freeing as it doesn't allocate anything.
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// ar >> buffer.length;
// buffer.pointer = new uint8_t[buffer.length];
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
// CAPI_ASSERT_ERROR(
// default_serialization_engine_deserialize_lwe_bootstrap_key_u64(
// engine, {buffer.pointer, buffer.length}, &key));
// }
// template <>
// template <class Archive>
// void KeyWrapper<LweKeyswitchKey>::save(Archive &ar,
// const unsigned int version) const {
// ar << buffer.length;
// ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
// }
// template <>
// template <class Archive>
// void KeyWrapper<LweKeyswitchKey>::load(Archive &ar,
// const unsigned int version) {
// DefaultSerializationEngine *engine;
// // No Freeing as it doesn't allocate anything.
// CAPI_ASSERT_ERROR(new_default_serialization_engine(&engine));
// ar >> buffer.length;
// buffer.pointer = new uint8_t[buffer.length];
// ar >> hpx::serialization::make_array(buffer.pointer, buffer.length);
// CAPI_ASSERT_ERROR(
// default_serialization_engine_deserialize_lwe_keyswitch_key_u64(
// engine, {buffer.pointer, buffer.length}, &key));
// }
/************************/
/* Context management. */
/************************/
@@ -132,35 +110,21 @@ struct RuntimeContextManager {
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw(
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam> kskw(
context->getKeys().getKeyswitchKeys());
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw(
KeyWrapper<LweBootstrapKey, BootstrapKeyParam> bskw(
context->getKeys().getBootstrapKeys());
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw(
context->getKeys().getPackingKeyswitchKeys());
hpx::collectives::broadcast_to("ksk_keystore", kskw);
hpx::collectives::broadcast_to("bsk_keystore", bskw);
hpx::collectives::broadcast_to("pksk_keystore", pkskw);
} else {
auto kskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey>>(
"ksk_keystore");
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam>>("ksk_keystore");
auto bskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::LweBootstrapKey>>(
"bsk_keystore");
auto pkskFut = hpx::collectives::broadcast_from<
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey>>(
"pksk_keystore");
KeyWrapper<::concretelang::clientlib::LweKeyswitchKey> kskw =
kskFut.get();
KeyWrapper<::concretelang::clientlib::LweBootstrapKey> bskw =
bskFut.get();
KeyWrapper<::concretelang::clientlib::PackingKeyswitchKey> pkskw =
pkskFut.get();
KeyWrapper<LweBootstrapKey, BootstrapKeyParam>>("bsk_keystore");
KeyWrapper<LweKeyswitchKey, KeyswitchKeyParam> kskw = kskFut.get();
KeyWrapper<LweBootstrapKey, BootstrapKeyParam> bskw = bskFut.get();
context = new mlir::concretelang::RuntimeContext(
::concretelang::clientlib::EvaluationKeys(kskw.keys, bskw.keys,
pkskw.keys));
EvaluationKeys(kskw.keys, bskw.keys, {}));
}
}

View File

@@ -66,6 +66,14 @@ struct WorkFunctionRegistry {
return ret;
}
void clearRegistry() {
std::lock_guard<std::mutex> guard(registry_guard);
ptr_to_name_registry.clear();
name_to_ptr_registry.clear();
fnid = 0;
}
private:
void registerWorkFunction(const void *fn, std::string name) {
@@ -81,7 +89,6 @@ private:
}
std::string registerAnonymousWorkFunction(const void *fn) {
static std::atomic<unsigned int> fnid{0};
std::string name = "_dfr_jit_wfnname_" + std::to_string(fnid++);
registerWorkFunction(fn, name);
return name;
@@ -89,6 +96,7 @@ private:
private:
std::mutex registry_guard;
std::atomic<unsigned int> fnid{0};
std::map<const void *, std::string> ptr_to_name_registry;
std::map<std::string, const void *> name_to_ptr_registry;
};

View File

@@ -317,6 +317,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
num_nodes = hpx::get_num_localities().get();
new WorkFunctionRegistry();
new RuntimeContextManager();
_dfr_jit_phase_barrier = new hpx::lcos::barrier("phase_barrier", num_nodes,
hpx::get_locality_id());
_dfr_startup_barrier = new hpx::lcos::barrier("startup_barrier", num_nodes,
@@ -357,20 +358,21 @@ void _dfr_start(int64_t use_dfr_p, void *ctx) {
// cancelled function which registers the work functions.
if (!_dfr_is_root_node() && !_dfr_is_jit())
_dfr_stop_impl();
}
// If DFR is used and a runtime context is needed, and execution is
// distributed, then broadcast from root to all compute nodes.
if (use_dfr_p && (num_nodes > 1) && (ctx || !_dfr_is_root_node())) {
BEGIN_TIME(&broadcast_timer);
new RuntimeContextManager();
_dfr_node_level_runtime_context_manager->setContext(ctx);
// If DFR is used and a runtime context is needed, and execution is
// distributed, then broadcast from root to all compute nodes.
if (num_nodes > 1 && (ctx || !_dfr_is_root_node())) {
BEGIN_TIME(&broadcast_timer);
_dfr_node_level_runtime_context_manager->setContext(ctx);
}
// If this is not JIT, then the remote nodes never reach _dfr_stop,
// so root should not instantiate this barrier.
if (_dfr_is_root_node() && _dfr_is_jit())
_dfr_startup_barrier->wait();
END_TIME(&broadcast_timer, "Key broadcasting");
if (num_nodes > 1 && ctx) {
END_TIME(&broadcast_timer, "Key broadcasting");
}
}
BEGIN_TIME(&compute_timer);
}
@@ -396,11 +398,11 @@ void _dfr_stop(int64_t use_dfr_p) {
// gain as the root node would be waiting for the end of computation
// on all remote nodes before reaching here anyway (dataflow
// dependences).
if (_dfr_is_jit()) {
if (_dfr_is_jit())
_dfr_jit_phase_barrier->wait();
}
_dfr_node_level_runtime_context_manager->clearContext();
_dfr_node_level_work_function_registry->clearRegistry();
}
}
END_TIME(&compute_timer, "Compute");