feat(runtime): deactivate main wrapping by default and add explicit initialization/termination.

This commit is contained in:
Antoniu Pop
2022-03-17 14:56:51 +00:00
committed by Antoniu Pop
parent bca85ea2b6
commit 460fbabbe0
4 changed files with 92 additions and 49 deletions

View File

@@ -14,9 +14,10 @@
#include "concretelang/Runtime/runtime_api.h"
bool _dfr_is_root_node();
void _dfr_is_jit(bool);
bool _dfr_set_required(bool);
void _dfr_set_jit(bool);
bool _dfr_is_jit();
bool _dfr_is_root_node();
void _dfr_terminate();
typedef enum _dfr_task_arg_type {

View File

@@ -680,8 +680,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
/***************************/
/* JIT execution support. */
/***************************/
void _dfr_try_initialize();
namespace {
static bool dfr_required_p = false;
static bool is_jit_p = false;
} // namespace
bool _dfr_set_required(bool is_required) {
dfr_required_p = is_required;
if (dfr_required_p)
_dfr_try_initialize();
return true;
}
void _dfr_set_jit(bool is_jit) { is_jit_p = is_jit; }
bool _dfr_is_jit() { return is_jit_p; }
static inline bool _dfr_is_root_node_impl() {
static bool is_root_node_p = (hpx::find_here() == hpx::find_root_locality());
static bool is_root_node_p =
(!dfr_required_p || (hpx::find_here() == hpx::find_root_locality()));
return is_root_node_p;
}
@@ -692,31 +707,15 @@ void _dfr_register_work_function(wfnptr wfn) {
(void *)wfn);
}
/********************************/
/* Distributed key management. */
/********************************/
/************************************/
/* Initialization & Finalization. */
/************************************/
// TODO: need to set a flag for when executing in JIT to allow remote
// nodes to execute generated function that registers work functions.
// This also means that compute nodes need to go through all the
// phases of computation synchronized with the root node.
static inline bool _dfr_is_jit_impl(bool is_jit = false) {
static bool is_jit_p = is_jit;
if (is_jit && !is_jit_p)
is_jit_p = true;
return is_jit_p;
}
void _dfr_is_jit(bool is_jit) { _dfr_is_jit_impl(is_jit); }
bool _dfr_is_jit() { return _dfr_is_jit_impl(); }
static inline void _dfr_stop_impl() {
if (_dfr_is_root_node_impl())
if (_dfr_is_root_node())
hpx::apply([]() { hpx::finalize(); });
hpx::stop();
if (!_dfr_is_root_node())
exit(EXIT_SUCCESS);
}
static inline void _dfr_start_impl(int argc, char *argv[]) {
@@ -744,6 +743,10 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
std::string numOMPThreads = std::to_string(nOMPThreads);
setenv("OMP_NUM_THREADS", numOMPThreads.c_str(), 0);
nHPXThreads = nCores + 1 - nOMPThreads;
env = getenv("DFR_NUM_THREADS");
if (env != nullptr)
nHPXThreads = strtoul(env, NULL, 10);
nHPXThreads = (nHPXThreads) ? nHPXThreads : 1;
std::string numHPXThreads = std::to_string(nHPXThreads);
char *_argv[3] = {const_cast<char *>("__dummy_dfr_HPX_program_name__"),
const_cast<char *>("--hpx:threads"),
@@ -764,7 +767,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
_dfr_jit_phase_barrier = new hpx::lcos::barrier(
"phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id());
if (_dfr_is_root_node_impl()) {
if (_dfr_is_root_node()) {
// Create compute server components on each node - from the root
// node only - and the corresponding compute client on the root
// node.
@@ -779,33 +782,68 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
JIT invocation). These serve to pause/resume the runtime
scheduler and to clean up used resources. */
void _dfr_start() {
hpx::resume();
// The first invocation will initialise the runtime. As each call to
// _dfr_start is matched with _dfr_stop, if this is not hte first,
// we need to resume the HPX runtime.
uint64_t uninitialised = 0;
uint64_t active = 1;
uint64_t suspended = 2;
if (init_guard.compare_exchange_strong(uninitialised, active))
_dfr_start_impl(0, nullptr);
else if (init_guard.compare_exchange_strong(suspended, active))
hpx::resume();
if (!_dfr_is_jit())
// If this is not the root node in a non-JIT execution, then this
// node should only run the scheduler for any incoming work until
// termination is flagged. If this is JIT, we need to run the
// cancelled function which registers the work functions.
if (!_dfr_is_root_node() && !_dfr_is_jit())
_dfr_stop_impl();
// TODO: conditional -- If this is the root node, and this is JIT
// execution, we need to wait for the compute nodes to compile and
// register work functions
if (_dfr_is_root_node_impl() && _dfr_is_jit()) {
if (_dfr_is_root_node() && _dfr_is_jit()) {
_dfr_jit_workfunction_registration_barrier->wait();
}
}
// This function cannot be used to terminate the runtime as it is
// non-decidable if another computation phase will follow. Instead the
// _dfr_terminate function provides this facility and is normally
// called on exit from "main" when not using the main wrapper library.
void _dfr_stop() {
if (!_dfr_is_root_node_impl() /*&& _dfr_is_jit() /** implicitly true*/) {
// Non-root nodes synchronize here with the root to mark the point
// where the root is free to send work out.
// TODO: optimize this by moving synchro to local remote nodes
// waiting in the scheduler for registration.
if (!_dfr_is_root_node() /*&& _dfr_is_jit() /** implicitly true*/) {
_dfr_jit_workfunction_registration_barrier->wait();
}
// The barrier is only needed to synchronize the different
// computation phases when the compute nodes need to generate and
// register new work functions in each phase.
// TODO: this barrier may be removed based on how work function
// registration is handled - but it is unlikely to result in much
// gain as the root node would be waiting for the end of computation
// on all remote nodes before reaching here anyway (dataflow
// dependences).
if (_dfr_is_jit()) {
_dfr_jit_phase_barrier->wait();
}
hpx::suspend();
// TODO: this can be removed along with the matching hpx::resume if
// their overhead is larger than the benefit of pausing worker
// threads outside of parallel regions - to be tested.
uint64_t active = 1;
uint64_t suspended = 2;
if (init_guard.compare_exchange_strong(active, suspended))
hpx::suspend();
// TODO: until we have better unique identifiers for keys it is
// safer to drop them in-between phases.
_dfr_node_level_bsk_manager->clear_keys();
_dfr_node_level_ksk_manager->clear_keys();
@@ -823,14 +861,26 @@ void _dfr_stop() {
}
}
void _dfr_terminate() {
uint64_t initialised = 1;
if (init_guard.compare_exchange_strong(initialised, 2)) {
hpx::resume();
_dfr_stop_impl();
void _dfr_try_initialize() {
// Initialize and immediately suspend the HPX runtime if not yet done.
uint64_t uninitialised = 0;
uint64_t suspended = 2;
if (init_guard.compare_exchange_strong(uninitialised, suspended)) {
_dfr_start_impl(0, nullptr);
hpx::suspend();
}
}
void _dfr_terminate() {
uint64_t active = 1;
uint64_t suspended = 2;
uint64_t terminated = 3;
if (init_guard.compare_exchange_strong(suspended, active))
hpx::resume();
if (init_guard.compare_exchange_strong(active, terminated))
_dfr_stop_impl();
}
/*******************/
/* Main wrapper. */
/*******************/
@@ -839,22 +889,12 @@ extern int main(int argc, char *argv[]) __attribute__((weak));
extern int __real_main(int argc, char *argv[]) __attribute__((weak));
int __wrap_main(int argc, char *argv[]) {
int r;
// Initialize and immediately suspend the HPX runtime if not yet done.
uint64_t uninitialised = 0;
if (init_guard.compare_exchange_strong(uninitialised, 1)) {
_dfr_start_impl(0, nullptr);
hpx::suspend();
}
_dfr_try_initialize();
// Run the actual main function. Within there should be a call to
// _dfr_start to resume execution of the HPX scheduler if needed.
r = __real_main(argc, argv);
// By default all _dfr_start should be matched to a _dfr_stop, so we
// need to resume before being able to finalize.
uint64_t initialised = 1;
if (init_guard.compare_exchange_strong(initialised, 2)) {
hpx::resume();
_dfr_stop_impl();
}
_dfr_terminate();
return r;
}
@@ -869,7 +909,7 @@ size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); }
void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs) {
// clang-format off
hpx::cout << "Task \"" << name << "\""
hpx::cout << "Task \"" << name << "\t\""
<< " [" << inputs << " inputs, " << outputs << " outputs]"
<< " Executing on Node/Worker: " << _dfr_debug_get_node_id()
<< " / " << _dfr_debug_get_worker_id() << "\n" << std::flush;
@@ -885,8 +925,8 @@ void _dfr_print_debug(size_t val) {
#include "concretelang/Runtime/DFRuntime.hpp"
bool _dfr_set_required(bool is_required) { return !is_required; }
void _dfr_set_jit(bool) {}
bool _dfr_is_root_node() { return true; }
void _dfr_is_jit(bool) {}
void _dfr_terminate() {}
#endif

View File

@@ -34,6 +34,7 @@
#include <concretelang/Dialect/RT/IR/RTDialect.h>
#include <concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h>
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/CompilerEngine.h>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>

View File

@@ -104,6 +104,7 @@ JITLambda::call(clientlib::PublicArguments &args,
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
}
#endif
// invokeRaw needs to have pointers on arguments and a pointers on the result
// as last argument.
// Prepare the outputs vector to store the output value of the lambda.