mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(runtime): deactivate main wrapping by default and add explicit initialization/termination.
This commit is contained in:
@@ -14,9 +14,10 @@
|
||||
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
|
||||
bool _dfr_is_root_node();
|
||||
void _dfr_is_jit(bool);
|
||||
bool _dfr_set_required(bool);
|
||||
void _dfr_set_jit(bool);
|
||||
bool _dfr_is_jit();
|
||||
bool _dfr_is_root_node();
|
||||
void _dfr_terminate();
|
||||
|
||||
typedef enum _dfr_task_arg_type {
|
||||
|
||||
@@ -680,8 +680,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
/***************************/
|
||||
/* JIT execution support. */
|
||||
/***************************/
|
||||
void _dfr_try_initialize();
|
||||
namespace {
|
||||
static bool dfr_required_p = false;
|
||||
static bool is_jit_p = false;
|
||||
} // namespace
|
||||
bool _dfr_set_required(bool is_required) {
|
||||
dfr_required_p = is_required;
|
||||
if (dfr_required_p)
|
||||
_dfr_try_initialize();
|
||||
return true;
|
||||
}
|
||||
void _dfr_set_jit(bool is_jit) { is_jit_p = is_jit; }
|
||||
bool _dfr_is_jit() { return is_jit_p; }
|
||||
|
||||
static inline bool _dfr_is_root_node_impl() {
|
||||
static bool is_root_node_p = (hpx::find_here() == hpx::find_root_locality());
|
||||
static bool is_root_node_p =
|
||||
(!dfr_required_p || (hpx::find_here() == hpx::find_root_locality()));
|
||||
return is_root_node_p;
|
||||
}
|
||||
|
||||
@@ -692,31 +707,15 @@ void _dfr_register_work_function(wfnptr wfn) {
|
||||
(void *)wfn);
|
||||
}
|
||||
|
||||
/********************************/
|
||||
/* Distributed key management. */
|
||||
/********************************/
|
||||
|
||||
/************************************/
|
||||
/* Initialization & Finalization. */
|
||||
/************************************/
|
||||
// TODO: need to set a flag for when executing in JIT to allow remote
|
||||
// nodes to execute generated function that registers work functions.
|
||||
// This also means that compute nodes need to go through all the
|
||||
// phases of computation synchronized with the root node.
|
||||
static inline bool _dfr_is_jit_impl(bool is_jit = false) {
|
||||
static bool is_jit_p = is_jit;
|
||||
if (is_jit && !is_jit_p)
|
||||
is_jit_p = true;
|
||||
return is_jit_p;
|
||||
}
|
||||
|
||||
void _dfr_is_jit(bool is_jit) { _dfr_is_jit_impl(is_jit); }
|
||||
bool _dfr_is_jit() { return _dfr_is_jit_impl(); }
|
||||
|
||||
static inline void _dfr_stop_impl() {
|
||||
if (_dfr_is_root_node_impl())
|
||||
if (_dfr_is_root_node())
|
||||
hpx::apply([]() { hpx::finalize(); });
|
||||
hpx::stop();
|
||||
if (!_dfr_is_root_node())
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
|
||||
static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
@@ -744,6 +743,10 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
std::string numOMPThreads = std::to_string(nOMPThreads);
|
||||
setenv("OMP_NUM_THREADS", numOMPThreads.c_str(), 0);
|
||||
nHPXThreads = nCores + 1 - nOMPThreads;
|
||||
env = getenv("DFR_NUM_THREADS");
|
||||
if (env != nullptr)
|
||||
nHPXThreads = strtoul(env, NULL, 10);
|
||||
nHPXThreads = (nHPXThreads) ? nHPXThreads : 1;
|
||||
std::string numHPXThreads = std::to_string(nHPXThreads);
|
||||
char *_argv[3] = {const_cast<char *>("__dummy_dfr_HPX_program_name__"),
|
||||
const_cast<char *>("--hpx:threads"),
|
||||
@@ -764,7 +767,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
_dfr_jit_phase_barrier = new hpx::lcos::barrier(
|
||||
"phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id());
|
||||
|
||||
if (_dfr_is_root_node_impl()) {
|
||||
if (_dfr_is_root_node()) {
|
||||
// Create compute server components on each node - from the root
|
||||
// node only - and the corresponding compute client on the root
|
||||
// node.
|
||||
@@ -779,33 +782,68 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
JIT invocation). These serve to pause/resume the runtime
|
||||
scheduler and to clean up used resources. */
|
||||
void _dfr_start() {
|
||||
hpx::resume();
|
||||
// The first invocation will initialise the runtime. As each call to
|
||||
// _dfr_start is matched with _dfr_stop, if this is not hte first,
|
||||
// we need to resume the HPX runtime.
|
||||
uint64_t uninitialised = 0;
|
||||
uint64_t active = 1;
|
||||
uint64_t suspended = 2;
|
||||
if (init_guard.compare_exchange_strong(uninitialised, active))
|
||||
_dfr_start_impl(0, nullptr);
|
||||
else if (init_guard.compare_exchange_strong(suspended, active))
|
||||
hpx::resume();
|
||||
|
||||
if (!_dfr_is_jit())
|
||||
// If this is not the root node in a non-JIT execution, then this
|
||||
// node should only run the scheduler for any incoming work until
|
||||
// termination is flagged. If this is JIT, we need to run the
|
||||
// cancelled function which registers the work functions.
|
||||
if (!_dfr_is_root_node() && !_dfr_is_jit())
|
||||
_dfr_stop_impl();
|
||||
|
||||
// TODO: conditional -- If this is the root node, and this is JIT
|
||||
// execution, we need to wait for the compute nodes to compile and
|
||||
// register work functions
|
||||
if (_dfr_is_root_node_impl() && _dfr_is_jit()) {
|
||||
if (_dfr_is_root_node() && _dfr_is_jit()) {
|
||||
_dfr_jit_workfunction_registration_barrier->wait();
|
||||
}
|
||||
}
|
||||
|
||||
// This function cannot be used to terminate the runtime as it is
|
||||
// non-decidable if another computation phase will follow. Instead the
|
||||
// _dfr_terminate function provides this facility and is normally
|
||||
// called on exit from "main" when not using the main wrapper library.
|
||||
void _dfr_stop() {
|
||||
if (!_dfr_is_root_node_impl() /*&& _dfr_is_jit() /** implicitly true*/) {
|
||||
// Non-root nodes synchronize here with the root to mark the point
|
||||
// where the root is free to send work out.
|
||||
// TODO: optimize this by moving synchro to local remote nodes
|
||||
// waiting in the scheduler for registration.
|
||||
if (!_dfr_is_root_node() /*&& _dfr_is_jit() /** implicitly true*/) {
|
||||
_dfr_jit_workfunction_registration_barrier->wait();
|
||||
}
|
||||
|
||||
// The barrier is only needed to synchronize the different
|
||||
// computation phases when the compute nodes need to generate and
|
||||
// register new work functions in each phase.
|
||||
|
||||
// TODO: this barrier may be removed based on how work function
|
||||
// registration is handled - but it is unlikely to result in much
|
||||
// gain as the root node would be waiting for the end of computation
|
||||
// on all remote nodes before reaching here anyway (dataflow
|
||||
// dependences).
|
||||
if (_dfr_is_jit()) {
|
||||
_dfr_jit_phase_barrier->wait();
|
||||
}
|
||||
|
||||
hpx::suspend();
|
||||
// TODO: this can be removed along with the matching hpx::resume if
|
||||
// their overhead is larger than the benefit of pausing worker
|
||||
// threads outside of parallel regions - to be tested.
|
||||
uint64_t active = 1;
|
||||
uint64_t suspended = 2;
|
||||
if (init_guard.compare_exchange_strong(active, suspended))
|
||||
hpx::suspend();
|
||||
|
||||
// TODO: until we have better unique identifiers for keys it is
|
||||
// safer to drop them in-between phases.
|
||||
_dfr_node_level_bsk_manager->clear_keys();
|
||||
_dfr_node_level_ksk_manager->clear_keys();
|
||||
|
||||
@@ -823,14 +861,26 @@ void _dfr_stop() {
|
||||
}
|
||||
}
|
||||
|
||||
void _dfr_terminate() {
|
||||
uint64_t initialised = 1;
|
||||
if (init_guard.compare_exchange_strong(initialised, 2)) {
|
||||
hpx::resume();
|
||||
_dfr_stop_impl();
|
||||
void _dfr_try_initialize() {
|
||||
// Initialize and immediately suspend the HPX runtime if not yet done.
|
||||
uint64_t uninitialised = 0;
|
||||
uint64_t suspended = 2;
|
||||
if (init_guard.compare_exchange_strong(uninitialised, suspended)) {
|
||||
_dfr_start_impl(0, nullptr);
|
||||
hpx::suspend();
|
||||
}
|
||||
}
|
||||
|
||||
void _dfr_terminate() {
|
||||
uint64_t active = 1;
|
||||
uint64_t suspended = 2;
|
||||
uint64_t terminated = 3;
|
||||
if (init_guard.compare_exchange_strong(suspended, active))
|
||||
hpx::resume();
|
||||
if (init_guard.compare_exchange_strong(active, terminated))
|
||||
_dfr_stop_impl();
|
||||
}
|
||||
|
||||
/*******************/
|
||||
/* Main wrapper. */
|
||||
/*******************/
|
||||
@@ -839,22 +889,12 @@ extern int main(int argc, char *argv[]) __attribute__((weak));
|
||||
extern int __real_main(int argc, char *argv[]) __attribute__((weak));
|
||||
int __wrap_main(int argc, char *argv[]) {
|
||||
int r;
|
||||
// Initialize and immediately suspend the HPX runtime if not yet done.
|
||||
uint64_t uninitialised = 0;
|
||||
if (init_guard.compare_exchange_strong(uninitialised, 1)) {
|
||||
_dfr_start_impl(0, nullptr);
|
||||
hpx::suspend();
|
||||
}
|
||||
|
||||
_dfr_try_initialize();
|
||||
// Run the actual main function. Within there should be a call to
|
||||
// _dfr_start to resume execution of the HPX scheduler if needed.
|
||||
r = __real_main(argc, argv);
|
||||
// By default all _dfr_start should be matched to a _dfr_stop, so we
|
||||
// need to resume before being able to finalize.
|
||||
uint64_t initialised = 1;
|
||||
if (init_guard.compare_exchange_strong(initialised, 2)) {
|
||||
hpx::resume();
|
||||
_dfr_stop_impl();
|
||||
}
|
||||
_dfr_terminate();
|
||||
|
||||
return r;
|
||||
}
|
||||
@@ -869,7 +909,7 @@ size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); }
|
||||
|
||||
void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs) {
|
||||
// clang-format off
|
||||
hpx::cout << "Task \"" << name << "\""
|
||||
hpx::cout << "Task \"" << name << "\t\""
|
||||
<< " [" << inputs << " inputs, " << outputs << " outputs]"
|
||||
<< " Executing on Node/Worker: " << _dfr_debug_get_node_id()
|
||||
<< " / " << _dfr_debug_get_worker_id() << "\n" << std::flush;
|
||||
@@ -885,8 +925,8 @@ void _dfr_print_debug(size_t val) {
|
||||
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
|
||||
bool _dfr_set_required(bool is_required) { return !is_required; }
|
||||
void _dfr_set_jit(bool) {}
|
||||
bool _dfr_is_root_node() { return true; }
|
||||
void _dfr_is_jit(bool) {}
|
||||
void _dfr_terminate() {}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -34,6 +34,7 @@
|
||||
#include <concretelang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <concretelang/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.h>
|
||||
#include <concretelang/Dialect/TFHE/IR/TFHEDialect.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/CompilerEngine.h>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
|
||||
@@ -104,6 +104,7 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
|
||||
}
|
||||
#endif
|
||||
|
||||
// invokeRaw needs to have pointers on arguments and a pointers on the result
|
||||
// as last argument.
|
||||
// Prepare the outputs vector to store the output value of the lambda.
|
||||
|
||||
Reference in New Issue
Block a user