diff --git a/compiler/include/concretelang/Runtime/DFRuntime.hpp b/compiler/include/concretelang/Runtime/DFRuntime.hpp index 0ed66ad3f..9f97c2523 100644 --- a/compiler/include/concretelang/Runtime/DFRuntime.hpp +++ b/compiler/include/concretelang/Runtime/DFRuntime.hpp @@ -14,9 +14,10 @@ #include "concretelang/Runtime/runtime_api.h" -bool _dfr_is_root_node(); -void _dfr_is_jit(bool); +bool _dfr_set_required(bool); +void _dfr_set_jit(bool); bool _dfr_is_jit(); +bool _dfr_is_root_node(); void _dfr_terminate(); typedef enum _dfr_task_arg_type { diff --git a/compiler/lib/Runtime/DFRuntime.cpp b/compiler/lib/Runtime/DFRuntime.cpp index d1ee12cdf..3cedf77d5 100644 --- a/compiler/lib/Runtime/DFRuntime.cpp +++ b/compiler/lib/Runtime/DFRuntime.cpp @@ -680,8 +680,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, /***************************/ /* JIT execution support. */ /***************************/ +void _dfr_try_initialize(); +namespace { +static bool dfr_required_p = false; +static bool is_jit_p = false; +} // namespace +bool _dfr_set_required(bool is_required) { + dfr_required_p = is_required; + if (dfr_required_p) + _dfr_try_initialize(); + return true; +} +void _dfr_set_jit(bool is_jit) { is_jit_p = is_jit; } +bool _dfr_is_jit() { return is_jit_p; } + static inline bool _dfr_is_root_node_impl() { - static bool is_root_node_p = (hpx::find_here() == hpx::find_root_locality()); + static bool is_root_node_p = + (!dfr_required_p || (hpx::find_here() == hpx::find_root_locality())); return is_root_node_p; } @@ -692,31 +707,15 @@ void _dfr_register_work_function(wfnptr wfn) { (void *)wfn); } -/********************************/ -/* Distributed key management. */ -/********************************/ - /************************************/ /* Initialization & Finalization. */ /************************************/ -// TODO: need to set a flag for when executing in JIT to allow remote -// nodes to execute generated function that registers work functions. -// This also means that compute nodes need to go through all the -// phases of computation synchronized with the root node. -static inline bool _dfr_is_jit_impl(bool is_jit = false) { - static bool is_jit_p = is_jit; - if (is_jit && !is_jit_p) - is_jit_p = true; - return is_jit_p; -} - -void _dfr_is_jit(bool is_jit) { _dfr_is_jit_impl(is_jit); } -bool _dfr_is_jit() { return _dfr_is_jit_impl(); } - static inline void _dfr_stop_impl() { - if (_dfr_is_root_node_impl()) + if (_dfr_is_root_node()) hpx::apply([]() { hpx::finalize(); }); hpx::stop(); + if (!_dfr_is_root_node()) + exit(EXIT_SUCCESS); } static inline void _dfr_start_impl(int argc, char *argv[]) { @@ -744,6 +743,10 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { std::string numOMPThreads = std::to_string(nOMPThreads); setenv("OMP_NUM_THREADS", numOMPThreads.c_str(), 0); nHPXThreads = nCores + 1 - nOMPThreads; + env = getenv("DFR_NUM_THREADS"); + if (env != nullptr) + nHPXThreads = strtoul(env, NULL, 10); + nHPXThreads = (nHPXThreads) ? nHPXThreads : 1; std::string numHPXThreads = std::to_string(nHPXThreads); char *_argv[3] = {const_cast("__dummy_dfr_HPX_program_name__"), const_cast("--hpx:threads"), @@ -764,7 +767,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { _dfr_jit_phase_barrier = new hpx::lcos::barrier( "phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id()); - if (_dfr_is_root_node_impl()) { + if (_dfr_is_root_node()) { // Create compute server components on each node - from the root // node only - and the corresponding compute client on the root // node. @@ -779,33 +782,68 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { JIT invocation). These serve to pause/resume the runtime scheduler and to clean up used resources. */ void _dfr_start() { - hpx::resume(); + // The first invocation will initialise the runtime. As each call to + // _dfr_start is matched with _dfr_stop, if this is not hte first, + // we need to resume the HPX runtime. + uint64_t uninitialised = 0; + uint64_t active = 1; + uint64_t suspended = 2; + if (init_guard.compare_exchange_strong(uninitialised, active)) + _dfr_start_impl(0, nullptr); + else if (init_guard.compare_exchange_strong(suspended, active)) + hpx::resume(); - if (!_dfr_is_jit()) + // If this is not the root node in a non-JIT execution, then this + // node should only run the scheduler for any incoming work until + // termination is flagged. If this is JIT, we need to run the + // cancelled function which registers the work functions. + if (!_dfr_is_root_node() && !_dfr_is_jit()) _dfr_stop_impl(); // TODO: conditional -- If this is the root node, and this is JIT // execution, we need to wait for the compute nodes to compile and // register work functions - if (_dfr_is_root_node_impl() && _dfr_is_jit()) { + if (_dfr_is_root_node() && _dfr_is_jit()) { _dfr_jit_workfunction_registration_barrier->wait(); } } +// This function cannot be used to terminate the runtime as it is +// non-decidable if another computation phase will follow. Instead the +// _dfr_terminate function provides this facility and is normally +// called on exit from "main" when not using the main wrapper library. void _dfr_stop() { - if (!_dfr_is_root_node_impl() /*&& _dfr_is_jit() /** implicitly true*/) { + // Non-root nodes synchronize here with the root to mark the point + // where the root is free to send work out. + // TODO: optimize this by moving synchro to local remote nodes + // waiting in the scheduler for registration. + if (!_dfr_is_root_node() /*&& _dfr_is_jit() /** implicitly true*/) { _dfr_jit_workfunction_registration_barrier->wait(); } // The barrier is only needed to synchronize the different // computation phases when the compute nodes need to generate and // register new work functions in each phase. + + // TODO: this barrier may be removed based on how work function + // registration is handled - but it is unlikely to result in much + // gain as the root node would be waiting for the end of computation + // on all remote nodes before reaching here anyway (dataflow + // dependences). if (_dfr_is_jit()) { _dfr_jit_phase_barrier->wait(); } - hpx::suspend(); + // TODO: this can be removed along with the matching hpx::resume if + // their overhead is larger than the benefit of pausing worker + // threads outside of parallel regions - to be tested. + uint64_t active = 1; + uint64_t suspended = 2; + if (init_guard.compare_exchange_strong(active, suspended)) + hpx::suspend(); + // TODO: until we have better unique identifiers for keys it is + // safer to drop them in-between phases. _dfr_node_level_bsk_manager->clear_keys(); _dfr_node_level_ksk_manager->clear_keys(); @@ -823,14 +861,26 @@ void _dfr_stop() { } } -void _dfr_terminate() { - uint64_t initialised = 1; - if (init_guard.compare_exchange_strong(initialised, 2)) { - hpx::resume(); - _dfr_stop_impl(); +void _dfr_try_initialize() { + // Initialize and immediately suspend the HPX runtime if not yet done. + uint64_t uninitialised = 0; + uint64_t suspended = 2; + if (init_guard.compare_exchange_strong(uninitialised, suspended)) { + _dfr_start_impl(0, nullptr); + hpx::suspend(); } } +void _dfr_terminate() { + uint64_t active = 1; + uint64_t suspended = 2; + uint64_t terminated = 3; + if (init_guard.compare_exchange_strong(suspended, active)) + hpx::resume(); + if (init_guard.compare_exchange_strong(active, terminated)) + _dfr_stop_impl(); +} + /*******************/ /* Main wrapper. */ /*******************/ @@ -839,22 +889,12 @@ extern int main(int argc, char *argv[]) __attribute__((weak)); extern int __real_main(int argc, char *argv[]) __attribute__((weak)); int __wrap_main(int argc, char *argv[]) { int r; - // Initialize and immediately suspend the HPX runtime if not yet done. - uint64_t uninitialised = 0; - if (init_guard.compare_exchange_strong(uninitialised, 1)) { - _dfr_start_impl(0, nullptr); - hpx::suspend(); - } + + _dfr_try_initialize(); // Run the actual main function. Within there should be a call to // _dfr_start to resume execution of the HPX scheduler if needed. r = __real_main(argc, argv); - // By default all _dfr_start should be matched to a _dfr_stop, so we - // need to resume before being able to finalize. - uint64_t initialised = 1; - if (init_guard.compare_exchange_strong(initialised, 2)) { - hpx::resume(); - _dfr_stop_impl(); - } + _dfr_terminate(); return r; } @@ -869,7 +909,7 @@ size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); } void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs) { // clang-format off - hpx::cout << "Task \"" << name << "\"" + hpx::cout << "Task \"" << name << "\t\"" << " [" << inputs << " inputs, " << outputs << " outputs]" << " Executing on Node/Worker: " << _dfr_debug_get_node_id() << " / " << _dfr_debug_get_worker_id() << "\n" << std::flush; @@ -885,8 +925,8 @@ void _dfr_print_debug(size_t val) { #include "concretelang/Runtime/DFRuntime.hpp" +bool _dfr_set_required(bool is_required) { return !is_required; } +void _dfr_set_jit(bool) {} bool _dfr_is_root_node() { return true; } -void _dfr_is_jit(bool) {} void _dfr_terminate() {} - #endif diff --git a/compiler/lib/Support/CompilerEngine.cpp b/compiler/lib/Support/CompilerEngine.cpp index 27e6071cb..67a5d1b92 100644 --- a/compiler/lib/Support/CompilerEngine.cpp +++ b/compiler/lib/Support/CompilerEngine.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include diff --git a/compiler/lib/Support/Jit.cpp b/compiler/lib/Support/Jit.cpp index 41bbffe29..296b6e1d7 100644 --- a/compiler/lib/Support/Jit.cpp +++ b/compiler/lib/Support/Jit.cpp @@ -104,6 +104,7 @@ JITLambda::call(clientlib::PublicArguments &args, return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers); } #endif + // invokeRaw needs to have pointers on arguments and a pointers on the result // as last argument. // Prepare the outputs vector to store the output value of the lambda.