// Part of the Concrete Compiler Project, under the BSD3 License with Zama // Exceptions. See // https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt // for license information. /// This file implements the dataflow runtime. It encapsulates all of /// the underlying communication, parallelism, etc. and only exposes a /// simplified interface for code generation in runtime_api.h /// This hides the details of implementation, including of the HPX /// framework currently used, from the code generation side. #ifdef CONCRETELANG_DATAFLOW_EXECUTION_ENABLED #include #include #include #include #include #include #include #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Runtime/distributed_generic_task_server.hpp" #include "concretelang/Runtime/runtime_api.h" #include "concretelang/Runtime/time_util.h" namespace mlir { namespace concretelang { namespace dfr { namespace { static std::vector gcc; static hpx::lcos::barrier *_dfr_jit_phase_barrier; static hpx::lcos::barrier *_dfr_startup_barrier; static size_t num_nodes = 0; #if CONCRETELANG_TIMING_ENABLED static struct timespec init_timer, broadcast_timer, compute_timer, whole_timer; #endif } // namespace } // namespace dfr } // namespace concretelang } // namespace mlir #include "concretelang/Runtime/dfr_tasks.hpp" using namespace hpx; using namespace mlir::concretelang::dfr; // Ready futures are only used as inputs to tasks (never passed to // await_future), so we only need to track the references in task // creation. void *_dfr_make_ready_future(void *in, size_t memref_clone_p) { hpx::future future = hpx::make_ready_future(in); return (void *)new dfr_refcounted_future_t( new hpx::shared_future(std::move(future)), 1, memref_clone_p); } void *_dfr_await_future(void *in) { return static_cast(in)->future->get(); } void _dfr_deallocate_future(void *in) { auto drf = static_cast(in); size_t prev_count = drf->count.fetch_sub(1); if (prev_count == 1) { // If this was a memref for which a clone was needed, deallocate first. if (drf->cloned_memref_p) free( (void *)(static_cast *>(drf->future->get()) ->data)); free(drf->future->get()); delete (drf->future); delete drf; } } /// Runtime generic async_task. Each first NUM_PARAMS pairs of /// arguments in the variadic list corresponds to a void* pointer on a /// hpx::future and the size of data within the future. After /// that come NUM_OUTPUTS pairs of hpx::future* and size_t for /// the returns. void _dfr_create_async_task(wfnptr wfn, void *ctx, size_t num_params, size_t num_outputs, ...) { std::vector refcounted_futures; std::vector param_sizes; std::vector param_types; std::vector outputs; std::vector output_sizes; std::vector output_types; va_list args; va_start(args, num_outputs); for (size_t i = 0; i < num_outputs; ++i) { outputs.push_back(va_arg(args, void *)); output_sizes.push_back(va_arg(args, uint64_t)); output_types.push_back(va_arg(args, uint64_t)); } for (size_t i = 0; i < num_params; ++i) { refcounted_futures.push_back(va_arg(args, void *)); param_sizes.push_back(va_arg(args, uint64_t)); param_types.push_back(va_arg(args, uint64_t)); } va_end(args); dfr_create_async_task_impl(wfn, ctx, refcounted_futures, param_sizes, param_types, outputs, output_sizes, output_types); } /// Runtime generic async_task with vector parametres. Each first /// NUM_OUTPUTS quadruplets of arguments in the variadic list /// corresponds to a size_t for the number of elements in the /// following array, a void * pointer on an array of /// hpx::future and two size_t parameters for the size and type /// of each output. After that come NUM_PARAMS quadruplets of /// arguments in the variadic list that correspond to a size_t for the /// number of elements in the following array, a void* pointer on an /// array of hpx::future and the same two size_t parametres /// (size and type). void _dfr_create_async_task_vec(wfnptr wfn, void *ctx, size_t num_params, size_t num_outputs, ...) { std::vector refcounted_futures; std::vector param_sizes; std::vector param_types; std::vector outputs; std::vector output_sizes; std::vector output_types; va_list args; va_start(args, num_outputs); for (size_t i = 0; i < num_outputs; ++i) { size_t count = va_arg(args, uint64_t); void **futures = va_arg(args, void **); size_t sizes = va_arg(args, uint64_t); size_t types = va_arg(args, uint64_t); for (size_t j = 0; j < count; ++j) { outputs.push_back(futures[j]); output_sizes.push_back(sizes); output_types.push_back(types); } } for (size_t i = 0; i < num_params; ++i) { size_t count = va_arg(args, uint64_t); void **futures = va_arg(args, void **); size_t sizes = va_arg(args, uint64_t); size_t types = va_arg(args, uint64_t); for (size_t j = 0; j < count; ++j) { refcounted_futures.push_back(futures[j]); param_sizes.push_back(sizes); param_types.push_back(types); } } va_end(args); dfr_create_async_task_impl(wfn, ctx, refcounted_futures, param_sizes, param_types, outputs, output_sizes, output_types); } /***************************/ /* JIT execution support. */ /***************************/ void _dfr_try_initialize(); namespace mlir { namespace concretelang { namespace dfr { namespace { static bool dfr_required_p = false; static bool is_jit_p = false; static bool is_root_node_p = true; static bool use_omp_p = false; } // namespace void _dfr_set_required(bool is_required) { dfr_required_p = is_required; if (dfr_required_p) { _dfr_try_initialize(); } } void _dfr_set_jit(bool is_jit) { is_jit_p = is_jit; } void _dfr_set_use_omp(bool use_omp) { use_omp_p = use_omp; } bool _dfr_is_jit() { return is_jit_p; } bool _dfr_is_root_node() { return is_root_node_p; } bool _dfr_use_omp() { return use_omp_p; } bool _dfr_is_distributed() { return num_nodes > 1; } } // namespace dfr } // namespace concretelang } // namespace mlir void _dfr_register_work_function(wfnptr wfn) { _dfr_node_level_work_function_registry->getWorkFunctionName((void *)wfn); } /************************************/ /* Initialization & Finalization. */ /************************************/ namespace mlir { namespace concretelang { namespace dfr { namespace { static std::atomic init_guard = {0}; static uint64_t uninitialised = 0; static uint64_t active = 1; static uint64_t terminated = 2; } // namespace } // namespace dfr } // namespace concretelang } // namespace mlir static inline void _dfr_stop_impl() { if (_dfr_is_root_node()) hpx::apply([]() { hpx::finalize(); }); hpx::stop(); if (!_dfr_is_root_node()) exit(EXIT_SUCCESS); } static inline void _dfr_start_impl(int argc, char *argv[]) { BEGIN_TIME(&init_timer); dl_handle = dlopen(nullptr, RTLD_NOW); // If OpenMP is to be used, we need to force its initialization // before thread binding occurs. Otherwise OMP threads will be bound // to the core of the thread initializing the OMP runtime. if (_dfr_use_omp()) { #pragma omp parallel shared(use_omp_p) { #pragma omp critical use_omp_p = true; } } if (argc == 0) { int nCores, nOMPThreads, nHPXThreads; std::string hpxThreadNum; std::vector parameters; parameters.push_back(const_cast("__dummy_dfr_HPX_program_name__")); hwloc_topology_t topology; hwloc_topology_init(&topology); hwloc_topology_set_all_types_filter(topology, HWLOC_TYPE_FILTER_KEEP_NONE); hwloc_topology_set_type_filter(topology, HWLOC_OBJ_CORE, HWLOC_TYPE_FILTER_KEEP_ALL); hwloc_topology_load(topology); nCores = hwloc_get_nbobjs_by_type(topology, HWLOC_OBJ_CORE); if (nCores < 1) nCores = 1; // We do not directly handle this, but we should take into account // the choices made by the OpenMP runtime if we would be mixing // loop & dataflow parallelism. char *env = getenv("OMP_NUM_THREADS"); if (_dfr_use_omp() && env != nullptr) nOMPThreads = strtoul(env, NULL, 10); else if (_dfr_use_omp()) nOMPThreads = nCores; else nOMPThreads = 1; // Unless specified, we will consider that within each node loop // parallelism is the priority, so we would allocate either // ncores/OMP_NUM_THREADS or ncores-OMP_NUM_THREADS+1. Both make // sense depending on whether we have very regular computation or // not - the latter being more conservative in that we will // exploit all cores, at the risk of oversubscribing. Ideally the // distribution of hardware resources to the runtime systems // should be explicitly defined by the user. env = getenv("DFR_NUM_THREADS"); if (env != nullptr) { nHPXThreads = strtoul(env, NULL, 10); parameters.push_back(const_cast("--hpx:threads")); hpxThreadNum = std::to_string(nHPXThreads); parameters.push_back(const_cast(hpxThreadNum.c_str())); } else nHPXThreads = nCores + 1 - nOMPThreads; if (nHPXThreads < 1) nHPXThreads = 1; // If the user does not provide their own config file, one is by // default located at the root of the concrete-compiler directory. env = getenv("HPX_CONFIG_FILE"); // If no file is provided, try and check that the default is // available - otherwise use a basic default configuration. #ifdef HPX_DEFAULT_CONFIG_FILE if (env == nullptr) if (access(HPX_DEFAULT_CONFIG_FILE, F_OK) == 0) env = const_cast(HPX_DEFAULT_CONFIG_FILE); #endif if (env != nullptr) { parameters.push_back(const_cast("--hpx:config")); parameters.push_back(const_cast(env)); hpx::start(nullptr, parameters.size(), parameters.data()); } else { // Last resort configuration in case no config file could be // identified, provide some default values that make (some) // sense for homomorphic computations (stacks need to reflect // the size of ciphertexts rather than simple cleartext // scalars). if (std::find(parameters.begin(), parameters.end(), "--hpx:threads") == parameters.end()) { parameters.push_back(const_cast("--hpx:threads")); hpxThreadNum = std::to_string(nHPXThreads); parameters.push_back(const_cast(hpxThreadNum.c_str())); } parameters.push_back( const_cast("--hpx:ini=hpx.stacks.small_size=0x8000000")); parameters.push_back( const_cast("--hpx:ini=hpx.stacks.medium_size=0x10000000")); parameters.push_back( const_cast("--hpx:ini=hpx.stacks.large_size=0x20000000")); parameters.push_back( const_cast("--hpx:ini=hpx.stacks.huge_size=0x40000000")); hpx::start(nullptr, parameters.size(), parameters.data()); } } else { hpx::start(nullptr, argc, argv); } // Instantiate and initialise on each node is_root_node_p = (hpx::find_here() == hpx::find_root_locality()); num_nodes = hpx::get_num_localities().get(); new WorkFunctionRegistry(); new RuntimeContextManager(); _dfr_jit_phase_barrier = new hpx::lcos::barrier("phase_barrier", num_nodes, hpx::get_locality_id()); _dfr_startup_barrier = new hpx::lcos::barrier("startup_barrier", num_nodes, hpx::get_locality_id()); if (_dfr_is_root_node()) { // Create compute server components on each node - from the root // node only - and the corresponding compute client on the root // node. gcc = hpx::new_( hpx::default_layout(hpx::find_all_localities()), num_nodes) .get(); } END_TIME(&init_timer, "Initialization"); } /* Start/stop functions to be called from within user code (or during JIT invocation). These serve to pause/resume the runtime scheduler and to clean up used resources. */ void _dfr_start(int64_t use_dfr_p, void *ctx) { BEGIN_TIME(&whole_timer); if (use_dfr_p) { // The first invocation will initialise the runtime. As each call to // _dfr_start is matched with _dfr_stop, if this is not hte first, // we need to resume the HPX runtime. assert(init_guard != terminated && "DFR runtime: attempting to start runtime after it has been " "terminated"); uint64_t expected = uninitialised; if (init_guard.compare_exchange_strong(expected, active)) _dfr_start_impl(0, nullptr); assert(init_guard == active && "DFR runtime failed to initialise"); // If this is not the root node in a non-JIT execution, then this // node should only run the scheduler for any incoming work until // termination is flagged. If this is JIT, we need to run the // cancelled function which registers the work functions. if (!_dfr_is_root_node() && !_dfr_is_jit()) _dfr_stop_impl(); // If DFR is used and a runtime context is needed, and execution is // distributed, then broadcast from root to all compute nodes. if (num_nodes > 1 && (ctx || !_dfr_is_root_node())) { BEGIN_TIME(&broadcast_timer); _dfr_node_level_runtime_context_manager->setContext(ctx); } // If this is not JIT, then the remote nodes never reach _dfr_stop, // so root should not instantiate this barrier. if (_dfr_is_root_node() && _dfr_is_jit()) _dfr_startup_barrier->wait(); if (num_nodes > 1 && ctx) { END_TIME(&broadcast_timer, "Key broadcasting"); } } BEGIN_TIME(&compute_timer); } // This function cannot be used to terminate the runtime as it is // non-decidable if another computation phase will follow. Instead the // _dfr_terminate function provides this facility and is normally // called on exit from "main" when not using the main wrapper library. void _dfr_stop(int64_t use_dfr_p) { if (use_dfr_p) { if (num_nodes > 1) { // Non-root nodes synchronize here with the root to mark the point // where the root is free to send work out (only needed in JIT). if (!_dfr_is_root_node()) _dfr_startup_barrier->wait(); // The barrier is only needed to synchronize the different // computation phases when the compute nodes need to generate and // register new work functions in each phase. // TODO: this barrier may be removed based on how work function // registration is handled - but it is unlikely to result in much // gain as the root node would be waiting for the end of computation // on all remote nodes before reaching here anyway (dataflow // dependences). if (_dfr_is_jit()) _dfr_jit_phase_barrier->wait(); _dfr_node_level_runtime_context_manager->clearContext(); _dfr_node_level_work_function_registry->clearRegistry(); } } END_TIME(&compute_timer, "Compute"); END_TIME(&whole_timer, "Total execution"); } void _dfr_try_initialize() { // Initialize and immediately suspend the HPX runtime if not yet done. uint64_t expected = uninitialised; if (init_guard.compare_exchange_strong(expected, active)) { _dfr_start_impl(0, nullptr); } assert(init_guard == active && "DFR runtime failed to initialise"); } void _dfr_terminate() { uint64_t expected = active; if (init_guard.compare_exchange_strong(expected, terminated)) _dfr_stop_impl(); assert((init_guard == terminated || init_guard == uninitialised) && "DFR runtime failed to terminate"); } /*******************/ /* Main wrapper. */ /*******************/ extern "C" { extern int main(int argc, char *argv[]) __attribute__((weak)); extern int __real_main(int argc, char *argv[]) __attribute__((weak)); int __wrap_main(int argc, char *argv[]) { int r; _dfr_try_initialize(); // Run the actual main function. Within there should be a call to // _dfr_start to resume execution of the HPX scheduler if needed. r = __real_main(argc, argv); _dfr_terminate(); return r; } } /**********************/ /* Debug interface. */ /**********************/ size_t _dfr_debug_get_node_id() { return hpx::get_locality_id(); } size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); } void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs) { // clang-format off hpx::cout << "Task \"" << name << "\t\"" << " [" << inputs << " inputs, " << outputs << " outputs]" << " Executing on Node/Worker: " << _dfr_debug_get_node_id() << " / " << _dfr_debug_get_worker_id() << "\n" << std::flush; // clang-format on } /// Generic utility function for printing debug info void _dfr_print_debug(size_t val) { hpx::cout << "_dfr_print_debug : " << val << "\n" << std::flush; } #else // CONCRETELANG_DATAFLOW_EXECUTION_ENABLED #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Runtime/time_util.h" namespace mlir { namespace concretelang { namespace dfr { namespace { static bool is_jit_p = false; static bool use_omp_p = false; static size_t num_nodes = 1; #if CONCRETELANG_TIMING_ENABLED static struct timespec compute_timer; #endif } // namespace void _dfr_set_required(bool is_required) {} void _dfr_set_jit(bool p) { is_jit_p = p; } void _dfr_set_use_omp(bool use_omp) { use_omp_p = use_omp; } bool _dfr_is_jit() { return is_jit_p; } bool _dfr_is_root_node() { return true; } bool _dfr_use_omp() { return use_omp_p; } bool _dfr_is_distributed() { return num_nodes > 1; } } // namespace dfr } // namespace concretelang } // namespace mlir using namespace mlir::concretelang::dfr; void _dfr_start(int64_t use_dfr_p, void *ctx) { BEGIN_TIME(&compute_timer); } void _dfr_stop(int64_t use_dfr_p) { END_TIME(&compute_timer, "Compute"); } void _dfr_terminate() {} #endif