mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 12:15:09 -05:00
formatting(dfr): add .hpp to the formatting script and format the relevant files.
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
#ifndef ZAMALANG_DFR_DFRUNTIME_HPP
|
||||
#define ZAMALANG_DFR_DFRUNTIME_HPP
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <dlfcn.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "zamalang/Runtime/runtime_api.h"
|
||||
|
||||
@@ -15,35 +15,27 @@ struct WorkFunctionRegistry;
|
||||
extern WorkFunctionRegistry *node_level_work_function_registry;
|
||||
|
||||
// Recover the name of the work function
|
||||
static inline const char *
|
||||
_dfr_get_function_name_from_address(void *fn)
|
||||
{
|
||||
static inline const char *_dfr_get_function_name_from_address(void *fn) {
|
||||
Dl_info info;
|
||||
|
||||
if (!dladdr(fn, &info) || info.dli_sname == nullptr)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"_dfr_get_function_name_from_address",
|
||||
"Error recovering work function name from address.");
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_name_from_address",
|
||||
"Error recovering work function name from address.");
|
||||
return info.dli_sname;
|
||||
}
|
||||
|
||||
static inline wfnptr
|
||||
_dfr_get_function_pointer_from_name(const char *fn_name)
|
||||
{
|
||||
static inline wfnptr _dfr_get_function_pointer_from_name(const char *fn_name) {
|
||||
auto ptr = dlsym(dl_handle, fn_name);
|
||||
|
||||
if (ptr == nullptr)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"_dfr_get_function_pointer_from_name",
|
||||
"Error recovering work function pointer from name.");
|
||||
return (wfnptr) ptr;
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_pointer_from_name",
|
||||
"Error recovering work function pointer from name.");
|
||||
return (wfnptr)ptr;
|
||||
}
|
||||
|
||||
// Determine where new task should run. For now just round-robin
|
||||
// distribution - TODO: optimise.
|
||||
static inline size_t
|
||||
_dfr_find_next_execution_locality()
|
||||
{
|
||||
static inline size_t _dfr_find_next_execution_locality() {
|
||||
static size_t num_nodes = hpx::get_num_localities().get();
|
||||
static std::atomic<std::size_t> next_locality{0};
|
||||
|
||||
@@ -52,39 +44,33 @@ _dfr_find_next_execution_locality()
|
||||
return next_loc % num_nodes;
|
||||
}
|
||||
|
||||
static inline bool
|
||||
_dfr_is_root_node()
|
||||
{
|
||||
static inline bool _dfr_is_root_node() {
|
||||
return hpx::find_here() == hpx::find_root_locality();
|
||||
}
|
||||
|
||||
struct WorkFunctionRegistry
|
||||
{
|
||||
WorkFunctionRegistry()
|
||||
{
|
||||
node_level_work_function_registry = this;
|
||||
}
|
||||
struct WorkFunctionRegistry {
|
||||
WorkFunctionRegistry() { node_level_work_function_registry = this; }
|
||||
|
||||
wfnptr getWorkFunctionPointer(const std::string &name)
|
||||
{
|
||||
wfnptr getWorkFunctionPointer(const std::string &name) {
|
||||
std::lock_guard<std::mutex> guard(registry_guard);
|
||||
|
||||
auto fnptrit = name_to_ptr_registry.find(name);
|
||||
if (fnptrit != name_to_ptr_registry.end())
|
||||
return (wfnptr) fnptrit->second;
|
||||
return (wfnptr)fnptrit->second;
|
||||
|
||||
auto ptr = dlsym(dl_handle, name.c_str());
|
||||
if (ptr == nullptr)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"WorkFunctionRegistry::getWorkFunctionPointer",
|
||||
"Error recovering work function pointer from name.");
|
||||
ptr_to_name_registry.insert(std::pair<const void *, std::string>(ptr, name));
|
||||
name_to_ptr_registry.insert(std::pair<std::string, const void *>(name, ptr));
|
||||
return (wfnptr) ptr;
|
||||
"WorkFunctionRegistry::getWorkFunctionPointer",
|
||||
"Error recovering work function pointer from name.");
|
||||
ptr_to_name_registry.insert(
|
||||
std::pair<const void *, std::string>(ptr, name));
|
||||
name_to_ptr_registry.insert(
|
||||
std::pair<std::string, const void *>(name, ptr));
|
||||
return (wfnptr)ptr;
|
||||
}
|
||||
|
||||
std::string getWorkFunctionName(const void *fn)
|
||||
{
|
||||
std::string getWorkFunctionName(const void *fn) {
|
||||
std::lock_guard<std::mutex> guard(registry_guard);
|
||||
|
||||
auto fnnameit = ptr_to_name_registry.find(fn);
|
||||
@@ -96,17 +82,17 @@ struct WorkFunctionRegistry
|
||||
// Assume that if we can't find the name, there is no dynamic
|
||||
// library to find it in. TODO: fix this to distinguish JIT/binary
|
||||
// and in case of distributed exec.
|
||||
if (!dladdr(fn, &info) || info.dli_sname == nullptr)
|
||||
{
|
||||
static std::atomic<unsigned int> fnid{0};
|
||||
ret = "_dfr_jit_wfnname_" + std::to_string(fnid++);
|
||||
} else {
|
||||
if (!dladdr(fn, &info) || info.dli_sname == nullptr) {
|
||||
static std::atomic<unsigned int> fnid{0};
|
||||
ret = "_dfr_jit_wfnname_" + std::to_string(fnid++);
|
||||
} else {
|
||||
ret = info.dli_sname;
|
||||
}
|
||||
ptr_to_name_registry.insert(std::pair<const void *, std::string>(fn, ret));
|
||||
name_to_ptr_registry.insert(std::pair<std::string, const void *>(ret, fn));
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex registry_guard;
|
||||
std::map<const void *, std::string> ptr_to_name_registry;
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
#define ZAMALANG_DFR_DISTRIBUTED_GENERIC_TASK_SERVER_HPP
|
||||
|
||||
#include <cstdarg>
|
||||
#include <string>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
|
||||
#include <hpx/async_colocated/get_colocation_id.hpp>
|
||||
#include <hpx/include/actions.hpp>
|
||||
#include <hpx/include/lcos.hpp>
|
||||
#include <hpx/include/parallel_algorithm.hpp>
|
||||
@@ -14,15 +15,14 @@
|
||||
#include <hpx/serialization/detail/serialize_collection.hpp>
|
||||
#include <hpx/serialization/serialization_fwd.hpp>
|
||||
#include <hpx/serialization/serialize.hpp>
|
||||
#include <hpx/async_colocated/get_colocation_id.hpp>
|
||||
|
||||
#include <hpx/async_colocated/get_colocation_id.hpp>
|
||||
#include <hpx/include/client.hpp>
|
||||
#include <hpx/include/runtime.hpp>
|
||||
#include <hpx/modules/collectives.hpp>
|
||||
|
||||
#include "zamalang/Runtime/key_manager.hpp"
|
||||
#include "zamalang/Runtime/DFRuntime.hpp"
|
||||
#include "zamalang/Runtime/key_manager.hpp"
|
||||
|
||||
extern WorkFunctionRegistry *node_level_work_function_registry;
|
||||
|
||||
@@ -30,57 +30,43 @@ using namespace hpx::naming;
|
||||
using namespace hpx::components;
|
||||
using namespace hpx::collectives;
|
||||
|
||||
|
||||
struct OpaqueInputData
|
||||
{
|
||||
struct OpaqueInputData {
|
||||
OpaqueInputData() = default;
|
||||
|
||||
OpaqueInputData(std::string wfn_name,
|
||||
std::vector<void *> params,
|
||||
std::vector<size_t> param_sizes,
|
||||
std::vector<size_t> output_sizes,
|
||||
bool alloc_p = false) :
|
||||
wfn_name(wfn_name),
|
||||
params(std::move(params)),
|
||||
param_sizes(std::move(param_sizes)),
|
||||
output_sizes(std::move(output_sizes)),
|
||||
alloc_p(alloc_p)
|
||||
{}
|
||||
OpaqueInputData(std::string wfn_name, std::vector<void *> params,
|
||||
std::vector<size_t> param_sizes,
|
||||
std::vector<size_t> output_sizes, bool alloc_p = false)
|
||||
: wfn_name(wfn_name), params(std::move(params)),
|
||||
param_sizes(std::move(param_sizes)),
|
||||
output_sizes(std::move(output_sizes)), alloc_p(alloc_p) {}
|
||||
|
||||
OpaqueInputData(const OpaqueInputData &oid) :
|
||||
wfn_name(std::move(oid.wfn_name)),
|
||||
params(std::move(oid.params)),
|
||||
param_sizes(std::move(oid.param_sizes)),
|
||||
output_sizes(std::move(oid.output_sizes)),
|
||||
alloc_p(oid.alloc_p)
|
||||
{}
|
||||
OpaqueInputData(const OpaqueInputData &oid)
|
||||
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
|
||||
param_sizes(std::move(oid.param_sizes)),
|
||||
output_sizes(std::move(oid.output_sizes)), alloc_p(oid.alloc_p) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive>
|
||||
void load(Archive &ar, const unsigned int version)
|
||||
{
|
||||
ar & wfn_name;
|
||||
ar & param_sizes;
|
||||
ar & output_sizes;
|
||||
for (auto p : param_sizes)
|
||||
{
|
||||
char *param = new char[p];
|
||||
// TODO: Optimise these serialisation operations
|
||||
for (size_t i = 0; i < p; ++i)
|
||||
ar & param[i];
|
||||
params.push_back((void *)param);
|
||||
}
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar &wfn_name;
|
||||
ar ¶m_sizes;
|
||||
ar &output_sizes;
|
||||
for (auto p : param_sizes) {
|
||||
char *param = new char[p];
|
||||
// TODO: Optimise these serialisation operations
|
||||
for (size_t i = 0; i < p; ++i)
|
||||
ar ¶m[i];
|
||||
params.push_back((void *)param);
|
||||
}
|
||||
alloc_p = true;
|
||||
}
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const
|
||||
{
|
||||
ar & wfn_name;
|
||||
ar & param_sizes;
|
||||
ar & output_sizes;
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
ar &wfn_name;
|
||||
ar ¶m_sizes;
|
||||
ar &output_sizes;
|
||||
for (size_t p = 0; p < params.size(); ++p)
|
||||
for (size_t i = 0; i < param_sizes[p]; ++i)
|
||||
ar & static_cast<char *>(params[p])[i];
|
||||
ar &static_cast<char *>(params[p])[i];
|
||||
}
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
|
||||
@@ -91,49 +77,38 @@ struct OpaqueInputData
|
||||
bool alloc_p = false;
|
||||
};
|
||||
|
||||
struct OpaqueOutputData
|
||||
{
|
||||
struct OpaqueOutputData {
|
||||
OpaqueOutputData() = default;
|
||||
OpaqueOutputData(std::vector<void *> outputs,
|
||||
std::vector<size_t> output_sizes,
|
||||
bool alloc_p = false) :
|
||||
outputs(std::move(outputs)),
|
||||
output_sizes(std::move(output_sizes)),
|
||||
alloc_p(alloc_p)
|
||||
{}
|
||||
OpaqueOutputData(const OpaqueOutputData &ood) :
|
||||
outputs(std::move(ood.outputs)),
|
||||
output_sizes(std::move(ood.output_sizes)),
|
||||
alloc_p(ood.alloc_p)
|
||||
{}
|
||||
std::vector<size_t> output_sizes, bool alloc_p = false)
|
||||
: outputs(std::move(outputs)), output_sizes(std::move(output_sizes)),
|
||||
alloc_p(alloc_p) {}
|
||||
OpaqueOutputData(const OpaqueOutputData &ood)
|
||||
: outputs(std::move(ood.outputs)),
|
||||
output_sizes(std::move(ood.output_sizes)), alloc_p(ood.alloc_p) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive>
|
||||
void load(Archive &ar, const unsigned int version)
|
||||
{
|
||||
ar & output_sizes;
|
||||
for (auto p : output_sizes)
|
||||
{
|
||||
char *output = new char[p];
|
||||
for (size_t i = 0; i < p; ++i)
|
||||
ar & output[i];
|
||||
outputs.push_back((void *)output);
|
||||
}
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar &output_sizes;
|
||||
for (auto p : output_sizes) {
|
||||
char *output = new char[p];
|
||||
for (size_t i = 0; i < p; ++i)
|
||||
ar &output[i];
|
||||
outputs.push_back((void *)output);
|
||||
}
|
||||
alloc_p = true;
|
||||
}
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const
|
||||
{
|
||||
ar & output_sizes;
|
||||
for (size_t p = 0; p < outputs.size(); ++p)
|
||||
{
|
||||
for (size_t i = 0; i < output_sizes[p]; ++i)
|
||||
ar & static_cast<char *>(outputs[p])[i];
|
||||
// TODO: investigate if HPX is automatically deallocating
|
||||
//these. Here it could be safely assumed that these would no
|
||||
//longer be live.
|
||||
//delete ((char*)outputs[p]);
|
||||
}
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
ar &output_sizes;
|
||||
for (size_t p = 0; p < outputs.size(); ++p) {
|
||||
for (size_t i = 0; i < output_sizes[p]; ++i)
|
||||
ar &static_cast<char *>(outputs[p])[i];
|
||||
// TODO: investigate if HPX is automatically deallocating
|
||||
// these. Here it could be safely assumed that these would no
|
||||
// longer be live.
|
||||
// delete (char*)outputs[p];
|
||||
}
|
||||
}
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
|
||||
@@ -142,129 +117,124 @@ struct OpaqueOutputData
|
||||
bool alloc_p = false;
|
||||
};
|
||||
|
||||
struct GenericComputeServer : component_base<GenericComputeServer>
|
||||
{
|
||||
GenericComputeServer () = default;
|
||||
struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
GenericComputeServer() = default;
|
||||
|
||||
// Component actions exposed
|
||||
OpaqueOutputData execute_task (const OpaqueInputData &inputs)
|
||||
{
|
||||
auto wfn = node_level_work_function_registry->getWorkFunctionPointer(inputs.wfn_name);
|
||||
OpaqueOutputData execute_task(const OpaqueInputData &inputs) {
|
||||
auto wfn = node_level_work_function_registry->getWorkFunctionPointer(
|
||||
inputs.wfn_name);
|
||||
std::vector<void *> outputs;
|
||||
|
||||
switch (inputs.output_sizes.size()) {
|
||||
case 1:
|
||||
{
|
||||
void *output = (void *)(new char[inputs.output_sizes[0]]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output};
|
||||
break;
|
||||
case 1: {
|
||||
void *output = (void *)(new char[inputs.output_sizes[0]]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
case 2:
|
||||
{
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output1, output2);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output1, output2);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output1, output2};
|
||||
break;
|
||||
outputs = {output};
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output1, output2);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output1, output2);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1,
|
||||
output2);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
case 3:
|
||||
{
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
void *output3 = (void *)(new char[inputs.output_sizes[2]]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2, output3);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output1, output2, output3);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output1, output2, output3);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2, output3);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output1, output2, output3};
|
||||
break;
|
||||
outputs = {output1, output2};
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
void *output3 = (void *)(new char[inputs.output_sizes[2]]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2, output3);
|
||||
break;
|
||||
case 1:
|
||||
wfn(inputs.params[0], output1, output2, output3);
|
||||
break;
|
||||
case 2:
|
||||
wfn(inputs.params[0], inputs.params[1], output1, output2, output3);
|
||||
break;
|
||||
case 3:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1,
|
||||
output2, output3);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output1, output2, output3};
|
||||
break;
|
||||
}
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task outputs not supported.");
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "GenericComputeServer::execute_task",
|
||||
"Error: number of task outputs not supported.");
|
||||
}
|
||||
|
||||
if (inputs.alloc_p)
|
||||
for (auto p : inputs.params)
|
||||
delete((char*)p);
|
||||
delete ((char *)p);
|
||||
|
||||
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes), inputs.alloc_p);
|
||||
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
|
||||
inputs.alloc_p);
|
||||
}
|
||||
|
||||
HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task);
|
||||
};
|
||||
|
||||
HPX_REGISTER_ACTION_DECLARATION(GenericComputeServer::execute_task_action,
|
||||
GenericComputeServer_execute_task_action)
|
||||
GenericComputeServer_execute_task_action)
|
||||
|
||||
HPX_REGISTER_COMPONENT_MODULE()
|
||||
HPX_REGISTER_COMPONENT(hpx::components::component<GenericComputeServer>,
|
||||
GenericComputeServer)
|
||||
GenericComputeServer)
|
||||
|
||||
HPX_REGISTER_ACTION(GenericComputeServer::execute_task_action,
|
||||
GenericComputeServer_execute_task_action)
|
||||
GenericComputeServer_execute_task_action)
|
||||
|
||||
|
||||
struct GenericComputeClient : client_base<GenericComputeClient, GenericComputeServer>
|
||||
{
|
||||
struct GenericComputeClient
|
||||
: client_base<GenericComputeClient, GenericComputeServer> {
|
||||
typedef client_base<GenericComputeClient, GenericComputeServer> base_type;
|
||||
|
||||
GenericComputeClient() = default;
|
||||
GenericComputeClient(id_type id) : base_type(std::move(id)) {}
|
||||
|
||||
hpx::future<OpaqueOutputData>
|
||||
execute_task(const OpaqueInputData &inputs)
|
||||
{
|
||||
hpx::future<OpaqueOutputData> execute_task(const OpaqueInputData &inputs) {
|
||||
typedef GenericComputeServer::execute_task_action action_type;
|
||||
return hpx::async<action_type>(this->get_id(), inputs);
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef ZAMALANG_DFR_KEY_MANAGER_HPP
|
||||
#define ZAMALANG_DFR_KEY_MANAGER_HPP
|
||||
|
||||
#include <utility>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include <hpx/include/runtime.hpp>
|
||||
#include <hpx/modules/collectives.hpp>
|
||||
@@ -12,97 +12,80 @@
|
||||
struct PbsKeyManager;
|
||||
extern PbsKeyManager *node_level_key_manager;
|
||||
|
||||
|
||||
struct PbsKeyWrapper
|
||||
{
|
||||
struct PbsKeyWrapper {
|
||||
std::shared_ptr<void *> key;
|
||||
size_t key_id;
|
||||
size_t size;
|
||||
|
||||
PbsKeyWrapper() {}
|
||||
|
||||
PbsKeyWrapper(void *key, size_t key_id, size_t size) :
|
||||
key(std::make_shared<void *>(key)), key_id(key_id), size(size) {}
|
||||
PbsKeyWrapper(void *key, size_t key_id, size_t size)
|
||||
: key(std::make_shared<void *>(key)), key_id(key_id), size(size) {}
|
||||
|
||||
PbsKeyWrapper(std::shared_ptr<void *> key, size_t key_id, size_t size) :
|
||||
key(key), key_id(key_id), size(size) {}
|
||||
PbsKeyWrapper(std::shared_ptr<void *> key, size_t key_id, size_t size)
|
||||
: key(key), key_id(key_id), size(size) {}
|
||||
|
||||
PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept :
|
||||
key(moved.key), key_id(moved.key_id), size(moved.size) {}
|
||||
PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept
|
||||
: key(moved.key), key_id(moved.key_id), size(moved.size) {}
|
||||
|
||||
PbsKeyWrapper(const PbsKeyWrapper &pbsk) :
|
||||
key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {}
|
||||
PbsKeyWrapper(const PbsKeyWrapper &pbsk)
|
||||
: key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const
|
||||
{
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
char *_key_ = static_cast<char *>(*key);
|
||||
ar & key_id & size;
|
||||
ar &key_id &size;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
ar & _key_[i];
|
||||
ar &_key_[i];
|
||||
}
|
||||
|
||||
template <class Archive>
|
||||
void load(Archive &ar, const unsigned int version)
|
||||
{
|
||||
ar & key_id & size;
|
||||
char *_key_ = (char *) malloc(size);
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar &key_id &size;
|
||||
char *_key_ = (char *)malloc(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
ar & _key_[i];
|
||||
ar &_key_[i];
|
||||
key = std::make_shared<void *>(_key_);
|
||||
}
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
};
|
||||
|
||||
inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs)
|
||||
{
|
||||
inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs) {
|
||||
return lhs.key_id == rhs.key_id;
|
||||
}
|
||||
|
||||
|
||||
PbsKeyWrapper _dfr_fetch_key(size_t);
|
||||
HPX_PLAIN_ACTION(_dfr_fetch_key, _dfr_fetch_key_action)
|
||||
|
||||
struct PbsKeyManager
|
||||
{
|
||||
struct PbsKeyManager {
|
||||
// The initial keys registered on the root node and whether to push
|
||||
// them is TBD.
|
||||
|
||||
PbsKeyManager()
|
||||
{
|
||||
node_level_key_manager = this;
|
||||
}
|
||||
PbsKeyManager() { node_level_key_manager = this; }
|
||||
|
||||
PbsKeyWrapper get_key(const size_t key_id)
|
||||
{
|
||||
PbsKeyWrapper get_key(const size_t key_id) {
|
||||
keystore_guard.lock();
|
||||
auto keyit = keystore.find(key_id);
|
||||
keystore_guard.unlock();
|
||||
|
||||
if (keyit == keystore.end())
|
||||
{
|
||||
_dfr_fetch_key_action fet;
|
||||
PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id);
|
||||
if (pkw.size == 0)
|
||||
{
|
||||
// Maybe retry or try other nodes... but for now it's an error.
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"_dfr_get_key",
|
||||
"Error: key not found on remote node.");
|
||||
}
|
||||
else
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
keyit = keystore.insert(std::pair<size_t, PbsKeyWrapper>(key_id, pkw)).first;
|
||||
}
|
||||
if (keyit == keystore.end()) {
|
||||
_dfr_fetch_key_action fet;
|
||||
PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id);
|
||||
if (pkw.size == 0) {
|
||||
// Maybe retry or try other nodes... but for now it's an error.
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
|
||||
"Error: key not found on remote node.");
|
||||
} else {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
keyit = keystore.insert(std::pair<size_t, PbsKeyWrapper>(key_id, pkw))
|
||||
.first;
|
||||
}
|
||||
}
|
||||
return keyit->second;
|
||||
}
|
||||
|
||||
// To be used only for remote requests
|
||||
PbsKeyWrapper fetch_key(const size_t key_id)
|
||||
{
|
||||
PbsKeyWrapper fetch_key(const size_t key_id) {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
|
||||
auto keyit = keystore.find(key_id);
|
||||
@@ -112,30 +95,27 @@ struct PbsKeyManager
|
||||
return PbsKeyWrapper(nullptr, 0, 0);
|
||||
}
|
||||
|
||||
void register_key(void *key, size_t key_id, size_t size)
|
||||
{
|
||||
void register_key(void *key, size_t key_id, size_t size) {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
auto keyit =
|
||||
keystore.insert(
|
||||
std::pair<size_t, PbsKeyWrapper>(key_id,
|
||||
PbsKeyWrapper(key, key_id, size))).first;
|
||||
if (keyit == keystore.end())
|
||||
{
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"_dfr_register_key",
|
||||
"Error: could not register new key.");
|
||||
}
|
||||
auto keyit = keystore
|
||||
.insert(std::pair<size_t, PbsKeyWrapper>(
|
||||
key_id, PbsKeyWrapper(key, key_id, size)))
|
||||
.first;
|
||||
if (keyit == keystore.end()) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key",
|
||||
"Error: could not register new key.");
|
||||
}
|
||||
}
|
||||
|
||||
void broadcast_keys()
|
||||
{
|
||||
void broadcast_keys() {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
if (_dfr_is_root_node())
|
||||
hpx::collectives::broadcast_to("keystore", this->keystore).get();
|
||||
else
|
||||
keystore = std::move(
|
||||
hpx::collectives::broadcast_from<std::map<size_t, PbsKeyWrapper>>
|
||||
("keystore").get());
|
||||
hpx::collectives::broadcast_from<std::map<size_t, PbsKeyWrapper>>(
|
||||
"keystore")
|
||||
.get());
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -143,10 +123,7 @@ private:
|
||||
std::map<size_t, PbsKeyWrapper> keystore;
|
||||
};
|
||||
|
||||
|
||||
PbsKeyWrapper
|
||||
_dfr_fetch_key(size_t key_id)
|
||||
{
|
||||
PbsKeyWrapper _dfr_fetch_key(size_t key_id) {
|
||||
return node_level_key_manager->fetch_key(key_id);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user