mirror of
https://github.com/zama-ai/concrete.git
synced 2026-04-17 03:00:54 -04:00
fix(dfr): broadcast evaluation keys early to avoid locking in HPX helper threads.
This commit is contained in:
@@ -69,21 +69,19 @@ struct OpaqueInputData {
|
||||
std::vector<size_t> _param_sizes,
|
||||
std::vector<uint64_t> _param_types,
|
||||
std::vector<size_t> _output_sizes,
|
||||
std::vector<uint64_t> _output_types, bool _alloc_p = false)
|
||||
std::vector<uint64_t> _output_types)
|
||||
: wfn_name(_wfn_name), params(std::move(_params)),
|
||||
param_sizes(std::move(_param_sizes)),
|
||||
param_types(std::move(_param_types)),
|
||||
output_sizes(std::move(_output_sizes)),
|
||||
output_types(std::move(_output_types)), alloc_p(_alloc_p),
|
||||
source_locality(hpx::find_here()), ksk_id(0), bsk_id(0) {}
|
||||
output_types(std::move(_output_types)), ksk_id(0), bsk_id(0) {}
|
||||
|
||||
OpaqueInputData(const OpaqueInputData &oid)
|
||||
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
|
||||
param_sizes(std::move(oid.param_sizes)),
|
||||
param_types(std::move(oid.param_types)),
|
||||
output_sizes(std::move(oid.output_sizes)),
|
||||
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
|
||||
source_locality(oid.source_locality), ksk_id(oid.ksk_id),
|
||||
output_types(std::move(oid.output_types)), ksk_id(oid.ksk_id),
|
||||
bsk_id(oid.bsk_id) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
@@ -91,7 +89,6 @@ struct OpaqueInputData {
|
||||
ar >> wfn_name;
|
||||
ar >> param_sizes >> param_types;
|
||||
ar >> output_sizes >> output_types;
|
||||
ar >> source_locality;
|
||||
for (size_t p = 0; p < param_sizes.size(); ++p) {
|
||||
char *param;
|
||||
_dfr_checked_aligned_alloc((void **)¶m, 64, param_sizes[p]);
|
||||
@@ -118,15 +115,13 @@ struct OpaqueInputData {
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
ar >> bsk_id >> ksk_id;
|
||||
|
||||
// The copied pointer is meaningless - TODO: if the context
|
||||
// can change dynamically (e.g., different evaluation keys)
|
||||
// then this needs updating by passing key ids and retrieving
|
||||
// adequate keys for the context.
|
||||
delete ((char *)params[p]);
|
||||
// TODO: this might be relaxed with newer versions of HPX.
|
||||
// Do not set the context here as remote operations are
|
||||
// unstable when initiated within a HPX helper thread.
|
||||
params[p] =
|
||||
(void *)
|
||||
_dfr_node_level_runtime_context_manager->getContextAddress();
|
||||
(void *)_dfr_node_level_runtime_context_manager->getContext();
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
@@ -134,14 +129,12 @@ struct OpaqueInputData {
|
||||
"Error: invalid task argument type.");
|
||||
}
|
||||
}
|
||||
alloc_p = true;
|
||||
}
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
ar << wfn_name;
|
||||
ar << param_sizes << param_types;
|
||||
ar << output_sizes << output_types;
|
||||
ar << source_locality;
|
||||
for (size_t p = 0; p < params.size(); ++p) {
|
||||
// Save the first level of the data structure - if the parameter
|
||||
// is a tensor/memref, there is a second level.
|
||||
@@ -161,18 +154,8 @@ struct OpaqueInputData {
|
||||
mref.data + mref.offset * elementSize, size * elementSize);
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
mlir::concretelang::RuntimeContext *context =
|
||||
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
|
||||
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
|
||||
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
|
||||
|
||||
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
|
||||
std::cout << "Registering Key ids " << (uint64_t)ksk << " "
|
||||
<< (uint64_t)bsk << "\n"
|
||||
<< std::flush;
|
||||
_dfr_register_bsk(bsk, (uint64_t)bsk);
|
||||
_dfr_register_ksk(ksk, (uint64_t)ksk);
|
||||
ar << (uint64_t)bsk << (uint64_t)ksk;
|
||||
// Nothing to do now - TODO: pass key ids if these are not
|
||||
// unique for a computation.
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
@@ -189,8 +172,6 @@ struct OpaqueInputData {
|
||||
std::vector<uint64_t> param_types;
|
||||
std::vector<size_t> output_sizes;
|
||||
std::vector<uint64_t> output_types;
|
||||
bool alloc_p = false;
|
||||
hpx::naming::id_type source_locality;
|
||||
uint64_t ksk_id;
|
||||
uint64_t bsk_id;
|
||||
};
|
||||
@@ -199,13 +180,13 @@ struct OpaqueOutputData {
|
||||
OpaqueOutputData() = default;
|
||||
OpaqueOutputData(std::vector<void *> outputs,
|
||||
std::vector<size_t> output_sizes,
|
||||
std::vector<uint64_t> output_types, bool alloc_p = false)
|
||||
std::vector<uint64_t> output_types)
|
||||
: outputs(std::move(outputs)), output_sizes(std::move(output_sizes)),
|
||||
output_types(std::move(output_types)), alloc_p(alloc_p) {}
|
||||
output_types(std::move(output_types)) {}
|
||||
OpaqueOutputData(const OpaqueOutputData &ood)
|
||||
: outputs(std::move(ood.outputs)),
|
||||
output_sizes(std::move(ood.output_sizes)),
|
||||
output_types(std::move(ood.output_types)), alloc_p(ood.alloc_p) {}
|
||||
output_types(std::move(ood.output_types)) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
@@ -246,7 +227,6 @@ struct OpaqueOutputData {
|
||||
"Error: invalid task argument type.");
|
||||
}
|
||||
}
|
||||
alloc_p = true;
|
||||
}
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
@@ -283,7 +263,6 @@ struct OpaqueOutputData {
|
||||
std::vector<void *> outputs;
|
||||
std::vector<size_t> output_sizes;
|
||||
std::vector<uint64_t> output_types;
|
||||
bool alloc_p = false;
|
||||
};
|
||||
|
||||
struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
@@ -295,12 +274,6 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.wfn_name);
|
||||
std::vector<void *> outputs;
|
||||
|
||||
if (inputs.source_locality != hpx::find_here() &&
|
||||
(inputs.ksk_id || inputs.bsk_id)) {
|
||||
_dfr_node_level_runtime_context_manager->getContext(
|
||||
inputs.ksk_id, inputs.bsk_id, inputs.source_locality);
|
||||
}
|
||||
|
||||
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
|
||||
inputs.output_sizes.size());
|
||||
hpx::cout << std::flush;
|
||||
@@ -735,7 +708,7 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
}
|
||||
|
||||
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
|
||||
std::move(inputs.output_types), inputs.alloc_p);
|
||||
std::move(inputs.output_types));
|
||||
}
|
||||
|
||||
HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task);
|
||||
|
||||
@@ -29,14 +29,9 @@ template <typename T> struct KeyManager;
|
||||
struct RuntimeContextManager;
|
||||
namespace {
|
||||
static void *dl_handle;
|
||||
static KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
|
||||
static KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
|
||||
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
|
||||
} // namespace
|
||||
|
||||
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
|
||||
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id);
|
||||
|
||||
template <typename LweKeyType> struct KeyWrapper {
|
||||
LweKeyType *key;
|
||||
|
||||
@@ -44,6 +39,10 @@ template <typename LweKeyType> struct KeyWrapper {
|
||||
KeyWrapper(LweKeyType *key) : key(key) {}
|
||||
KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {}
|
||||
KeyWrapper(const KeyWrapper &kw) : key(kw.key) {}
|
||||
KeyWrapper &operator=(const KeyWrapper &rhs) {
|
||||
this->key = rhs.key;
|
||||
return *this;
|
||||
}
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const;
|
||||
@@ -51,6 +50,12 @@ template <typename LweKeyType> struct KeyWrapper {
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
};
|
||||
|
||||
template <typename LweKeyType>
|
||||
bool operator==(const KeyWrapper<LweKeyType> &lhs,
|
||||
const KeyWrapper<LweKeyType> &rhs) {
|
||||
return lhs.key == rhs.key;
|
||||
}
|
||||
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweBootstrapKey_u64>::save(Archive &ar,
|
||||
@@ -91,137 +96,6 @@ void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
|
||||
key = deserialize_lwe_keyswitching_key_u64(buffer);
|
||||
}
|
||||
|
||||
template <typename LweKeyType> struct KeyManager {
|
||||
KeyManager() {}
|
||||
LweKeyType *get_key(hpx::naming::id_type loc, const uint64_t key_id);
|
||||
|
||||
KeyWrapper<LweKeyType> fetch_key(const uint64_t key_id) {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
|
||||
auto keyit = keystore.find(key_id);
|
||||
if (keyit != keystore.end())
|
||||
return keyit->second;
|
||||
// If this node does not contain this key, this is an error
|
||||
// (location was supplied as source for this key).
|
||||
HPX_THROW_EXCEPTION(
|
||||
hpx::no_success, "fetch_key",
|
||||
"Error: could not find key to be fetched on source location.");
|
||||
}
|
||||
|
||||
void register_key(LweKeyType *key, uint64_t key_id) {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
auto keyit = keystore.find(key_id);
|
||||
if (keyit == keystore.end()) {
|
||||
keyit = keystore
|
||||
.insert(std::pair<uint64_t, KeyWrapper<LweKeyType>>(
|
||||
key_id, KeyWrapper<LweKeyType>(key)))
|
||||
.first;
|
||||
if (keyit == keystore.end()) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key",
|
||||
"Error: could not register new key.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void clear_keys() {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
keystore.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex keystore_guard;
|
||||
std::map<uint64_t, KeyWrapper<LweKeyType>> keystore;
|
||||
};
|
||||
|
||||
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t key_id) {
|
||||
return _dfr_node_level_bsk_manager->fetch_key(key_id);
|
||||
}
|
||||
|
||||
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t key_id) {
|
||||
return _dfr_node_level_ksk_manager->fetch_key(key_id);
|
||||
}
|
||||
|
||||
} // namespace dfr
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_ksk, _dfr_fetch_ksk_action)
|
||||
HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_bsk, _dfr_fetch_bsk_action)
|
||||
|
||||
namespace mlir {
|
||||
namespace concretelang {
|
||||
namespace dfr {
|
||||
|
||||
template <> KeyManager<LweBootstrapKey_u64>::KeyManager() {
|
||||
_dfr_node_level_bsk_manager = this;
|
||||
}
|
||||
|
||||
template <>
|
||||
LweBootstrapKey_u64 *
|
||||
KeyManager<LweBootstrapKey_u64>::get_key(hpx::naming::id_type loc,
|
||||
const uint64_t key_id) {
|
||||
keystore_guard.lock();
|
||||
auto keyit = keystore.find(key_id);
|
||||
keystore_guard.unlock();
|
||||
|
||||
if (keyit == keystore.end()) {
|
||||
_dfr_fetch_bsk_action fetch;
|
||||
KeyWrapper<LweBootstrapKey_u64> &&bskw = fetch(loc, key_id);
|
||||
if (bskw.key == nullptr) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
|
||||
"Error: Bootstrap key not found on root node.");
|
||||
} else {
|
||||
_dfr_register_bsk(bskw.key, key_id);
|
||||
}
|
||||
return bskw.key;
|
||||
}
|
||||
return keyit->second.key;
|
||||
}
|
||||
|
||||
template <> KeyManager<LweKeyswitchKey_u64>::KeyManager() {
|
||||
_dfr_node_level_ksk_manager = this;
|
||||
}
|
||||
|
||||
template <>
|
||||
LweKeyswitchKey_u64 *
|
||||
KeyManager<LweKeyswitchKey_u64>::get_key(hpx::naming::id_type loc,
|
||||
const uint64_t key_id) {
|
||||
keystore_guard.lock();
|
||||
auto keyit = keystore.find(key_id);
|
||||
keystore_guard.unlock();
|
||||
|
||||
if (keyit == keystore.end()) {
|
||||
_dfr_fetch_ksk_action fetch;
|
||||
KeyWrapper<LweKeyswitchKey_u64> &&kskw = fetch(loc, key_id);
|
||||
if (kskw.key == nullptr) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
|
||||
"Error: Keyswitching key not found on root node.");
|
||||
} else {
|
||||
_dfr_register_ksk(kskw.key, key_id);
|
||||
}
|
||||
return kskw.key;
|
||||
}
|
||||
return keyit->second.key;
|
||||
}
|
||||
|
||||
/************************/
|
||||
/* Key management API. */
|
||||
/************************/
|
||||
|
||||
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id) {
|
||||
_dfr_node_level_bsk_manager->register_key(key, key_id);
|
||||
}
|
||||
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id) {
|
||||
_dfr_node_level_ksk_manager->register_key(key, key_id);
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *_dfr_get_bsk(hpx::naming::id_type loc, uint64_t key_id) {
|
||||
return _dfr_node_level_bsk_manager->get_key(loc, key_id);
|
||||
}
|
||||
LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
|
||||
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
|
||||
}
|
||||
|
||||
/************************/
|
||||
/* Context management. */
|
||||
/************************/
|
||||
@@ -230,58 +104,52 @@ struct RuntimeContextManager {
|
||||
// TODO: this is only ok so long as we don't change keys. Once we
|
||||
// use multiple keys, should have a map.
|
||||
RuntimeContext *context;
|
||||
std::mutex context_guard;
|
||||
uint64_t ksk_id;
|
||||
uint64_t bsk_id;
|
||||
|
||||
RuntimeContextManager() {
|
||||
ksk_id = 0;
|
||||
bsk_id = 0;
|
||||
context = nullptr;
|
||||
_dfr_node_level_runtime_context_manager = this;
|
||||
}
|
||||
|
||||
RuntimeContext *getContext(uint64_t ksk, uint64_t bsk,
|
||||
hpx::naming::id_type source_locality) {
|
||||
std::cout << "GetContext on node " << hpx::get_locality_id()
|
||||
<< " with context " << context << " " << bsk_id << " " << ksk_id
|
||||
<< "\n"
|
||||
<< std::flush;
|
||||
if (context != nullptr) {
|
||||
std::cout << "simil " << ksk_id << " " << ksk << " " << bsk_id << " "
|
||||
<< bsk << "\n"
|
||||
<< std::flush;
|
||||
assert(ksk == ksk_id && bsk == bsk_id &&
|
||||
"Context manager can only used with single keys for now.");
|
||||
void setContext(void *ctx) {
|
||||
assert(context == nullptr &&
|
||||
"Only one RuntimeContext can be used at a time.");
|
||||
|
||||
// Root node broadcasts the evaluation keys and each remote
|
||||
// instantiates a local RuntimeContext.
|
||||
if (_dfr_is_root_node()) {
|
||||
RuntimeContext *context = (RuntimeContext *)ctx;
|
||||
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
|
||||
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
|
||||
|
||||
auto kskFut = hpx::collectives::broadcast_to(
|
||||
"ksk_keystore", KeyWrapper<LweKeyswitchKey_u64>(ksk));
|
||||
auto bskFut = hpx::collectives::broadcast_to(
|
||||
"bsk_keystore", KeyWrapper<LweBootstrapKey_u64>(bsk));
|
||||
} else {
|
||||
assert(ksk_id == 0 && bsk_id == 0 &&
|
||||
"Context empty but context manager has key ids.");
|
||||
LweKeyswitchKey_u64 *keySwitchKey = _dfr_get_ksk(source_locality, ksk);
|
||||
LweBootstrapKey_u64 *bootstrapKey = _dfr_get_bsk(source_locality, bsk);
|
||||
std::lock_guard<std::mutex> guard(context_guard);
|
||||
if (context == nullptr) {
|
||||
auto ctx = new RuntimeContext();
|
||||
ctx->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
|
||||
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
|
||||
new ::concretelang::clientlib::LweKeyswitchKey(keySwitchKey)),
|
||||
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
|
||||
new ::concretelang::clientlib::LweBootstrapKey(bootstrapKey)));
|
||||
ksk_id = ksk;
|
||||
bsk_id = bsk;
|
||||
context = ctx;
|
||||
std::cout << "Fetching Key ids " << ksk_id << " " << bsk_id << "\n"
|
||||
<< std::flush;
|
||||
} else {
|
||||
std::cout << " GOT context after LOCK on node "
|
||||
<< hpx::get_locality_id() << " with context " << context
|
||||
<< " " << bsk_id << " " << ksk_id << "\n"
|
||||
<< std::flush;
|
||||
}
|
||||
auto kskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweKeyswitchKey_u64>>(
|
||||
"ksk_keystore");
|
||||
auto bskFut =
|
||||
hpx::collectives::broadcast_from<KeyWrapper<LweBootstrapKey_u64>>(
|
||||
"bsk_keystore");
|
||||
|
||||
context = new mlir::concretelang::RuntimeContext();
|
||||
context->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
|
||||
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
|
||||
new ::concretelang::clientlib::LweKeyswitchKey(kskFut.get().key)),
|
||||
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
|
||||
new ::concretelang::clientlib::LweBootstrapKey(
|
||||
bskFut.get().key)));
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
RuntimeContext **getContextAddress() { return &context; }
|
||||
RuntimeContext **getContext() { return &context; }
|
||||
|
||||
void clearContext() {
|
||||
if (context != nullptr)
|
||||
delete context;
|
||||
context = nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace dfr
|
||||
|
||||
@@ -26,6 +26,7 @@ void _dfr_deallocate_future(void *);
|
||||
void _dfr_deallocate_future_data(void *);
|
||||
|
||||
/* Initialisation & termination. */
|
||||
void _dfr_start_c(void *);
|
||||
void _dfr_start();
|
||||
void _dfr_stop();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user