diff --git a/compiler/lib/Runtime/DFRuntime.cpp b/compiler/lib/Runtime/DFRuntime.cpp index 133f63280..569ab2500 100644 --- a/compiler/lib/Runtime/DFRuntime.cpp +++ b/compiler/lib/Runtime/DFRuntime.cpp @@ -743,8 +743,6 @@ void _dfr_set_required(bool is_required) { mlir::concretelang::dfr::dfr_required_p = is_required; if (mlir::concretelang::dfr::dfr_required_p) { _dfr_try_initialize(); - mlir::concretelang::dfr::is_root_node_p = - (hpx::find_here() == hpx::find_root_locality()); } } void _dfr_set_jit(bool is_jit) { mlir::concretelang::dfr::is_jit_p = is_jit; } @@ -782,7 +780,7 @@ static inline void _dfr_stop_impl() { static inline void _dfr_start_impl(int argc, char *argv[]) { mlir::concretelang::dfr::dl_handle = dlopen(nullptr, RTLD_NOW); if (argc == 0) { - unsigned long nCores, nOMPThreads, nHPXThreads; + int nCores, nOMPThreads, nHPXThreads; hwloc_topology_t topology; hwloc_topology_init(&topology); hwloc_topology_set_all_types_filter(topology, HWLOC_TYPE_FILTER_KEEP_NONE); @@ -792,33 +790,48 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { nCores = hwloc_get_nbobjs_by_type(topology, HWLOC_OBJ_CORE); if (nCores < 1) nCores = 1; - nOMPThreads = 1; + char *env = getenv("OMP_NUM_THREADS"); - if (env != nullptr) { + if (env != nullptr) nOMPThreads = strtoul(env, NULL, 10); - if (nOMPThreads == 0) - nOMPThreads = 1; - if (nOMPThreads >= nCores) - nOMPThreads = nCores; - } - std::string numOMPThreads = std::to_string(nOMPThreads); - setenv("OMP_NUM_THREADS", numOMPThreads.c_str(), 0); - nHPXThreads = nCores + 1 - nOMPThreads; + else + nOMPThreads = nCores; + env = getenv("DFR_NUM_THREADS"); if (env != nullptr) nHPXThreads = strtoul(env, NULL, 10); - nHPXThreads = (nHPXThreads) ? nHPXThreads : 1; - std::string numHPXThreads = std::to_string(nHPXThreads); - char *_argv[3] = {const_cast("__dummy_dfr_HPX_program_name__"), - const_cast("--hpx:threads"), - const_cast(numHPXThreads.c_str())}; - int _argc = 3; - hpx::start(nullptr, _argc, _argv); + else + nHPXThreads = nCores + 1 - nOMPThreads; + if (nHPXThreads < 1) + nHPXThreads = 1; + + env = getenv("HPX_CONFIG_FILE"); + if (env != nullptr) { + int _argc = 3; + char *_argv[3] = {const_cast("__dummy_dfr_HPX_program_name__"), + const_cast("--hpx:config"), + const_cast(env)}; + hpx::start(nullptr, _argc, _argv); + } else { + std::string numHPXThreads = std::to_string(nHPXThreads); + int _argc = 7; + char *_argv[7] = { + const_cast("__dummy_dfr_HPX_program_name__"), + const_cast("--hpx:threads"), + const_cast(numHPXThreads.c_str()), + const_cast("--hpx:ini=hpx.stacks.small_size=0x8000000"), + const_cast("--hpx:ini=hpx.stacks.medium_size=0x10000000"), + const_cast("--hpx:ini=hpx.stacks.large_size=0x20000000"), + const_cast("--hpx:ini=hpx.stacks.huge_size=0x40000000")}; + hpx::start(nullptr, _argc, _argv); + } } else { hpx::start(nullptr, argc, argv); } - // Instantiate on each node + // Instantiate and initialise on each node + mlir::concretelang::dfr::is_root_node_p = + (hpx::find_here() == hpx::find_root_locality()); new mlir::concretelang::dfr::KeyManager(); new mlir::concretelang::dfr::KeyManager(); new mlir::concretelang::dfr::WorkFunctionRegistry();