diff --git a/.github/workflows/scripts/format_cpp.sh b/.github/workflows/scripts/format_cpp.sh index 2026e47c9..c8a4ad61d 100755 --- a/.github/workflows/scripts/format_cpp.sh +++ b/.github/workflows/scripts/format_cpp.sh @@ -2,7 +2,7 @@ set -o pipefail -find ./compiler/{include,lib,src} \( -iname "*.h" -o -iname "*.cpp" -o -iname "*.cc" \) | xargs clang-format -i -style='file' +find ./compiler/{include,lib,src} -iregex '^.*\.\(cpp\|cc\|h\|hpp\)$' | xargs clang-format -i -style='file' if [ $? -ne 0 ] then exit 1 diff --git a/compiler/include/zamalang/Runtime/DFRuntime.hpp b/compiler/include/zamalang/Runtime/DFRuntime.hpp index 55fe86cc3..a6866bf25 100644 --- a/compiler/include/zamalang/Runtime/DFRuntime.hpp +++ b/compiler/include/zamalang/Runtime/DFRuntime.hpp @@ -1,9 +1,9 @@ #ifndef ZAMALANG_DFR_DFRUNTIME_HPP #define ZAMALANG_DFR_DFRUNTIME_HPP -#include -#include #include +#include +#include #include "zamalang/Runtime/runtime_api.h" @@ -15,35 +15,27 @@ struct WorkFunctionRegistry; extern WorkFunctionRegistry *node_level_work_function_registry; // Recover the name of the work function -static inline const char * -_dfr_get_function_name_from_address(void *fn) -{ +static inline const char *_dfr_get_function_name_from_address(void *fn) { Dl_info info; if (!dladdr(fn, &info) || info.dli_sname == nullptr) - HPX_THROW_EXCEPTION(hpx::no_success, - "_dfr_get_function_name_from_address", - "Error recovering work function name from address."); + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_name_from_address", + "Error recovering work function name from address."); return info.dli_sname; } -static inline wfnptr -_dfr_get_function_pointer_from_name(const char *fn_name) -{ +static inline wfnptr _dfr_get_function_pointer_from_name(const char *fn_name) { auto ptr = dlsym(dl_handle, fn_name); if (ptr == nullptr) - HPX_THROW_EXCEPTION(hpx::no_success, - "_dfr_get_function_pointer_from_name", - "Error recovering work function pointer from name."); - return (wfnptr) ptr; + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_pointer_from_name", + "Error recovering work function pointer from name."); + return (wfnptr)ptr; } // Determine where new task should run. For now just round-robin // distribution - TODO: optimise. -static inline size_t -_dfr_find_next_execution_locality() -{ +static inline size_t _dfr_find_next_execution_locality() { static size_t num_nodes = hpx::get_num_localities().get(); static std::atomic next_locality{0}; @@ -52,39 +44,33 @@ _dfr_find_next_execution_locality() return next_loc % num_nodes; } -static inline bool -_dfr_is_root_node() -{ +static inline bool _dfr_is_root_node() { return hpx::find_here() == hpx::find_root_locality(); } -struct WorkFunctionRegistry -{ - WorkFunctionRegistry() - { - node_level_work_function_registry = this; - } +struct WorkFunctionRegistry { + WorkFunctionRegistry() { node_level_work_function_registry = this; } - wfnptr getWorkFunctionPointer(const std::string &name) - { + wfnptr getWorkFunctionPointer(const std::string &name) { std::lock_guard guard(registry_guard); auto fnptrit = name_to_ptr_registry.find(name); if (fnptrit != name_to_ptr_registry.end()) - return (wfnptr) fnptrit->second; + return (wfnptr)fnptrit->second; auto ptr = dlsym(dl_handle, name.c_str()); if (ptr == nullptr) HPX_THROW_EXCEPTION(hpx::no_success, - "WorkFunctionRegistry::getWorkFunctionPointer", - "Error recovering work function pointer from name."); - ptr_to_name_registry.insert(std::pair(ptr, name)); - name_to_ptr_registry.insert(std::pair(name, ptr)); - return (wfnptr) ptr; + "WorkFunctionRegistry::getWorkFunctionPointer", + "Error recovering work function pointer from name."); + ptr_to_name_registry.insert( + std::pair(ptr, name)); + name_to_ptr_registry.insert( + std::pair(name, ptr)); + return (wfnptr)ptr; } - std::string getWorkFunctionName(const void *fn) - { + std::string getWorkFunctionName(const void *fn) { std::lock_guard guard(registry_guard); auto fnnameit = ptr_to_name_registry.find(fn); @@ -96,17 +82,17 @@ struct WorkFunctionRegistry // Assume that if we can't find the name, there is no dynamic // library to find it in. TODO: fix this to distinguish JIT/binary // and in case of distributed exec. - if (!dladdr(fn, &info) || info.dli_sname == nullptr) - { - static std::atomic fnid{0}; - ret = "_dfr_jit_wfnname_" + std::to_string(fnid++); - } else { + if (!dladdr(fn, &info) || info.dli_sname == nullptr) { + static std::atomic fnid{0}; + ret = "_dfr_jit_wfnname_" + std::to_string(fnid++); + } else { ret = info.dli_sname; } ptr_to_name_registry.insert(std::pair(fn, ret)); name_to_ptr_registry.insert(std::pair(ret, fn)); return ret; } + private: std::mutex registry_guard; std::map ptr_to_name_registry; diff --git a/compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp b/compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp index 6d47970b6..8dcd0c883 100644 --- a/compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp +++ b/compiler/include/zamalang/Runtime/distributed_generic_task_server.hpp @@ -2,9 +2,10 @@ #define ZAMALANG_DFR_DISTRIBUTED_GENERIC_TASK_SERVER_HPP #include -#include #include +#include +#include #include #include #include @@ -14,15 +15,14 @@ #include #include #include -#include #include #include #include #include -#include "zamalang/Runtime/key_manager.hpp" #include "zamalang/Runtime/DFRuntime.hpp" +#include "zamalang/Runtime/key_manager.hpp" extern WorkFunctionRegistry *node_level_work_function_registry; @@ -30,57 +30,43 @@ using namespace hpx::naming; using namespace hpx::components; using namespace hpx::collectives; - -struct OpaqueInputData -{ +struct OpaqueInputData { OpaqueInputData() = default; - OpaqueInputData(std::string wfn_name, - std::vector params, - std::vector param_sizes, - std::vector output_sizes, - bool alloc_p = false) : - wfn_name(wfn_name), - params(std::move(params)), - param_sizes(std::move(param_sizes)), - output_sizes(std::move(output_sizes)), - alloc_p(alloc_p) - {} + OpaqueInputData(std::string wfn_name, std::vector params, + std::vector param_sizes, + std::vector output_sizes, bool alloc_p = false) + : wfn_name(wfn_name), params(std::move(params)), + param_sizes(std::move(param_sizes)), + output_sizes(std::move(output_sizes)), alloc_p(alloc_p) {} - OpaqueInputData(const OpaqueInputData &oid) : - wfn_name(std::move(oid.wfn_name)), - params(std::move(oid.params)), - param_sizes(std::move(oid.param_sizes)), - output_sizes(std::move(oid.output_sizes)), - alloc_p(oid.alloc_p) - {} + OpaqueInputData(const OpaqueInputData &oid) + : wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)), + param_sizes(std::move(oid.param_sizes)), + output_sizes(std::move(oid.output_sizes)), alloc_p(oid.alloc_p) {} friend class hpx::serialization::access; - template - void load(Archive &ar, const unsigned int version) - { - ar & wfn_name; - ar & param_sizes; - ar & output_sizes; - for (auto p : param_sizes) - { - char *param = new char[p]; - // TODO: Optimise these serialisation operations - for (size_t i = 0; i < p; ++i) - ar & param[i]; - params.push_back((void *)param); - } + template void load(Archive &ar, const unsigned int version) { + ar &wfn_name; + ar ¶m_sizes; + ar &output_sizes; + for (auto p : param_sizes) { + char *param = new char[p]; + // TODO: Optimise these serialisation operations + for (size_t i = 0; i < p; ++i) + ar ¶m[i]; + params.push_back((void *)param); + } alloc_p = true; } template - void save(Archive &ar, const unsigned int version) const - { - ar & wfn_name; - ar & param_sizes; - ar & output_sizes; + void save(Archive &ar, const unsigned int version) const { + ar &wfn_name; + ar ¶m_sizes; + ar &output_sizes; for (size_t p = 0; p < params.size(); ++p) for (size_t i = 0; i < param_sizes[p]; ++i) - ar & static_cast(params[p])[i]; + ar &static_cast(params[p])[i]; } HPX_SERIALIZATION_SPLIT_MEMBER() @@ -91,49 +77,38 @@ struct OpaqueInputData bool alloc_p = false; }; -struct OpaqueOutputData -{ +struct OpaqueOutputData { OpaqueOutputData() = default; OpaqueOutputData(std::vector outputs, - std::vector output_sizes, - bool alloc_p = false) : - outputs(std::move(outputs)), - output_sizes(std::move(output_sizes)), - alloc_p(alloc_p) - {} - OpaqueOutputData(const OpaqueOutputData &ood) : - outputs(std::move(ood.outputs)), - output_sizes(std::move(ood.output_sizes)), - alloc_p(ood.alloc_p) - {} + std::vector output_sizes, bool alloc_p = false) + : outputs(std::move(outputs)), output_sizes(std::move(output_sizes)), + alloc_p(alloc_p) {} + OpaqueOutputData(const OpaqueOutputData &ood) + : outputs(std::move(ood.outputs)), + output_sizes(std::move(ood.output_sizes)), alloc_p(ood.alloc_p) {} friend class hpx::serialization::access; - template - void load(Archive &ar, const unsigned int version) - { - ar & output_sizes; - for (auto p : output_sizes) - { - char *output = new char[p]; - for (size_t i = 0; i < p; ++i) - ar & output[i]; - outputs.push_back((void *)output); - } + template void load(Archive &ar, const unsigned int version) { + ar &output_sizes; + for (auto p : output_sizes) { + char *output = new char[p]; + for (size_t i = 0; i < p; ++i) + ar &output[i]; + outputs.push_back((void *)output); + } alloc_p = true; } template - void save(Archive &ar, const unsigned int version) const - { - ar & output_sizes; - for (size_t p = 0; p < outputs.size(); ++p) - { - for (size_t i = 0; i < output_sizes[p]; ++i) - ar & static_cast(outputs[p])[i]; - // TODO: investigate if HPX is automatically deallocating - //these. Here it could be safely assumed that these would no - //longer be live. - //delete ((char*)outputs[p]); - } + void save(Archive &ar, const unsigned int version) const { + ar &output_sizes; + for (size_t p = 0; p < outputs.size(); ++p) { + for (size_t i = 0; i < output_sizes[p]; ++i) + ar &static_cast(outputs[p])[i]; + // TODO: investigate if HPX is automatically deallocating + // these. Here it could be safely assumed that these would no + // longer be live. + // delete (char*)outputs[p]; + } } HPX_SERIALIZATION_SPLIT_MEMBER() @@ -142,129 +117,124 @@ struct OpaqueOutputData bool alloc_p = false; }; -struct GenericComputeServer : component_base -{ - GenericComputeServer () = default; +struct GenericComputeServer : component_base { + GenericComputeServer() = default; // Component actions exposed - OpaqueOutputData execute_task (const OpaqueInputData &inputs) - { - auto wfn = node_level_work_function_registry->getWorkFunctionPointer(inputs.wfn_name); + OpaqueOutputData execute_task(const OpaqueInputData &inputs) { + auto wfn = node_level_work_function_registry->getWorkFunctionPointer( + inputs.wfn_name); std::vector outputs; switch (inputs.output_sizes.size()) { - case 1: - { - void *output = (void *)(new char[inputs.output_sizes[0]]); - switch (inputs.params.size()) { - case 0: - wfn(output); - break; - case 1: - wfn(inputs.params[0], output); - break; - case 2: - wfn(inputs.params[0], inputs.params[1], output); - break; - case 3: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], output); - break; - default: - HPX_THROW_EXCEPTION(hpx::no_success, - "GenericComputeServer::execute_task", - "Error: number of task parameters not supported."); - } - outputs = {output}; - break; + case 1: { + void *output = (void *)(new char[inputs.output_sizes[0]]); + switch (inputs.params.size()) { + case 0: + wfn(output); + break; + case 1: + wfn(inputs.params[0], output); + break; + case 2: + wfn(inputs.params[0], inputs.params[1], output); + break; + case 3: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], output); + break; + default: + HPX_THROW_EXCEPTION(hpx::no_success, + "GenericComputeServer::execute_task", + "Error: number of task parameters not supported."); } - case 2: - { - void *output1 = (void *)(new char[inputs.output_sizes[0]]); - void *output2 = (void *)(new char[inputs.output_sizes[1]]); - switch (inputs.params.size()) { - case 0: - wfn(output1, output2); - break; - case 1: - wfn(inputs.params[0], output1, output2); - break; - case 2: - wfn(inputs.params[0], inputs.params[1], output1, output2); - break; - case 3: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2); - break; - default: - HPX_THROW_EXCEPTION(hpx::no_success, - "GenericComputeServer::execute_task", - "Error: number of task parameters not supported."); - } - outputs = {output1, output2}; - break; + outputs = {output}; + break; + } + case 2: { + void *output1 = (void *)(new char[inputs.output_sizes[0]]); + void *output2 = (void *)(new char[inputs.output_sizes[1]]); + switch (inputs.params.size()) { + case 0: + wfn(output1, output2); + break; + case 1: + wfn(inputs.params[0], output1, output2); + break; + case 2: + wfn(inputs.params[0], inputs.params[1], output1, output2); + break; + case 3: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, + output2); + break; + default: + HPX_THROW_EXCEPTION(hpx::no_success, + "GenericComputeServer::execute_task", + "Error: number of task parameters not supported."); } - case 3: - { - void *output1 = (void *)(new char[inputs.output_sizes[0]]); - void *output2 = (void *)(new char[inputs.output_sizes[1]]); - void *output3 = (void *)(new char[inputs.output_sizes[2]]); - switch (inputs.params.size()) { - case 0: - wfn(output1, output2, output3); - break; - case 1: - wfn(inputs.params[0], output1, output2, output3); - break; - case 2: - wfn(inputs.params[0], inputs.params[1], output1, output2, output3); - break; - case 3: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, output2, output3); - break; - default: - HPX_THROW_EXCEPTION(hpx::no_success, - "GenericComputeServer::execute_task", - "Error: number of task parameters not supported."); - } - outputs = {output1, output2, output3}; - break; + outputs = {output1, output2}; + break; + } + case 3: { + void *output1 = (void *)(new char[inputs.output_sizes[0]]); + void *output2 = (void *)(new char[inputs.output_sizes[1]]); + void *output3 = (void *)(new char[inputs.output_sizes[2]]); + switch (inputs.params.size()) { + case 0: + wfn(output1, output2, output3); + break; + case 1: + wfn(inputs.params[0], output1, output2, output3); + break; + case 2: + wfn(inputs.params[0], inputs.params[1], output1, output2, output3); + break; + case 3: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, + output2, output3); + break; + default: + HPX_THROW_EXCEPTION(hpx::no_success, + "GenericComputeServer::execute_task", + "Error: number of task parameters not supported."); } + outputs = {output1, output2, output3}; + break; + } default: - HPX_THROW_EXCEPTION(hpx::no_success, - "GenericComputeServer::execute_task", - "Error: number of task outputs not supported."); + HPX_THROW_EXCEPTION(hpx::no_success, "GenericComputeServer::execute_task", + "Error: number of task outputs not supported."); } if (inputs.alloc_p) for (auto p : inputs.params) - delete((char*)p); + delete ((char *)p); - return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes), inputs.alloc_p); + return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes), + inputs.alloc_p); } HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task); }; HPX_REGISTER_ACTION_DECLARATION(GenericComputeServer::execute_task_action, - GenericComputeServer_execute_task_action) + GenericComputeServer_execute_task_action) HPX_REGISTER_COMPONENT_MODULE() HPX_REGISTER_COMPONENT(hpx::components::component, - GenericComputeServer) + GenericComputeServer) HPX_REGISTER_ACTION(GenericComputeServer::execute_task_action, - GenericComputeServer_execute_task_action) + GenericComputeServer_execute_task_action) - -struct GenericComputeClient : client_base -{ +struct GenericComputeClient + : client_base { typedef client_base base_type; GenericComputeClient() = default; GenericComputeClient(id_type id) : base_type(std::move(id)) {} - hpx::future - execute_task(const OpaqueInputData &inputs) - { + hpx::future execute_task(const OpaqueInputData &inputs) { typedef GenericComputeServer::execute_task_action action_type; return hpx::async(this->get_id(), inputs); } diff --git a/compiler/include/zamalang/Runtime/key_manager.hpp b/compiler/include/zamalang/Runtime/key_manager.hpp index 4ce73cc88..a1e855a57 100644 --- a/compiler/include/zamalang/Runtime/key_manager.hpp +++ b/compiler/include/zamalang/Runtime/key_manager.hpp @@ -1,8 +1,8 @@ #ifndef ZAMALANG_DFR_KEY_MANAGER_HPP #define ZAMALANG_DFR_KEY_MANAGER_HPP -#include #include +#include #include #include @@ -12,97 +12,80 @@ struct PbsKeyManager; extern PbsKeyManager *node_level_key_manager; - -struct PbsKeyWrapper -{ +struct PbsKeyWrapper { std::shared_ptr key; size_t key_id; size_t size; PbsKeyWrapper() {} - PbsKeyWrapper(void *key, size_t key_id, size_t size) : - key(std::make_shared(key)), key_id(key_id), size(size) {} + PbsKeyWrapper(void *key, size_t key_id, size_t size) + : key(std::make_shared(key)), key_id(key_id), size(size) {} - PbsKeyWrapper(std::shared_ptr key, size_t key_id, size_t size) : - key(key), key_id(key_id), size(size) {} + PbsKeyWrapper(std::shared_ptr key, size_t key_id, size_t size) + : key(key), key_id(key_id), size(size) {} - PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept : - key(moved.key), key_id(moved.key_id), size(moved.size) {} + PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept + : key(moved.key), key_id(moved.key_id), size(moved.size) {} - PbsKeyWrapper(const PbsKeyWrapper &pbsk) : - key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {} + PbsKeyWrapper(const PbsKeyWrapper &pbsk) + : key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {} friend class hpx::serialization::access; template - void save(Archive &ar, const unsigned int version) const - { + void save(Archive &ar, const unsigned int version) const { char *_key_ = static_cast(*key); - ar & key_id & size; + ar &key_id &size; for (size_t i = 0; i < size; ++i) - ar & _key_[i]; + ar &_key_[i]; } - template - void load(Archive &ar, const unsigned int version) - { - ar & key_id & size; - char *_key_ = (char *) malloc(size); + template void load(Archive &ar, const unsigned int version) { + ar &key_id &size; + char *_key_ = (char *)malloc(size); for (size_t i = 0; i < size; ++i) - ar & _key_[i]; + ar &_key_[i]; key = std::make_shared(_key_); } HPX_SERIALIZATION_SPLIT_MEMBER() }; -inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs) -{ +inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs) { return lhs.key_id == rhs.key_id; } - PbsKeyWrapper _dfr_fetch_key(size_t); HPX_PLAIN_ACTION(_dfr_fetch_key, _dfr_fetch_key_action) -struct PbsKeyManager -{ +struct PbsKeyManager { // The initial keys registered on the root node and whether to push // them is TBD. - PbsKeyManager() - { - node_level_key_manager = this; - } + PbsKeyManager() { node_level_key_manager = this; } - PbsKeyWrapper get_key(const size_t key_id) - { + PbsKeyWrapper get_key(const size_t key_id) { keystore_guard.lock(); auto keyit = keystore.find(key_id); keystore_guard.unlock(); - if (keyit == keystore.end()) - { - _dfr_fetch_key_action fet; - PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id); - if (pkw.size == 0) - { - // Maybe retry or try other nodes... but for now it's an error. - HPX_THROW_EXCEPTION(hpx::no_success, - "_dfr_get_key", - "Error: key not found on remote node."); - } - else - { - std::lock_guard guard(keystore_guard); - keyit = keystore.insert(std::pair(key_id, pkw)).first; - } + if (keyit == keystore.end()) { + _dfr_fetch_key_action fet; + PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id); + if (pkw.size == 0) { + // Maybe retry or try other nodes... but for now it's an error. + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key", + "Error: key not found on remote node."); + } else { + std::lock_guard guard(keystore_guard); + keyit = keystore.insert(std::pair(key_id, pkw)) + .first; } + } return keyit->second; } // To be used only for remote requests - PbsKeyWrapper fetch_key(const size_t key_id) - { + PbsKeyWrapper fetch_key(const size_t key_id) { std::lock_guard guard(keystore_guard); auto keyit = keystore.find(key_id); @@ -112,30 +95,27 @@ struct PbsKeyManager return PbsKeyWrapper(nullptr, 0, 0); } - void register_key(void *key, size_t key_id, size_t size) - { + void register_key(void *key, size_t key_id, size_t size) { std::lock_guard guard(keystore_guard); - auto keyit = - keystore.insert( - std::pair(key_id, - PbsKeyWrapper(key, key_id, size))).first; - if (keyit == keystore.end()) - { - HPX_THROW_EXCEPTION(hpx::no_success, - "_dfr_register_key", - "Error: could not register new key."); - } + auto keyit = keystore + .insert(std::pair( + key_id, PbsKeyWrapper(key, key_id, size))) + .first; + if (keyit == keystore.end()) { + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key", + "Error: could not register new key."); + } } - void broadcast_keys() - { + void broadcast_keys() { std::lock_guard guard(keystore_guard); if (_dfr_is_root_node()) hpx::collectives::broadcast_to("keystore", this->keystore).get(); else keystore = std::move( - hpx::collectives::broadcast_from> - ("keystore").get()); + hpx::collectives::broadcast_from>( + "keystore") + .get()); } private: @@ -143,10 +123,7 @@ private: std::map keystore; }; - -PbsKeyWrapper -_dfr_fetch_key(size_t key_id) -{ +PbsKeyWrapper _dfr_fetch_key(size_t key_id) { return node_level_key_manager->fetch_key(key_id); }