fix(dfr): broadcast evaluation keys early to avoid locking in HPX helper threads.

This commit is contained in:
Antoniu Pop
2022-07-22 15:49:54 +01:00
committed by Antoniu Pop
parent fbca52f4a0
commit 1bb3d04059
5 changed files with 377 additions and 551 deletions

View File

@@ -69,21 +69,19 @@ struct OpaqueInputData {
std::vector<size_t> _param_sizes,
std::vector<uint64_t> _param_types,
std::vector<size_t> _output_sizes,
std::vector<uint64_t> _output_types, bool _alloc_p = false)
std::vector<uint64_t> _output_types)
: wfn_name(_wfn_name), params(std::move(_params)),
param_sizes(std::move(_param_sizes)),
param_types(std::move(_param_types)),
output_sizes(std::move(_output_sizes)),
output_types(std::move(_output_types)), alloc_p(_alloc_p),
source_locality(hpx::find_here()), ksk_id(0), bsk_id(0) {}
output_types(std::move(_output_types)), ksk_id(0), bsk_id(0) {}
OpaqueInputData(const OpaqueInputData &oid)
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
param_sizes(std::move(oid.param_sizes)),
param_types(std::move(oid.param_types)),
output_sizes(std::move(oid.output_sizes)),
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
source_locality(oid.source_locality), ksk_id(oid.ksk_id),
output_types(std::move(oid.output_types)), ksk_id(oid.ksk_id),
bsk_id(oid.bsk_id) {}
friend class hpx::serialization::access;
@@ -91,7 +89,6 @@ struct OpaqueInputData {
ar >> wfn_name;
ar >> param_sizes >> param_types;
ar >> output_sizes >> output_types;
ar >> source_locality;
for (size_t p = 0; p < param_sizes.size(); ++p) {
char *param;
_dfr_checked_aligned_alloc((void **)&param, 64, param_sizes[p]);
@@ -118,15 +115,13 @@ struct OpaqueInputData {
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
} break;
case _DFR_TASK_ARG_CONTEXT: {
ar >> bsk_id >> ksk_id;
// The copied pointer is meaningless - TODO: if the context
// can change dynamically (e.g., different evaluation keys)
// then this needs updating by passing key ids and retrieving
// adequate keys for the context.
delete ((char *)params[p]);
// TODO: this might be relaxed with newer versions of HPX.
// Do not set the context here as remote operations are
// unstable when initiated within a HPX helper thread.
params[p] =
(void *)
_dfr_node_level_runtime_context_manager->getContextAddress();
(void *)_dfr_node_level_runtime_context_manager->getContext();
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
@@ -134,14 +129,12 @@ struct OpaqueInputData {
"Error: invalid task argument type.");
}
}
alloc_p = true;
}
template <class Archive>
void save(Archive &ar, const unsigned int version) const {
ar << wfn_name;
ar << param_sizes << param_types;
ar << output_sizes << output_types;
ar << source_locality;
for (size_t p = 0; p < params.size(); ++p) {
// Save the first level of the data structure - if the parameter
// is a tensor/memref, there is a second level.
@@ -161,18 +154,8 @@ struct OpaqueInputData {
mref.data + mref.offset * elementSize, size * elementSize);
} break;
case _DFR_TASK_ARG_CONTEXT: {
mlir::concretelang::RuntimeContext *context =
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
std::cout << "Registering Key ids " << (uint64_t)ksk << " "
<< (uint64_t)bsk << "\n"
<< std::flush;
_dfr_register_bsk(bsk, (uint64_t)bsk);
_dfr_register_ksk(ksk, (uint64_t)ksk);
ar << (uint64_t)bsk << (uint64_t)ksk;
// Nothing to do now - TODO: pass key ids if these are not
// unique for a computation.
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
@@ -189,8 +172,6 @@ struct OpaqueInputData {
std::vector<uint64_t> param_types;
std::vector<size_t> output_sizes;
std::vector<uint64_t> output_types;
bool alloc_p = false;
hpx::naming::id_type source_locality;
uint64_t ksk_id;
uint64_t bsk_id;
};
@@ -199,13 +180,13 @@ struct OpaqueOutputData {
OpaqueOutputData() = default;
OpaqueOutputData(std::vector<void *> outputs,
std::vector<size_t> output_sizes,
std::vector<uint64_t> output_types, bool alloc_p = false)
std::vector<uint64_t> output_types)
: outputs(std::move(outputs)), output_sizes(std::move(output_sizes)),
output_types(std::move(output_types)), alloc_p(alloc_p) {}
output_types(std::move(output_types)) {}
OpaqueOutputData(const OpaqueOutputData &ood)
: outputs(std::move(ood.outputs)),
output_sizes(std::move(ood.output_sizes)),
output_types(std::move(ood.output_types)), alloc_p(ood.alloc_p) {}
output_types(std::move(ood.output_types)) {}
friend class hpx::serialization::access;
template <class Archive> void load(Archive &ar, const unsigned int version) {
@@ -246,7 +227,6 @@ struct OpaqueOutputData {
"Error: invalid task argument type.");
}
}
alloc_p = true;
}
template <class Archive>
void save(Archive &ar, const unsigned int version) const {
@@ -283,7 +263,6 @@ struct OpaqueOutputData {
std::vector<void *> outputs;
std::vector<size_t> output_sizes;
std::vector<uint64_t> output_types;
bool alloc_p = false;
};
struct GenericComputeServer : component_base<GenericComputeServer> {
@@ -295,12 +274,6 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.wfn_name);
std::vector<void *> outputs;
if (inputs.source_locality != hpx::find_here() &&
(inputs.ksk_id || inputs.bsk_id)) {
_dfr_node_level_runtime_context_manager->getContext(
inputs.ksk_id, inputs.bsk_id, inputs.source_locality);
}
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
inputs.output_sizes.size());
hpx::cout << std::flush;
@@ -735,7 +708,7 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
}
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
std::move(inputs.output_types), inputs.alloc_p);
std::move(inputs.output_types));
}
HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task);

