From 954b2098c6834177cb51035def47f4449e5636e3 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Tue, 15 Mar 2022 14:46:47 +0000 Subject: [PATCH] feat(runtime): enable distributed execution. --- .../include/concretelang/ClientLib/KeySet.h | 1 + .../concretelang/Dialect/RT/IR/RTOps.td | 8 +- .../concretelang/Runtime/DFRuntime.hpp | 107 ++------ .../Runtime/dfr_debug_interface.h | 2 +- .../distributed_generic_task_server.hpp | 251 ++++++++++++++---- .../concretelang/Runtime/key_manager.hpp | 243 +++++++++++------ .../concretelang/Runtime/runtime_api.h | 8 +- .../Runtime/workfunction_registry.hpp | 92 +++++++ .../lib/Dialect/RT/Analysis/CMakeLists.txt | 2 +- .../RT/Analysis/LowerDataflowTasksToRT.cpp | 127 +++++++-- ...owerRTToLLVMDFRCallsConversionPatterns.cpp | 21 ++ compiler/lib/Runtime/DFRuntime.cpp | 164 +++++++++--- compiler/lib/Support/JITSupport.cpp | 8 +- compiler/lib/Support/Jit.cpp | 14 + compiler/src/main.cpp | 2 +- 15 files changed, 765 insertions(+), 285 deletions(-) create mode 100644 compiler/include/concretelang/Runtime/workfunction_registry.hpp diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index 207606375..3924d90f5 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -19,6 +19,7 @@ extern "C" { #include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Common/Error.h" +#include namespace concretelang { namespace clientlib { diff --git a/compiler/include/concretelang/Dialect/RT/IR/RTOps.td b/compiler/include/concretelang/Dialect/RT/IR/RTOps.td index 84cdc50e1..3ae30f408 100644 --- a/compiler/include/concretelang/Dialect/RT/IR/RTOps.td +++ b/compiler/include/concretelang/Dialect/RT/IR/RTOps.td @@ -115,7 +115,13 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> { let summary = "Create a dataflow task."; } -def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> { +def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> { + let arguments = (ins Variadic:$list); + let results = (outs ); + let summary = "Register the task work-function with the runtime system."; +} + +def DeallocateFutureOp : RT_Op<"deallocate_future"> { let arguments = (ins RT_Future: $input); let results = (outs ); } diff --git a/compiler/include/concretelang/Runtime/DFRuntime.hpp b/compiler/include/concretelang/Runtime/DFRuntime.hpp index 8fa8d9267..0ed66ad3f 100644 --- a/compiler/include/concretelang/Runtime/DFRuntime.hpp +++ b/compiler/include/concretelang/Runtime/DFRuntime.hpp @@ -6,102 +6,39 @@ #ifndef CONCRETELANG_DFR_DFRUNTIME_HPP #define CONCRETELANG_DFR_DFRUNTIME_HPP +#include +#include #include #include #include #include "concretelang/Runtime/runtime_api.h" -/* Debug interface. */ -#include "concretelang/Runtime/dfr_debug_interface.h" +bool _dfr_is_root_node(); +void _dfr_is_jit(bool); +bool _dfr_is_jit(); +void _dfr_terminate(); -extern void *dl_handle; -struct WorkFunctionRegistry; -extern WorkFunctionRegistry *node_level_work_function_registry; +typedef enum _dfr_task_arg_type { + _DFR_TASK_ARG_BASE = 0, + _DFR_TASK_ARG_MEMREF = 1, + _DFR_TASK_ARG_UNRANKED_MEMREF = 2, + _DFR_TASK_ARG_CONTEXT = 3 +} _dfr_task_arg_type; -/// Recover the name of the work function -static inline const char *_dfr_get_function_name_from_address(void *fn) { - Dl_info info; - - if (!dladdr(fn, &info) || info.dli_sname == nullptr) - HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_name_from_address", - "Error recovering work function name from address."); - return info.dli_sname; +static inline _dfr_task_arg_type _dfr_get_arg_type(uint64_t val) { + return (_dfr_task_arg_type)(val & 0xFF); } - -static inline wfnptr _dfr_get_function_pointer_from_name(const char *fn_name) { - auto ptr = dlsym(dl_handle, fn_name); - - if (ptr == nullptr) - HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_pointer_from_name", - "Error recovering work function pointer from name."); - return (wfnptr)ptr; +static inline uint64_t _dfr_get_memref_element_size(uint64_t val) { + return val >> 8; } - -/// Determine where new task should run. For now just round-robin -/// distribution - TODO: optimise. -static inline size_t _dfr_find_next_execution_locality() { - static size_t num_nodes = hpx::get_num_localities().get(); - static std::atomic next_locality{0}; - - size_t next_loc = ++next_locality; - - return next_loc % num_nodes; +static inline uint64_t _dfr_set_arg_type(uint64_t val, + _dfr_task_arg_type type) { + return (val & ~(0xFF)) | type; } - -static inline bool _dfr_is_root_node() { - return hpx::find_here() == hpx::find_root_locality(); +static inline uint64_t _dfr_set_memref_element_size(uint64_t val, size_t size) { + assert(size < (((uint64_t)1) << 48)); + return (val & 0xFF) | (((uint64_t)size) << 8); } -struct WorkFunctionRegistry { - WorkFunctionRegistry() { node_level_work_function_registry = this; } - - wfnptr getWorkFunctionPointer(const std::string &name) { - std::lock_guard guard(registry_guard); - - auto fnptrit = name_to_ptr_registry.find(name); - if (fnptrit != name_to_ptr_registry.end()) - return (wfnptr)fnptrit->second; - - auto ptr = dlsym(dl_handle, name.c_str()); - if (ptr == nullptr) - HPX_THROW_EXCEPTION(hpx::no_success, - "WorkFunctionRegistry::getWorkFunctionPointer", - "Error recovering work function pointer from name."); - ptr_to_name_registry.insert( - std::pair(ptr, name)); - name_to_ptr_registry.insert( - std::pair(name, ptr)); - return (wfnptr)ptr; - } - - std::string getWorkFunctionName(const void *fn) { - std::lock_guard guard(registry_guard); - - auto fnnameit = ptr_to_name_registry.find(fn); - if (fnnameit != ptr_to_name_registry.end()) - return fnnameit->second; - - Dl_info info; - std::string ret; - // Assume that if we can't find the name, there is no dynamic - // library to find it in. TODO: fix this to distinguish JIT/binary - // and in case of distributed exec. - if (!dladdr(fn, &info) || info.dli_sname == nullptr) { - static std::atomic fnid{0}; - ret = "_dfr_jit_wfnname_" + std::to_string(fnid++); - } else { - ret = info.dli_sname; - } - ptr_to_name_registry.insert(std::pair(fn, ret)); - name_to_ptr_registry.insert(std::pair(ret, fn)); - return ret; - } - -private: - std::mutex registry_guard; - std::map ptr_to_name_registry; - std::map name_to_ptr_registry; -}; - #endif diff --git a/compiler/include/concretelang/Runtime/dfr_debug_interface.h b/compiler/include/concretelang/Runtime/dfr_debug_interface.h index de4d1111b..452296102 100644 --- a/compiler/include/concretelang/Runtime/dfr_debug_interface.h +++ b/compiler/include/concretelang/Runtime/dfr_debug_interface.h @@ -12,7 +12,7 @@ extern "C" { size_t _dfr_debug_get_node_id(); size_t _dfr_debug_get_worker_id(); -void _dfr_debug_print_task(const char *name, int inputs, int outputs); +void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs); void _dfr_print_debug(size_t val); } #endif diff --git a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp index a7a6e50e7..0ece268e3 100644 --- a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp +++ b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp @@ -26,101 +26,250 @@ #include #include -#include "concretelang/Runtime/DFRuntime.hpp" -#include "concretelang/Runtime/key_manager.hpp" +#include -extern WorkFunctionRegistry *node_level_work_function_registry; +#include "concretelang/Runtime/DFRuntime.hpp" +#include "concretelang/Runtime/context.h" +#include "concretelang/Runtime/dfr_debug_interface.h" +#include "concretelang/Runtime/key_manager.hpp" +#include "concretelang/Runtime/runtime_api.h" +#include "concretelang/Runtime/workfunction_registry.hpp" + +extern WorkFunctionRegistry *_dfr_node_level_work_function_registry; extern std::list new_allocated; using namespace hpx::naming; using namespace hpx::components; using namespace hpx::collectives; +static inline size_t _dfr_get_memref_rank(size_t size) { + return (size - 2 * sizeof(char *) /*allocated_ptr & aligned_ptr*/ + - sizeof(int64_t) /*offset*/) / + (2 * sizeof(int64_t) /*size&stride/rank*/); +} + struct OpaqueInputData { OpaqueInputData() = default; - OpaqueInputData(std::string wfn_name, std::vector params, - std::vector param_sizes, - std::vector output_sizes, bool alloc_p = false) - : wfn_name(wfn_name), params(std::move(params)), - param_sizes(std::move(param_sizes)), - output_sizes(std::move(output_sizes)), alloc_p(alloc_p) {} + OpaqueInputData(std::string _wfn_name, std::vector _params, + std::vector _param_sizes, + std::vector _param_types, + std::vector _output_sizes, + std::vector _output_types, bool _alloc_p = false) + : wfn_name(_wfn_name), params(std::move(_params)), + param_sizes(std::move(_param_sizes)), + param_types(std::move(_param_types)), + output_sizes(std::move(_output_sizes)), + output_types(std::move(_output_types)), alloc_p(_alloc_p), + source_locality(hpx::find_here()) {} OpaqueInputData(const OpaqueInputData &oid) : wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)), param_sizes(std::move(oid.param_sizes)), - output_sizes(std::move(oid.output_sizes)), alloc_p(oid.alloc_p) {} + param_types(std::move(oid.param_types)), + output_sizes(std::move(oid.output_sizes)), + output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p), + source_locality(oid.source_locality) {} friend class hpx::serialization::access; template void load(Archive &ar, const unsigned int version) { - ar &wfn_name; - ar ¶m_sizes; - ar &output_sizes; - for (auto p : param_sizes) { - char *param = new char[p]; - // TODO: Optimise these serialisation operations - for (size_t i = 0; i < p; ++i) - ar ¶m[i]; + ar >> wfn_name; + ar >> param_sizes >> param_types; + ar >> output_sizes >> output_types; + for (size_t p = 0; p < param_sizes.size(); ++p) { + char *param = new char[param_sizes[p]]; + new_allocated.push_back((void *)param); + ar >> hpx::serialization::make_array(param, param_sizes[p]); params.push_back((void *)param); + + switch (_dfr_get_arg_type(param_types[p])) { + case _DFR_TASK_ARG_BASE: + break; + case _DFR_TASK_ARG_MEMREF: { + size_t rank = _dfr_get_memref_rank(param_sizes[p]); + UnrankedMemRefType umref = {rank, params[p]}; + DynamicMemRefType mref(umref); + size_t elementSize = _dfr_get_memref_element_size(param_types[p]); + size_t size = 1; + for (size_t r = 0; r < rank; ++r) + size *= mref.sizes[r]; + size_t alloc_size = (size + mref.offset) * elementSize; + char *data = new char[alloc_size]; + new_allocated.push_back((void *)data); + ar >> hpx::serialization::make_array(data + mref.offset * elementSize, + size * elementSize); + static_cast *>(params[p])->basePtr = nullptr; + static_cast *>(params[p])->data = data; + } break; + case _DFR_TASK_ARG_CONTEXT: { + uint64_t bsk_id, ksk_id; + ar >> bsk_id >> ksk_id >> source_locality; + + mlir::concretelang::RuntimeContext *context = + new mlir::concretelang::RuntimeContext; + new_allocated.push_back((void *)context); + mlir::concretelang::RuntimeContext **_context = + new mlir::concretelang::RuntimeContext *[1]; + new_allocated.push_back((void *)_context); + _context[0] = context; + + context->bsk = (LweBootstrapKey_u64 *)bsk_id; + context->ksk = (LweKeyswitchKey_u64 *)ksk_id; + params[p] = (void *)_context; + } break; + case _DFR_TASK_ARG_UNRANKED_MEMREF: + default: + HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", + "Error: invalid task argument type."); + } } alloc_p = true; } template void save(Archive &ar, const unsigned int version) const { - ar &wfn_name; - ar ¶m_sizes; - ar &output_sizes; - for (size_t p = 0; p < params.size(); ++p) - for (size_t i = 0; i < param_sizes[p]; ++i) - ar &static_cast(params[p])[i]; + ar << wfn_name; + ar << param_sizes << param_types; + ar << output_sizes << output_types; + for (size_t p = 0; p < params.size(); ++p) { + // Save the first level of the data structure - if the parameter + // is a tensor/memref, there is a second level. + ar << hpx::serialization::make_array((char *)params[p], param_sizes[p]); + switch (_dfr_get_arg_type(param_types[p])) { + case _DFR_TASK_ARG_BASE: + break; + case _DFR_TASK_ARG_MEMREF: { + size_t rank = _dfr_get_memref_rank(param_sizes[p]); + UnrankedMemRefType umref = {rank, params[p]}; + DynamicMemRefType mref(umref); + size_t elementSize = _dfr_get_memref_element_size(param_types[p]); + size_t size = 1; + for (size_t r = 0; r < rank; ++r) + size *= mref.sizes[r]; + ar << hpx::serialization::make_array( + mref.data + mref.offset * elementSize, size * elementSize); + } break; + case _DFR_TASK_ARG_CONTEXT: { + mlir::concretelang::RuntimeContext *context = + *static_cast(params[p]); + LweKeyswitchKey_u64 *ksk = context->ksk; + LweBootstrapKey_u64 *bsk = context->bsk; + + // TODO: find better unique identifiers. This is not a + // correctness issue, but performance. + uint64_t bsk_id = (uint64_t)bsk; + uint64_t ksk_id = (uint64_t)ksk; + + assert(bsk != nullptr && ksk != nullptr && "Missing context keys"); + _dfr_register_bsk(bsk, bsk_id); + _dfr_register_ksk(ksk, ksk_id); + ar << bsk_id << ksk_id << source_locality; + } break; + case _DFR_TASK_ARG_UNRANKED_MEMREF: + default: + HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", + "Error: invalid task argument type."); + } + } } HPX_SERIALIZATION_SPLIT_MEMBER() std::string wfn_name; std::vector params; std::vector param_sizes; + std::vector param_types; std::vector output_sizes; + std::vector output_types; bool alloc_p = false; + hpx::naming::id_type source_locality; }; struct OpaqueOutputData { OpaqueOutputData() = default; OpaqueOutputData(std::vector outputs, - std::vector output_sizes, bool alloc_p = false) + std::vector output_sizes, + std::vector output_types, bool alloc_p = false) : outputs(std::move(outputs)), output_sizes(std::move(output_sizes)), - alloc_p(alloc_p) {} + output_types(std::move(output_types)), alloc_p(alloc_p) {} OpaqueOutputData(const OpaqueOutputData &ood) : outputs(std::move(ood.outputs)), - output_sizes(std::move(ood.output_sizes)), alloc_p(ood.alloc_p) {} + output_sizes(std::move(ood.output_sizes)), + output_types(std::move(ood.output_types)), alloc_p(ood.alloc_p) {} friend class hpx::serialization::access; template void load(Archive &ar, const unsigned int version) { - ar &output_sizes; - for (auto p : output_sizes) { - char *output = new char[p]; - for (size_t i = 0; i < p; ++i) - ar &output[i]; - outputs.push_back((void *)output); + ar >> output_sizes >> output_types; + for (size_t p = 0; p < output_sizes.size(); ++p) { + char *output = new char[output_sizes[p]]; new_allocated.push_back((void *)output); + ar >> hpx::serialization::make_array(output, output_sizes[p]); + outputs.push_back((void *)output); + + switch (_dfr_get_arg_type(output_types[p])) { + case _DFR_TASK_ARG_BASE: + break; + case _DFR_TASK_ARG_MEMREF: { + size_t rank = _dfr_get_memref_rank(output_sizes[p]); + UnrankedMemRefType umref = {rank, outputs[p]}; + DynamicMemRefType mref(umref); + size_t elementSize = _dfr_get_memref_element_size(output_types[p]); + size_t size = 1; + for (size_t r = 0; r < rank; ++r) + size *= mref.sizes[r]; + size_t alloc_size = (size + mref.offset) * elementSize; + char *data = new char[alloc_size]; + new_allocated.push_back((void *)data); + ar >> hpx::serialization::make_array(data + mref.offset * elementSize, + size * elementSize); + static_cast *>(outputs[p])->basePtr = + nullptr; + static_cast *>(outputs[p])->data = data; + } break; + case _DFR_TASK_ARG_CONTEXT: { + + } break; + case _DFR_TASK_ARG_UNRANKED_MEMREF: + default: + HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", + "Error: invalid task argument type."); + } } alloc_p = true; } template void save(Archive &ar, const unsigned int version) const { - ar &output_sizes; + ar << output_sizes << output_types; for (size_t p = 0; p < outputs.size(); ++p) { - for (size_t i = 0; i < output_sizes[p]; ++i) - ar &static_cast(outputs[p])[i]; - // TODO: investigate if HPX is automatically deallocating - // these. Here it could be safely assumed that these would no - // longer be live. - // delete (char*)outputs[p]; + ar << hpx::serialization::make_array((char *)outputs[p], output_sizes[p]); + + switch (_dfr_get_arg_type(output_types[p])) { + case _DFR_TASK_ARG_BASE: + break; + case _DFR_TASK_ARG_MEMREF: { + size_t rank = _dfr_get_memref_rank(output_sizes[p]); + UnrankedMemRefType umref = {rank, outputs[p]}; + DynamicMemRefType mref(umref); + size_t elementSize = _dfr_get_memref_element_size(output_types[p]); + size_t size = 1; + for (size_t r = 0; r < rank; ++r) + size *= mref.sizes[r]; + ar << hpx::serialization::make_array( + mref.data + mref.offset * elementSize, size * elementSize); + } break; + case _DFR_TASK_ARG_CONTEXT: { + + } break; + case _DFR_TASK_ARG_UNRANKED_MEMREF: + default: + HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", + "Error: invalid task argument type."); + } } } HPX_SERIALIZATION_SPLIT_MEMBER() std::vector outputs; std::vector output_sizes; + std::vector output_types; bool alloc_p = false; }; @@ -129,10 +278,24 @@ struct GenericComputeServer : component_base { // Component actions exposed OpaqueOutputData execute_task(const OpaqueInputData &inputs) { - auto wfn = node_level_work_function_registry->getWorkFunctionPointer( + auto wfn = _dfr_node_level_work_function_registry->getWorkFunctionPointer( inputs.wfn_name); std::vector outputs; + if (!_dfr_is_root_node()) { + for (size_t p = 0; p < inputs.params.size(); ++p) { + if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) { + mlir::concretelang::RuntimeContext *ctx = + (*(mlir::concretelang::RuntimeContext **)inputs.params[p]); + ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk); + ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk); + break; + } + } + } + _dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(), + inputs.output_sizes.size()); + hpx::cout << std::flush; switch (inputs.output_sizes.size()) { case 1: { void *output = (void *)(new char[inputs.output_sizes[0]]); @@ -449,12 +612,8 @@ struct GenericComputeServer : component_base { "Error: number of task outputs not supported."); } - if (inputs.alloc_p) - for (auto p : inputs.params) - delete ((char *)p); - return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes), - inputs.alloc_p); + std::move(inputs.output_types), inputs.alloc_p); } HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task); diff --git a/compiler/include/concretelang/Runtime/key_manager.hpp b/compiler/include/concretelang/Runtime/key_manager.hpp index d83dce418..3ffb43816 100644 --- a/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compiler/include/concretelang/Runtime/key_manager.hpp @@ -7,129 +7,206 @@ #define CONCRETELANG_DFR_KEY_MANAGER_HPP #include +#include #include #include #include +#include #include "concretelang/Runtime/DFRuntime.hpp" -struct PbsKeyManager; -extern PbsKeyManager *node_level_key_manager; +extern "C" { +#include "concrete-ffi.h" +} -struct PbsKeyWrapper { - std::shared_ptr key; - size_t key_id; - size_t size; +extern std::list new_allocated; - PbsKeyWrapper() {} +template struct KeyManager; +extern KeyManager *_dfr_node_level_bsk_manager; +extern KeyManager *_dfr_node_level_ksk_manager; +void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id); +void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id); - PbsKeyWrapper(void *key, size_t key_id, size_t size) - : key(std::make_shared(key)), key_id(key_id), size(size) {} - - PbsKeyWrapper(std::shared_ptr key, size_t key_id, size_t size) - : key(key), key_id(key_id), size(size) {} - - PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept - : key(moved.key), key_id(moved.key_id), size(moved.size) {} - - PbsKeyWrapper(const PbsKeyWrapper &pbsk) - : key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {} +template struct KeyWrapper { + LweKeyType *key; + KeyWrapper() : key(nullptr) {} + KeyWrapper(LweKeyType *key) : key(key) {} + KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {} + KeyWrapper(const KeyWrapper &kw) : key(kw.key) {} friend class hpx::serialization::access; template - void save(Archive &ar, const unsigned int version) const { - char *_key_ = static_cast(*key); - ar &key_id &size; - for (size_t i = 0; i < size; ++i) - ar &_key_[i]; - } - - template void load(Archive &ar, const unsigned int version) { - ar &key_id &size; - char *_key_ = (char *)malloc(size); - for (size_t i = 0; i < size; ++i) - ar &_key_[i]; - key = std::make_shared(_key_); - } + void save(Archive &ar, const unsigned int version) const; + template void load(Archive &ar, const unsigned int version); HPX_SERIALIZATION_SPLIT_MEMBER() }; -inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs) { - return lhs.key_id == rhs.key_id; +template <> +template +void KeyWrapper::save(Archive &ar, + const unsigned int version) const { + Buffer buffer = serialize_lwe_bootstrap_key_u64(key); + ar << buffer.length; + ar << hpx::serialization::make_array(buffer.pointer, buffer.length); +} +template <> +template +void KeyWrapper::load(Archive &ar, + const unsigned int version) { + size_t length; + ar >> length; + uint8_t *pointer = new uint8_t[length]; + new_allocated.push_back((void *)pointer); + ar >> hpx::serialization::make_array(pointer, length); + BufferView buffer = {(const uint8_t *)pointer, length}; + key = deserialize_lwe_bootstrap_key_u64(buffer); } -PbsKeyWrapper _dfr_fetch_key(size_t); -HPX_PLAIN_ACTION(_dfr_fetch_key, _dfr_fetch_key_action) +template <> +template +void KeyWrapper::save(Archive &ar, + const unsigned int version) const { + Buffer buffer = serialize_lwe_keyswitching_key_u64(key); + ar << buffer.length; + ar << hpx::serialization::make_array(buffer.pointer, buffer.length); +} +template <> +template +void KeyWrapper::load(Archive &ar, + const unsigned int version) { + size_t length; + ar >> length; + uint8_t *pointer = new uint8_t[length]; + new_allocated.push_back((void *)pointer); + ar >> hpx::serialization::make_array(pointer, length); + BufferView buffer = {(const uint8_t *)pointer, length}; + key = deserialize_lwe_keyswitching_key_u64(buffer); +} -struct PbsKeyManager { - // The initial keys registered on the root node and whether to push - // them is TBD. +KeyWrapper _dfr_fetch_ksk(uint64_t); +HPX_PLAIN_ACTION(_dfr_fetch_ksk, _dfr_fetch_ksk_action) +KeyWrapper _dfr_fetch_bsk(uint64_t); +HPX_PLAIN_ACTION(_dfr_fetch_bsk, _dfr_fetch_bsk_action) - PbsKeyManager() { node_level_key_manager = this; } +template struct KeyManager { + KeyManager() {} + LweKeyType *get_key(hpx::naming::id_type loc, const uint64_t key_id); - PbsKeyWrapper get_key(const size_t key_id) { - keystore_guard.lock(); - auto keyit = keystore.find(key_id); - keystore_guard.unlock(); - - if (keyit == keystore.end()) { - _dfr_fetch_key_action fet; - PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id); - if (pkw.size == 0) { - // Maybe retry or try other nodes... but for now it's an error. - HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key", - "Error: key not found on remote node."); - } else { - std::lock_guard guard(keystore_guard); - keyit = keystore.insert(std::pair(key_id, pkw)) - .first; - } - } - return keyit->second; - } - - // To be used only for remote requests - PbsKeyWrapper fetch_key(const size_t key_id) { + KeyWrapper fetch_key(const uint64_t key_id) { std::lock_guard guard(keystore_guard); auto keyit = keystore.find(key_id); if (keyit != keystore.end()) return keyit->second; - // If this node does not contain this key, return an empty wrapper - return PbsKeyWrapper(nullptr, 0, 0); + // If this node does not contain this key, this is an error + // (location was supplied as source for this key). + HPX_THROW_EXCEPTION( + hpx::no_success, "fetch_key", + "Error: could not find key to be fetched on source location."); } - void register_key(void *key, size_t key_id, size_t size) { + void register_key(LweKeyType *key, uint64_t key_id) { std::lock_guard guard(keystore_guard); - auto keyit = keystore - .insert(std::pair( - key_id, PbsKeyWrapper(key, key_id, size))) - .first; + auto keyit = keystore.find(key_id); if (keyit == keystore.end()) { - HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key", - "Error: could not register new key."); + keyit = keystore + .insert(std::pair>( + key_id, KeyWrapper(key))) + .first; + if (keyit == keystore.end()) { + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key", + "Error: could not register new key."); + } } } - void broadcast_keys() { + void clear_keys() { std::lock_guard guard(keystore_guard); - if (_dfr_is_root_node()) - hpx::collectives::broadcast_to("keystore", this->keystore).get(); - else - keystore = std::move( - hpx::collectives::broadcast_from>( - "keystore") - .get()); + keystore.clear(); } private: std::mutex keystore_guard; - std::map keystore; + std::map> keystore; }; -PbsKeyWrapper _dfr_fetch_key(size_t key_id) { - return node_level_key_manager->fetch_key(key_id); +template <> KeyManager::KeyManager() { + _dfr_node_level_bsk_manager = this; +} + +template <> +LweBootstrapKey_u64 * +KeyManager::get_key(hpx::naming::id_type loc, + const uint64_t key_id) { + keystore_guard.lock(); + auto keyit = keystore.find(key_id); + keystore_guard.unlock(); + + if (keyit == keystore.end()) { + _dfr_fetch_bsk_action fetch; + KeyWrapper &&bskw = fetch(loc, key_id); + if (bskw.key == nullptr) { + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key", + "Error: Bootstrap key not found on root node."); + } else { + _dfr_register_bsk(bskw.key, key_id); + } + return bskw.key; + } + return keyit->second.key; +} + +template <> KeyManager::KeyManager() { + _dfr_node_level_ksk_manager = this; +} + +template <> +LweKeyswitchKey_u64 * +KeyManager::get_key(hpx::naming::id_type loc, + const uint64_t key_id) { + keystore_guard.lock(); + auto keyit = keystore.find(key_id); + keystore_guard.unlock(); + + if (keyit == keystore.end()) { + _dfr_fetch_ksk_action fetch; + KeyWrapper &&kskw = fetch(loc, key_id); + if (kskw.key == nullptr) { + HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key", + "Error: Keyswitching key not found on root node."); + } else { + _dfr_register_ksk(kskw.key, key_id); + } + return kskw.key; + } + return keyit->second.key; +} + +KeyWrapper _dfr_fetch_bsk(uint64_t key_id) { + return _dfr_node_level_bsk_manager->fetch_key(key_id); +} + +KeyWrapper _dfr_fetch_ksk(uint64_t key_id) { + return _dfr_node_level_ksk_manager->fetch_key(key_id); +} + +/************************/ +/* Key management API. */ +/************************/ + +void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id) { + _dfr_node_level_bsk_manager->register_key(key, key_id); +} +void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id) { + _dfr_node_level_ksk_manager->register_key(key, key_id); +} + +LweBootstrapKey_u64 *_dfr_get_bsk(hpx::naming::id_type loc, uint64_t key_id) { + return _dfr_node_level_bsk_manager->get_key(loc, key_id); +} +LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) { + return _dfr_node_level_ksk_manager->get_key(loc, key_id); } #endif diff --git a/compiler/include/concretelang/Runtime/runtime_api.h b/compiler/include/concretelang/Runtime/runtime_api.h index b3f05f6e0..4025f9aa4 100644 --- a/compiler/include/concretelang/Runtime/runtime_api.h +++ b/compiler/include/concretelang/Runtime/runtime_api.h @@ -16,14 +16,9 @@ typedef void (*wfnptr)(...); void *_dfr_make_ready_future(void *); void _dfr_create_async_task(wfnptr, size_t, size_t, ...); +void _dfr_register_work_function(wfnptr); void *_dfr_await_future(void *); -/* Keys can have node-local copies which can be retrieved. This - should only be called on the node where the key is required. */ -void _dfr_register_key(void *, size_t, size_t); -void _dfr_broadcast_keys(); -void *_dfr_get_key(size_t); - /* Memory management: _dfr_make_ready_future allocates the future, not the underlying storage. _dfr_create_async_task allocates both future and storage for outputs. */ @@ -36,4 +31,5 @@ void _dfr_stop(); void _dfr_terminate(); } + #endif diff --git a/compiler/include/concretelang/Runtime/workfunction_registry.hpp b/compiler/include/concretelang/Runtime/workfunction_registry.hpp new file mode 100644 index 000000000..2435b8e5e --- /dev/null +++ b/compiler/include/concretelang/Runtime/workfunction_registry.hpp @@ -0,0 +1,92 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt +// for license information. + +#ifndef CONCRETELANG_DFR_WORKFUNCTION_REGISTRY_HPP +#define CONCRETELANG_DFR_WORKFUNCTION_REGISTRY_HPP + +#include +#include +#include + +#include +#include +#include + +#include "concretelang/Runtime/DFRuntime.hpp" + +extern void *dl_handle; +struct WorkFunctionRegistry; +extern WorkFunctionRegistry *_dfr_node_level_work_function_registry; + +struct WorkFunctionRegistry { + WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; } + + wfnptr getWorkFunctionPointer(const std::string &name) { + std::lock_guard guard(registry_guard); + + auto fnptrit = name_to_ptr_registry.find(name); + if (fnptrit != name_to_ptr_registry.end()) + return (wfnptr)fnptrit->second; + + auto ptr = dlsym(dl_handle, name.c_str()); + if (ptr == nullptr) { + HPX_THROW_EXCEPTION(hpx::no_success, + "WorkFunctionRegistry::getWorkFunctionPointer", + "Error recovering work function pointer from name."); + } + registerWorkFunction(ptr, name); + return (wfnptr)ptr; + } + + std::string getWorkFunctionName(const void *fn) { + std::lock_guard guard(registry_guard); + + auto fnnameit = ptr_to_name_registry.find(fn); + if (fnnameit != ptr_to_name_registry.end()) + return fnnameit->second; + + Dl_info info; + std::string ret; + // Assume that if we can't find the name, there is no dynamic + // library to find it in. TODO: fix this to distinguish JIT/binary + // and in case of distributed exec. + if (!dladdr(fn, &info) || info.dli_sname == nullptr) { + ret = registerAnonymousWorkFunction(fn); + } else { + ret = info.dli_sname; + registerWorkFunction(fn, ret); + } + return ret; + } + + void registerWorkFunction(const void *fn, std::string name) { + std::lock_guard guard(registry_guard); + + auto fnnameit = ptr_to_name_registry.find(fn); + if (fnnameit == ptr_to_name_registry.end()) + ptr_to_name_registry.insert( + std::pair(fn, name)); + + auto fnptrit = name_to_ptr_registry.find(name); + if (fnptrit == name_to_ptr_registry.end()) + name_to_ptr_registry.insert( + std::pair(name, fn)); + } + + std::string registerAnonymousWorkFunction(const void *fn) { + std::lock_guard guard(registry_guard); + static std::atomic fnid{0}; + std::string name = "_dfr_jit_wfnname_" + std::to_string(fnid++); + registerWorkFunction(fn, name); + return name; + } + +private: + std::recursive_mutex registry_guard; + std::map ptr_to_name_registry; + std::map name_to_ptr_registry; +}; + +#endif diff --git a/compiler/lib/Dialect/RT/Analysis/CMakeLists.txt b/compiler/lib/Dialect/RT/Analysis/CMakeLists.txt index d2611ac1b..57beefa9f 100644 --- a/compiler/lib/Dialect/RT/Analysis/CMakeLists.txt +++ b/compiler/lib/Dialect/RT/Analysis/CMakeLists.txt @@ -14,5 +14,5 @@ add_mlir_library(RTDialectAnalysis LINK_LIBS PUBLIC MLIRIR RTDialect + ConcretelangRuntime ) - diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index 0ab7e1c8d..fb78a01c2 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include @@ -132,7 +133,8 @@ static void replaceAllUsesNotInDFTsInRegionWith(Value orig, Value replacement, } // TODO: Fix type sizes. For now we're using some default values. -static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) { +static std::pair +getSizeInBytes(Value val, Location loc, OpBuilder builder) { DataLayout dataLayout = DataLayout::closest(val.getDefiningOp()); Type type = (val.getType().isa()) ? val.getType().dyn_cast().getElementType() @@ -153,26 +155,61 @@ static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) { Value rank = builder.create( loc, builder.getI64IntegerAttr(_rank)); Value sizes_shapes = builder.create(loc, rank, multiplier); - Value result = builder.create(loc, ptrs_offset, sizes_shapes); - return result; + Value typeSize = + builder.create(loc, ptrs_offset, sizes_shapes); + + Type elementType = type.dyn_cast().getElementType(); + // Assume here that the base type is a simple scalar-type or at + // least its size can be determined. + // size_t elementAttr = dataLayout.getTypeSize(elementType); + // Make room for a byte to store the type of this argument/output + // elementAttr <<= 8; + // elementAttr |= _DFR_TASK_ARG_MEMREF; + uint64_t elementAttr = 0; + size_t element_size = dataLayout.getTypeSize(elementType); + elementAttr = _dfr_set_arg_type(elementAttr, _DFR_TASK_ARG_MEMREF); + elementAttr = _dfr_set_memref_element_size(elementAttr, element_size); + Value arg_type = builder.create( + loc, builder.getI64IntegerAttr(elementAttr)); + return std::pair(typeSize, arg_type); } // Unranked memrefs should be lowered to just pointer + size, so we need 16 // bytes. - if (type.isa()) - return builder.create(loc, - builder.getI64IntegerAttr(16)); + if (type.isa()) { + Value arg_type = builder.create( + loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_UNRANKED_MEMREF)); + Value result = + builder.create(loc, builder.getI64IntegerAttr(16)); + return std::pair(result, arg_type); + } + + Value arg_type = builder.create( + loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_BASE)); // FHE types are converted to pointers, so we take their size as 8 // bytes until we can get the actual size of the actual types. - if (type.isa() || - type.isa() || - type.isa()) - return builder.create(loc, builder.getI64IntegerAttr(8)); + if (type.isa() || + type.isa() || + type.isa() || + type.isa() || + type.isa() || + type.isa()) { + Value result = + builder.create(loc, builder.getI64IntegerAttr(8)); + return std::pair(result, arg_type); + } else if (type.isa()) { + Value arg_type = builder.create( + loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_CONTEXT)); + Value result = + builder.create(loc, builder.getI64IntegerAttr(8)); + return std::pair(result, arg_type); + } // For all other types, get type size. - return builder.create( + Value result = builder.create( loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type))); + return std::pair(result, arg_type); } static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, @@ -189,9 +226,22 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, builder.setInsertionPoint(DFTOp); for (Value val : DFTOp.getOperands()) { if (!val.getType().isa()) { - Type futType = RT::FutureType::get(val.getType()); - auto mrf = - builder.create(DFTOp.getLoc(), futType, val); + Value newval; + // If we are building a future on a MemRef, then we need to clone + // the memref in order to allow the deallocation pass which does + // not synchronize with task execution. + if (val.getType().isa() && !val.isa()) { + newval = builder + .create( + DFTOp.getLoc(), val.getType().cast()) + .getResult(); + builder.create(DFTOp.getLoc(), val, newval); + } else { + newval = val; + } + Type futType = RT::FutureType::get(newval.getType()); + auto mrf = builder.create(DFTOp.getLoc(), futType, + newval); map.map(mrf->getResult(0), val); replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody); } @@ -201,7 +251,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, // DataflowTaskOp. This also includes the necessary handling of // operands and results (conversion to/from futures and propagation). SmallVector catOperands; - int size = 3 + DFTOp.getNumResults() * 2 + DFTOp.getNumOperands() * 2; + int size = 3 + DFTOp.getNumResults() * 3 + DFTOp.getNumOperands() * 3; catOperands.reserve(size); auto fnptr = builder.create( DFTOp.getLoc(), workFunction.getFunctionType(), @@ -214,8 +264,10 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, catOperands.push_back(numIns.getResult()); catOperands.push_back(numOuts.getResult()); for (auto operand : DFTOp.getOperands()) { + auto op_size = getSizeInBytes(operand, DFTOp.getLoc(), builder); catOperands.push_back(operand); - catOperands.push_back(getSizeInBytes(operand, DFTOp.getLoc(), builder)); + catOperands.push_back(op_size.first); + catOperands.push_back(op_size.second); } // We need to adjust the results for the CreateAsyncTaskOp which @@ -228,9 +280,11 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, Type futType = RT::PointerType::get(RT::FutureType::get(result.getType())); auto brpp = builder.create(DFTOp.getLoc(), futType); + auto op_size = getSizeInBytes(result, DFTOp.getLoc(), builder); map.map(result, brpp->getResult(0)); catOperands.push_back(brpp->getResult(0)); - catOperands.push_back(getSizeInBytes(result, DFTOp.getLoc(), builder)); + catOperands.push_back(op_size.first); + catOperands.push_back(op_size.second); } builder.create( DFTOp.getLoc(), @@ -272,6 +326,18 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, DFTOp.erase(); } +static void registerWorkFunction(FuncOp parentFunc, FuncOp workFunction) { + OpBuilder builder(parentFunc.body()); + builder.setInsertionPointToStart(&parentFunc.body().front()); + + auto fnptr = builder.create( + parentFunc.getLoc(), workFunction.getType(), + SymbolRefAttr::get(builder.getContext(), workFunction.getName())); + + builder.create(parentFunc.getLoc(), + fnptr.getResult()); +} + /// For documentation see Autopar.td struct LowerDataflowTasksPass : public LowerDataflowTasksBase { @@ -306,6 +372,33 @@ struct LowerDataflowTasksPass for (auto mapping : outliningMap) lowerDataflowTaskOp(mapping.first, mapping.second); + // If this is a JIT invocation and we're not on the root node, + // we do not need to do any computation, only register all work + // functions with the runtime system + if (!outliningMap.empty()) { + if (!_dfr_is_root_node()) { + // auto regFunc = builder.create(func.getLoc(), + // func.getName(), func.getType()); + + func.eraseBody(); + Block *b = new Block; + b->addArguments(func.getType().getInputs()); + func.body().push_front(b); + for (int i = func.getType().getNumInputs() - 1; i >= 0; --i) + func.eraseArgument(i); + for (int i = func.getType().getNumResults() - 1; i >= 0; --i) + func.eraseResult(i); + OpBuilder builder(func.body()); + builder.setInsertionPointToEnd(&func.body().front()); + builder.create(func.getLoc()); + } + } + + // Generate code to register all work-functions with the + // runtime. + for (auto mapping : outliningMap) + registerWorkFunction(func, mapping.second); + // Issue _dfr_start/stop calls for this function if (!outliningMap.empty()) { OpBuilder builder(func.getBody()); diff --git a/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp b/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp index 15344771d..816678de6 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp @@ -171,6 +171,26 @@ struct CreateAsyncTaskOpInterfaceLowering return success(); } }; +struct RegisterTaskWorkFunctionOpInterfaceLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + RT::RegisterTaskWorkFunctionOp>::ConvertOpToLLVMPattern; + + mlir::LogicalResult + matchAndRewrite(RT::RegisterTaskWorkFunctionOp rtwfOp, + ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + RT::RegisterTaskWorkFunctionOp::Adaptor transformed(operands); + + auto rtwfFuncType = + LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true); + auto rtwfFuncOp = getOrInsertFuncOpDecl( + rtwfOp, "_dfr_register_work_function", rtwfFuncType, rewriter); + rewriter.replaceOpWithNewOp(rtwfOp, rtwfFuncOp, + transformed.getOperands()); + return success(); + } +}; struct DeallocateFutureOpInterfaceLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -296,6 +316,7 @@ void mlir::concretelang::populateRTToLLVMConversionPatterns( DerefReturnPtrPlaceholderOpInterfaceLowering, DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering, CreateAsyncTaskOpInterfaceLowering, + RegisterTaskWorkFunctionOpInterfaceLowering, DeallocateFutureOpInterfaceLowering, DeallocateFutureDataOpInterfaceLowering, WorkFunctionReturnOpInterfaceLowering>(converter); diff --git a/compiler/lib/Runtime/DFRuntime.cpp b/compiler/lib/Runtime/DFRuntime.cpp index a3247bda4..d1ee12cdf 100644 --- a/compiler/lib/Runtime/DFRuntime.cpp +++ b/compiler/lib/Runtime/DFRuntime.cpp @@ -11,6 +11,7 @@ #ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED +#include #include #include #include @@ -22,11 +23,16 @@ std::vector gcc; void *dl_handle; -PbsKeyManager *node_level_key_manager; -WorkFunctionRegistry *node_level_work_function_registry; + +KeyManager *_dfr_node_level_bsk_manager; +KeyManager *_dfr_node_level_ksk_manager; + +WorkFunctionRegistry *_dfr_node_level_work_function_registry; std::list new_allocated; std::list fut_allocated; std::list m_allocated; +hpx::lcos::barrier *_dfr_jit_workfunction_registration_barrier; +hpx::lcos::barrier *_dfr_jit_phase_barrier; std::atomic init_guard = {0}; using namespace hpx; @@ -52,6 +58,17 @@ void _dfr_deallocate_future(void *in) { delete (static_cast *>(in)); } +// Determine where new task should run. For now just round-robin +// distribution - TODO: optimise. +static inline size_t _dfr_find_next_execution_locality() { + static size_t num_nodes = hpx::get_num_localities().get(); + static std::atomic next_locality{0}; + + size_t next_loc = ++next_locality; + + return next_loc % num_nodes; +} + /// Runtime generic async_task. Each first NUM_PARAMS pairs of /// arguments in the variadic list corresponds to a void* pointer on a /// hpx::future and the size of data within the future. After @@ -60,28 +77,39 @@ void _dfr_deallocate_future(void *in) { void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, ...) { std::vector params; - std::vector outputs; std::vector param_sizes; + std::vector param_types; + std::vector outputs; std::vector output_sizes; + std::vector output_types; va_list args; va_start(args, num_outputs); for (size_t i = 0; i < num_params; ++i) { params.push_back(va_arg(args, void *)); - param_sizes.push_back(va_arg(args, size_t)); + param_sizes.push_back(va_arg(args, uint64_t)); + param_types.push_back(va_arg(args, uint64_t)); } for (size_t i = 0; i < num_outputs; ++i) { outputs.push_back(va_arg(args, void *)); - output_sizes.push_back(va_arg(args, size_t)); + output_sizes.push_back(va_arg(args, uint64_t)); + output_types.push_back(va_arg(args, uint64_t)); } va_end(args); + for (size_t i = 0; i < num_params; ++i) { + if (_dfr_get_arg_type(param_types[i] == _DFR_TASK_ARG_MEMREF)) { + m_allocated.push_back( + (void *)static_cast *>(params[i])->data); + } + } + // We pass functions by name - which is not strictly necessary in // shared memory as pointers suffice, but is needed in the // distributed case where the functions need to be located/loaded on // the node. auto wfnname = - node_level_work_function_registry->getWorkFunctionName((void *)wfn); + _dfr_node_level_work_function_registry->getWorkFunctionName((void *)wfn); hpx::future> oodf; // In order to allow complete dataflow semantics for @@ -93,20 +121,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, switch (num_params) { case 0: oodf = std::move( - hpx::dataflow([wfnname, param_sizes, - output_sizes]() -> hpx::future { + hpx::dataflow([wfnname, param_sizes, param_types, output_sizes, + output_types]() -> hpx::future { std::vector params = {}; - OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + OpaqueInputData oid(wfnname, params, param_sizes, param_types, + output_sizes, output_types); return gcc[_dfr_find_next_execution_locality()].execute_task(oid); })); break; case 1: oodf = std::move(hpx::dataflow( - [wfnname, param_sizes, output_sizes](hpx::shared_future param0) + [wfnname, param_sizes, param_types, output_sizes, + output_types](hpx::shared_future param0) -> hpx::future { std::vector params = {param0.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + OpaqueInputData oid(wfnname, params, param_sizes, param_types, + output_sizes, output_types); return gcc[_dfr_find_next_execution_locality()].execute_task(oid); }, *(hpx::shared_future *)params[0])); @@ -114,11 +145,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 2: oodf = std::move(hpx::dataflow( - [wfnname, param_sizes, output_sizes](hpx::shared_future param0, - hpx::shared_future param1) + [wfnname, param_sizes, param_types, output_sizes, + output_types](hpx::shared_future param0, + hpx::shared_future param1) -> hpx::future { std::vector params = {param0.get(), param1.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + OpaqueInputData oid(wfnname, params, param_sizes, param_types, + output_sizes, output_types); return gcc[_dfr_find_next_execution_locality()].execute_task(oid); }, *(hpx::shared_future *)params[0], @@ -127,13 +160,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 3: oodf = std::move(hpx::dataflow( - [wfnname, param_sizes, output_sizes](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2) + [wfnname, param_sizes, param_types, output_sizes, + output_types](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2) -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, output_sizes); + OpaqueInputData oid(wfnname, params, param_sizes, param_types, + output_sizes, output_types); return gcc[_dfr_find_next_execution_locality()].execute_task(oid); }, *(hpx::shared_future *)params[0], @@ -642,25 +677,44 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, } } +/***************************/ +/* JIT execution support. */ +/***************************/ +static inline bool _dfr_is_root_node_impl() { + static bool is_root_node_p = (hpx::find_here() == hpx::find_root_locality()); + return is_root_node_p; +} + +bool _dfr_is_root_node() { return _dfr_is_root_node_impl(); } + +void _dfr_register_work_function(wfnptr wfn) { + _dfr_node_level_work_function_registry->registerAnonymousWorkFunction( + (void *)wfn); +} + /********************************/ /* Distributed key management. */ /********************************/ -void _dfr_register_key(void *key, size_t key_id, size_t size) { - node_level_key_manager->register_key(key, key_id, size); -} - -void _dfr_broadcast_keys() { node_level_key_manager->broadcast_keys(); } - -void *_dfr_get_key(size_t key_id) { - return *node_level_key_manager->get_key(key_id).key.get(); -} /************************************/ /* Initialization & Finalization. */ /************************************/ -/* Runtime initialization and finalization. */ +// TODO: need to set a flag for when executing in JIT to allow remote +// nodes to execute generated function that registers work functions. +// This also means that compute nodes need to go through all the +// phases of computation synchronized with the root node. +static inline bool _dfr_is_jit_impl(bool is_jit = false) { + static bool is_jit_p = is_jit; + if (is_jit && !is_jit_p) + is_jit_p = true; + return is_jit_p; +} + +void _dfr_is_jit(bool is_jit) { _dfr_is_jit_impl(is_jit); } +bool _dfr_is_jit() { return _dfr_is_jit_impl(); } + static inline void _dfr_stop_impl() { - if (_dfr_is_root_node()) + if (_dfr_is_root_node_impl()) hpx::apply([]() { hpx::finalize(); }); hpx::stop(); } @@ -701,10 +755,16 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { } // Instantiate on each node - new PbsKeyManager(); + new KeyManager(); + new KeyManager(); new WorkFunctionRegistry(); + _dfr_jit_workfunction_registration_barrier = new hpx::lcos::barrier( + "wait_register_remote_work_functions", hpx::get_num_localities().get(), + hpx::get_locality_id()); + _dfr_jit_phase_barrier = new hpx::lcos::barrier( + "phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id()); - if (_dfr_is_root_node()) { + if (_dfr_is_root_node_impl()) { // Create compute server components on each node - from the root // node only - and the corresponding compute client on the root // node. @@ -712,9 +772,6 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { gcc = hpx::new_( hpx::default_layout(hpx::find_all_localities()), num_nodes) .get(); - } else { - hpx::stop(); - exit(EXIT_SUCCESS); } } @@ -722,16 +779,36 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { JIT invocation). These serve to pause/resume the runtime scheduler and to clean up used resources. */ void _dfr_start() { - uint64_t uninitialised = 0; - if (init_guard.compare_exchange_strong(uninitialised, 1)) - _dfr_start_impl(0, nullptr); - else - hpx::resume(); + hpx::resume(); + + if (!_dfr_is_jit()) + _dfr_stop_impl(); + + // TODO: conditional -- If this is the root node, and this is JIT + // execution, we need to wait for the compute nodes to compile and + // register work functions + if (_dfr_is_root_node_impl() && _dfr_is_jit()) { + _dfr_jit_workfunction_registration_barrier->wait(); + } } void _dfr_stop() { + if (!_dfr_is_root_node_impl() /*&& _dfr_is_jit() /** implicitly true*/) { + _dfr_jit_workfunction_registration_barrier->wait(); + } + + // The barrier is only needed to synchronize the different + // computation phases when the compute nodes need to generate and + // register new work functions in each phase. + if (_dfr_is_jit()) { + _dfr_jit_phase_barrier->wait(); + } + hpx::suspend(); + _dfr_node_level_bsk_manager->clear_keys(); + _dfr_node_level_ksk_manager->clear_keys(); + while (!new_allocated.empty()) { delete[] static_cast(new_allocated.front()); new_allocated.pop_front(); @@ -758,7 +835,7 @@ void _dfr_terminate() { /* Main wrapper. */ /*******************/ extern "C" { -extern int main(int argc, char *argv[]); // __attribute__((weak)); +extern int main(int argc, char *argv[]) __attribute__((weak)); extern int __real_main(int argc, char *argv[]) __attribute__((weak)); int __wrap_main(int argc, char *argv[]) { int r; @@ -790,7 +867,7 @@ size_t _dfr_debug_get_node_id() { return hpx::get_locality_id(); } size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); } -void _dfr_debug_print_task(const char *name, int inputs, int outputs) { +void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs) { // clang-format off hpx::cout << "Task \"" << name << "\"" << " [" << inputs << " inputs, " << outputs << " outputs]" @@ -806,7 +883,10 @@ void _dfr_print_debug(size_t val) { #else // CONCRETELANG_PARALLEL_EXECUTION_ENABLED -#include +#include "concretelang/Runtime/DFRuntime.hpp" +bool _dfr_is_root_node() { return true; } +void _dfr_is_jit(bool) {} void _dfr_terminate() {} + #endif diff --git a/compiler/lib/Support/JITSupport.cpp b/compiler/lib/Support/JITSupport.cpp index f1f279427..9ab819917 100644 --- a/compiler/lib/Support/JITSupport.cpp +++ b/compiler/lib/Support/JITSupport.cpp @@ -3,6 +3,7 @@ // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. +#include #include #include #include @@ -48,8 +49,11 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) { // Mark the lambda as compiled using DF parallelization result->lambda->setUseDataflow(options.dataflowParallelize || options.autoParallelize); - result->clientParameters = - compilationResult.get().clientParameters.getValue(); + if (!mlir::concretelang::dfr::_dfr_is_root_node()) + result->clientParameters = clientlib::ClientParameters(); + else + result->clientParameters = + compilationResult.get().clientParameters.getValue(); return std::move(result); } diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 254de52c1..41bbffe29 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -13,6 +13,7 @@ #include #include "concretelang/Common/BitsSize.h" +#include #include #include #include @@ -89,6 +90,19 @@ JITLambda::call(clientlib::PublicArguments &args, "call: current runtime doesn't support dataflow execution, while " "compilation used dataflow parallelization"); } +#else + dfr::_dfr_set_jit(true); + // When using JIT on distributed systems, the compiler only + // generates work-functions and their registration calls. No results + // are returned and no inputs are needed. + if (!dfr::_dfr_is_root_node()) { + std::vector rawArgs; + if (auto err = invokeRaw(rawArgs)) { + return std::move(err); + } + std::vector buffers; + return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers); + } #endif // invokeRaw needs to have pointers on arguments and a pointers on the result // as last argument. diff --git a/compiler/src/main.cpp b/compiler/src/main.cpp index 5c8a3ebfa..4d41d0d15 100644 --- a/compiler/src/main.cpp +++ b/compiler/src/main.cpp @@ -31,7 +31,7 @@ #include "concretelang/Dialect/RT/IR/RTDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" -#include "concretelang/Runtime/runtime_api.h" +#include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/Error.h" #include "concretelang/Support/JITSupport.h"