View File

@@ -29,14 +29,9 @@ template <typename T> struct KeyManager;
struct RuntimeContextManager;
namespace {
static void *dl_handle;
static KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
static KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id);
template <typename LweKeyType> struct KeyWrapper {
LweKeyType *key;
@@ -44,6 +39,10 @@ template <typename LweKeyType> struct KeyWrapper {
KeyWrapper(LweKeyType *key) : key(key) {}
KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {}
KeyWrapper(const KeyWrapper &kw) : key(kw.key) {}
KeyWrapper &operator=(const KeyWrapper &rhs) {
this->key = rhs.key;
return *this;
}
friend class hpx::serialization::access;
template <class Archive>
void save(Archive &ar, const unsigned int version) const;
@@ -51,6 +50,12 @@ template <typename LweKeyType> struct KeyWrapper {
HPX_SERIALIZATION_SPLIT_MEMBER()
};
template <typename LweKeyType>
bool operator==(const KeyWrapper<LweKeyType> &lhs,
const KeyWrapper<LweKeyType> &rhs) {
return lhs.key == rhs.key;
}
template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey_u64>::save(Archive &ar,
@@ -91,137 +96,6 @@ void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
key = deserialize_lwe_keyswitching_key_u64(buffer);
}
template <typename LweKeyType> struct KeyManager {
KeyManager() {}
LweKeyType *get_key(hpx::naming::id_type loc, const uint64_t key_id);
KeyWrapper<LweKeyType> fetch_key(const uint64_t key_id) {
std::lock_guard<std::mutex> guard(keystore_guard);
auto keyit = keystore.find(key_id);
if (keyit != keystore.end())
return keyit->second;
// If this node does not contain this key, this is an error
// (location was supplied as source for this key).
HPX_THROW_EXCEPTION(
hpx::no_success, "fetch_key",
"Error: could not find key to be fetched on source location.");
}
void register_key(LweKeyType *key, uint64_t key_id) {
std::lock_guard<std::mutex> guard(keystore_guard);
auto keyit = keystore.find(key_id);
if (keyit == keystore.end()) {
keyit = keystore
.insert(std::pair<uint64_t, KeyWrapper<LweKeyType>>(
key_id, KeyWrapper<LweKeyType>(key)))
.first;
if (keyit == keystore.end()) {
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key",
"Error: could not register new key.");
}
}
}
void clear_keys() {
std::lock_guard<std::mutex> guard(keystore_guard);
keystore.clear();
}
private:
std::mutex keystore_guard;
std::map<uint64_t, KeyWrapper<LweKeyType>> keystore;
};
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t key_id) {
return _dfr_node_level_bsk_manager->fetch_key(key_id);
}
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t key_id) {
return _dfr_node_level_ksk_manager->fetch_key(key_id);
}
} // namespace dfr
} // namespace concretelang
} // namespace mlir
HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_ksk, _dfr_fetch_ksk_action)
HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_bsk, _dfr_fetch_bsk_action)
namespace mlir {
namespace concretelang {
namespace dfr {
template <> KeyManager<LweBootstrapKey_u64>::KeyManager() {
_dfr_node_level_bsk_manager = this;
}
template <>
LweBootstrapKey_u64 *
KeyManager<LweBootstrapKey_u64>::get_key(hpx::naming::id_type loc,
const uint64_t key_id) {
keystore_guard.lock();
auto keyit = keystore.find(key_id);
keystore_guard.unlock();
if (keyit == keystore.end()) {
_dfr_fetch_bsk_action fetch;
KeyWrapper<LweBootstrapKey_u64> &&bskw = fetch(loc, key_id);
if (bskw.key == nullptr) {
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
"Error: Bootstrap key not found on root node.");
} else {
_dfr_register_bsk(bskw.key, key_id);
}
return bskw.key;
}
return keyit->second.key;
}
template <> KeyManager<LweKeyswitchKey_u64>::KeyManager() {
_dfr_node_level_ksk_manager = this;
}
template <>
LweKeyswitchKey_u64 *
KeyManager<LweKeyswitchKey_u64>::get_key(hpx::naming::id_type loc,
const uint64_t key_id) {
keystore_guard.lock();
auto keyit = keystore.find(key_id);
keystore_guard.unlock();
if (keyit == keystore.end()) {
_dfr_fetch_ksk_action fetch;
KeyWrapper<LweKeyswitchKey_u64> &&kskw = fetch(loc, key_id);
if (kskw.key == nullptr) {
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
"Error: Keyswitching key not found on root node.");
} else {
_dfr_register_ksk(kskw.key, key_id);
}
return kskw.key;
}
return keyit->second.key;
}
/************************/
/* Key management API. */
/************************/
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id) {
_dfr_node_level_bsk_manager->register_key(key, key_id);
}
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id) {
_dfr_node_level_ksk_manager->register_key(key, key_id);
}
LweBootstrapKey_u64 *_dfr_get_bsk(hpx::naming::id_type loc, uint64_t key_id) {
return _dfr_node_level_bsk_manager->get_key(loc, key_id);
}
LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
}
/************************/
/* Context management. */
/************************/
@@ -230,58 +104,52 @@ struct RuntimeContextManager {
// TODO: this is only ok so long as we don't change keys. Once we
// use multiple keys, should have a map.
RuntimeContext *context;
std::mutex context_guard;
uint64_t ksk_id;
uint64_t bsk_id;
RuntimeContextManager() {
ksk_id = 0;
bsk_id = 0;
context = nullptr;
_dfr_node_level_runtime_context_manager = this;
}
RuntimeContext *getContext(uint64_t ksk, uint64_t bsk,
hpx::naming::id_type source_locality) {
std::cout << "GetContext on node " << hpx::get_locality_id()
<< " with context " << context << " " << bsk_id << " " << ksk_id
<< "\n"
<< std::flush;
if (context != nullptr) {
std::cout << "simil " << ksk_id << " " << ksk << " " << bsk_id << " "
<< bsk << "\n"
<< std::flush;
assert(ksk == ksk_id && bsk == bsk_id &&
"Context manager can only used with single keys for now.");
void setContext(void *ctx) {
assert(context == nullptr &&
"Only one RuntimeContext can be used at a time.");
// Root node broadcasts the evaluation keys and each remote
// instantiates a local RuntimeContext.
if (_dfr_is_root_node()) {
RuntimeContext *context = (RuntimeContext *)ctx;
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
auto kskFut = hpx::collectives::broadcast_to(
"ksk_keystore", KeyWrapper<LweKeyswitchKey_u64>(ksk));
auto bskFut = hpx::collectives::broadcast_to(
"bsk_keystore", KeyWrapper<LweBootstrapKey_u64>(bsk));
} else {
assert(ksk_id == 0 && bsk_id == 0 &&
"Context empty but context manager has key ids.");
LweKeyswitchKey_u64 *keySwitchKey = _dfr_get_ksk(source_locality, ksk);
LweBootstrapKey_u64 *bootstrapKey = _dfr_get_bsk(source_locality, bsk);
std::lock_guard<std::mutex> guard(context_guard);
if (context == nullptr) {
auto ctx = new RuntimeContext();
ctx->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
new ::concretelang::clientlib::LweKeyswitchKey(keySwitchKey)),
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
new ::concretelang::clientlib::LweBootstrapKey(bootstrapKey)));
ksk_id = ksk;
bsk_id = bsk;
context = ctx;
std::cout << "Fetching Key ids " << ksk_id << " " << bsk_id << "\n"
<< std::flush;
} else {
std::cout << " GOT context after LOCK on node "
<< hpx::get_locality_id() << " with context " << context
<< " " << bsk_id << " " << ksk_id << "\n"
<< std::flush;
}
auto kskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey_u64>>(
"ksk_keystore");
auto bskFut =
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey_u64>>(
"bsk_keystore");
context = new mlir::concretelang::RuntimeContext();
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
new ::concretelang::clientlib::LweKeyswitchKey(kskFut.get().key)),
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
new ::concretelang::clientlib::LweBootstrapKey(
bskFut.get().key)));
}
return context;
}
RuntimeContext **getContextAddress() { return &context; }
RuntimeContext **getContext() { return &context; }
void clearContext() {
if (context != nullptr)
delete context;
context = nullptr;
}
};
} // namespace dfr

View File

@@ -26,6 +26,7 @@ void _dfr_deallocate_future(void *);
void _dfr_deallocate_future_data(void *);
/* Initialisation & termination. */
void _dfr_start_c(void *);
void _dfr_start();
void _dfr_stop();