diff --git a/compiler/include/concretelang/Runtime/DFRuntime.hpp b/compiler/include/concretelang/Runtime/DFRuntime.hpp index 9f97c2523..ea9fd3e7b 100644 --- a/compiler/include/concretelang/Runtime/DFRuntime.hpp +++ b/compiler/include/concretelang/Runtime/DFRuntime.hpp @@ -14,11 +14,14 @@ #include "concretelang/Runtime/runtime_api.h" -bool _dfr_set_required(bool); +namespace mlir { +namespace concretelang { +namespace dfr { + +void _dfr_set_required(bool); void _dfr_set_jit(bool); bool _dfr_is_jit(); bool _dfr_is_root_node(); -void _dfr_terminate(); typedef enum _dfr_task_arg_type { _DFR_TASK_ARG_BASE = 0, @@ -42,4 +45,7 @@ static inline uint64_t _dfr_set_memref_element_size(uint64_t val, size_t size) { return (val & 0xFF) | (((uint64_t)size) << 8); } +} // namespace dfr +} // namespace concretelang +} // namespace mlir #endif diff --git a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp index 0ece268e3..8908e3a18 100644 --- a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp +++ b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp @@ -35,13 +35,14 @@ #include "concretelang/Runtime/runtime_api.h" #include "concretelang/Runtime/workfunction_registry.hpp" -extern WorkFunctionRegistry *_dfr_node_level_work_function_registry; -extern std::list new_allocated; - using namespace hpx::naming; using namespace hpx::components; using namespace hpx::collectives; +namespace mlir { +namespace concretelang { +namespace dfr { + static inline size_t _dfr_get_memref_rank(size_t size) { return (size - 2 * sizeof(char *) /*allocated_ptr & aligned_ptr*/ - sizeof(int64_t) /*offset*/) / @@ -619,15 +620,26 @@ struct GenericComputeServer : component_base { HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task); }; -HPX_REGISTER_ACTION_DECLARATION(GenericComputeServer::execute_task_action, - GenericComputeServer_execute_task_action) +} // namespace dfr +} // namespace concretelang +} // namespace mlir + +HPX_REGISTER_ACTION_DECLARATION( + mlir::concretelang::dfr::GenericComputeServer::execute_task_action, + GenericComputeServer_execute_task_action) HPX_REGISTER_COMPONENT_MODULE() -HPX_REGISTER_COMPONENT(hpx::components::component, - GenericComputeServer) +HPX_REGISTER_COMPONENT( + hpx::components::component, + GenericComputeServer) -HPX_REGISTER_ACTION(GenericComputeServer::execute_task_action, - GenericComputeServer_execute_task_action) +HPX_REGISTER_ACTION( + mlir::concretelang::dfr::GenericComputeServer::execute_task_action, + GenericComputeServer_execute_task_action) + +namespace mlir { +namespace concretelang { +namespace dfr { struct GenericComputeClient : client_base { @@ -642,4 +654,7 @@ struct GenericComputeClient } }; +} // namespace dfr +} // namespace concretelang +} // namespace mlir #endif diff --git a/compiler/include/concretelang/Runtime/key_manager.hpp b/compiler/include/concretelang/Runtime/key_manager.hpp index 3ffb43816..86d4942fb 100644 --- a/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compiler/include/concretelang/Runtime/key_manager.hpp @@ -20,11 +20,20 @@ extern "C" { #include "concrete-ffi.h" } -extern std::list new_allocated; +namespace mlir { +namespace concretelang { +namespace dfr { template struct KeyManager; -extern KeyManager *_dfr_node_level_bsk_manager; -extern KeyManager *_dfr_node_level_ksk_manager; +namespace { +static void *dl_handle; +static KeyManager *_dfr_node_level_bsk_manager; +static KeyManager *_dfr_node_level_ksk_manager; +static std::list new_allocated; +static std::list fut_allocated; +static std::list m_allocated; +} // namespace + void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id); void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id); @@ -84,11 +93,6 @@ void KeyWrapper::load(Archive &ar, key = deserialize_lwe_keyswitching_key_u64(buffer); } -KeyWrapper _dfr_fetch_ksk(uint64_t); -HPX_PLAIN_ACTION(_dfr_fetch_ksk, _dfr_fetch_ksk_action) -KeyWrapper _dfr_fetch_bsk(uint64_t); -HPX_PLAIN_ACTION(_dfr_fetch_bsk, _dfr_fetch_bsk_action) - template struct KeyManager { KeyManager() {} LweKeyType *get_key(hpx::naming::id_type loc, const uint64_t key_id); @@ -131,6 +135,25 @@ private: std::map> keystore; }; +KeyWrapper _dfr_fetch_bsk(uint64_t key_id) { + return _dfr_node_level_bsk_manager->fetch_key(key_id); +} + +KeyWrapper _dfr_fetch_ksk(uint64_t key_id) { + return _dfr_node_level_ksk_manager->fetch_key(key_id); +} + +} // namespace dfr +} // namespace concretelang +} // namespace mlir + +HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_ksk, _dfr_fetch_ksk_action) +HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_bsk, _dfr_fetch_bsk_action) + +namespace mlir { +namespace concretelang { +namespace dfr { + template <> KeyManager::KeyManager() { _dfr_node_level_bsk_manager = this; } @@ -183,14 +206,6 @@ KeyManager::get_key(hpx::naming::id_type loc, return keyit->second.key; } -KeyWrapper _dfr_fetch_bsk(uint64_t key_id) { - return _dfr_node_level_bsk_manager->fetch_key(key_id); -} - -KeyWrapper _dfr_fetch_ksk(uint64_t key_id) { - return _dfr_node_level_ksk_manager->fetch_key(key_id); -} - /************************/ /* Key management API. */ /************************/ @@ -209,4 +224,7 @@ LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) { return _dfr_node_level_ksk_manager->get_key(loc, key_id); } +} // namespace dfr +} // namespace concretelang +} // namespace mlir #endif diff --git a/compiler/include/concretelang/Runtime/workfunction_registry.hpp b/compiler/include/concretelang/Runtime/workfunction_registry.hpp index 2435b8e5e..5ddfb504d 100644 --- a/compiler/include/concretelang/Runtime/workfunction_registry.hpp +++ b/compiler/include/concretelang/Runtime/workfunction_registry.hpp @@ -16,9 +16,14 @@ #include "concretelang/Runtime/DFRuntime.hpp" -extern void *dl_handle; +namespace mlir { +namespace concretelang { +namespace dfr { + struct WorkFunctionRegistry; -extern WorkFunctionRegistry *_dfr_node_level_work_function_registry; +namespace { +static WorkFunctionRegistry *_dfr_node_level_work_function_registry; +} struct WorkFunctionRegistry { WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; } @@ -89,4 +94,7 @@ private: std::map name_to_ptr_registry; }; +} // namespace dfr +} // namespace concretelang +} // namespace mlir #endif diff --git a/compiler/lib/CAPI/Support/CompilerEngine.cpp b/compiler/lib/CAPI/Support/CompilerEngine.cpp index 5496e59b2..bd21ee53b 100644 --- a/compiler/lib/CAPI/Support/CompilerEngine.cpp +++ b/compiler/lib/CAPI/Support/CompilerEngine.cpp @@ -8,7 +8,7 @@ #include "concretelang-c/Support/CompilerEngine.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/ClientLib/Serializers.h" -#include "concretelang/Runtime/runtime_api.h" +#include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Support/CompilerEngine.h" #include "concretelang/Support/JITSupport.h" #include "concretelang/Support/Jit.h" diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index f0becbc9a..1886f18a4 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -158,8 +158,9 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) { // elementAttr |= _DFR_TASK_ARG_MEMREF; uint64_t elementAttr = 0; size_t element_size = dataLayout.getTypeSize(elementType); - elementAttr = _dfr_set_arg_type(elementAttr, _DFR_TASK_ARG_MEMREF); - elementAttr = _dfr_set_memref_element_size(elementAttr, element_size); + elementAttr = + dfr::_dfr_set_arg_type(elementAttr, dfr::_DFR_TASK_ARG_MEMREF); + elementAttr = dfr::_dfr_set_memref_element_size(elementAttr, element_size); Value arg_type = builder.create( loc, builder.getI64IntegerAttr(elementAttr)); return std::pair(typeSize, arg_type); @@ -169,14 +170,14 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) { // bytes. if (type.isa()) { Value arg_type = builder.create( - loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_UNRANKED_MEMREF)); + loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_UNRANKED_MEMREF)); Value result = builder.create(loc, builder.getI64IntegerAttr(16)); return std::pair(result, arg_type); } Value arg_type = builder.create( - loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_BASE)); + loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_BASE)); // FHE types are converted to pointers, so we take their size as 8 // bytes until we can get the actual size of the actual types. @@ -191,7 +192,7 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) { return std::pair(result, arg_type); } else if (type.isa()) { Value arg_type = builder.create( - loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_CONTEXT)); + loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT)); Value result = builder.create(loc, builder.getI64IntegerAttr(8)); return std::pair(result, arg_type); @@ -364,7 +365,7 @@ struct LowerDataflowTasksPass // we do not need to do any computation, only register all work // functions with the runtime system if (!outliningMap.empty()) { - if (!_dfr_is_root_node()) { + if (!dfr::_dfr_is_root_node()) { // auto regFunc = builder.create(func.getLoc(), // func.getName(), func.getType()); diff --git a/compiler/lib/Runtime/DFRuntime.cpp b/compiler/lib/Runtime/DFRuntime.cpp index 3cedf77d5..133f63280 100644 --- a/compiler/lib/Runtime/DFRuntime.cpp +++ b/compiler/lib/Runtime/DFRuntime.cpp @@ -21,27 +21,25 @@ #include "concretelang/Runtime/distributed_generic_task_server.hpp" #include "concretelang/Runtime/runtime_api.h" -std::vector gcc; -void *dl_handle; - -KeyManager *_dfr_node_level_bsk_manager; -KeyManager *_dfr_node_level_ksk_manager; - -WorkFunctionRegistry *_dfr_node_level_work_function_registry; -std::list new_allocated; -std::list fut_allocated; -std::list m_allocated; -hpx::lcos::barrier *_dfr_jit_workfunction_registration_barrier; -hpx::lcos::barrier *_dfr_jit_phase_barrier; -std::atomic init_guard = {0}; +namespace mlir { +namespace concretelang { +namespace dfr { +namespace { +static std::vector gcc; +static hpx::lcos::barrier *_dfr_jit_workfunction_registration_barrier; +static hpx::lcos::barrier *_dfr_jit_phase_barrier; +} // namespace +} // namespace dfr +} // namespace concretelang +} // namespace mlir using namespace hpx; void *_dfr_make_ready_future(void *in) { void *future = static_cast( new hpx::shared_future(hpx::make_ready_future(in))); - m_allocated.push_back(in); - fut_allocated.push_back(future); + mlir::concretelang::dfr::m_allocated.push_back(in); + mlir::concretelang::dfr::fut_allocated.push_back(future); return future; } @@ -98,8 +96,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, va_end(args); for (size_t i = 0; i < num_params; ++i) { - if (_dfr_get_arg_type(param_types[i] == _DFR_TASK_ARG_MEMREF)) { - m_allocated.push_back( + if (mlir::concretelang::dfr::_dfr_get_arg_type( + param_types[i] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF)) { + mlir::concretelang::dfr::m_allocated.push_back( (void *)static_cast *>(params[i])->data); } } @@ -108,9 +107,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, // shared memory as pointers suffice, but is needed in the // distributed case where the functions need to be located/loaded on // the node. - auto wfnname = - _dfr_node_level_work_function_registry->getWorkFunctionName((void *)wfn); - hpx::future> oodf; + auto wfnname = mlir::concretelang::dfr::_dfr_node_level_work_function_registry + ->getWorkFunctionName((void *)wfn); + hpx::future> oodf; // In order to allow complete dataflow semantics for // communication/synchronization, we split tasks in two parts: an @@ -120,13 +119,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, // individual synchronization for each return independently. switch (num_params) { case 0: - oodf = std::move( - hpx::dataflow([wfnname, param_sizes, param_types, output_sizes, - output_types]() -> hpx::future { + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, param_types, output_sizes, output_types]() + -> hpx::future { std::vector params = {}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); })); break; @@ -134,11 +136,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types](hpx::shared_future param0) - -> hpx::future { + -> hpx::future { std::vector params = {param0.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0])); break; @@ -148,11 +153,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [wfnname, param_sizes, param_types, output_sizes, output_types](hpx::shared_future param0, hpx::shared_future param1) - -> hpx::future { + -> hpx::future { std::vector params = {param0.get(), param1.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1])); @@ -164,12 +172,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, output_types](hpx::shared_future param0, hpx::shared_future param1, hpx::shared_future param2) - -> hpx::future { + -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -183,12 +194,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param1, hpx::shared_future param2, hpx::shared_future param3) - -> hpx::future { + -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get(), param3.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -204,13 +218,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param2, hpx::shared_future param3, hpx::shared_future param4) - -> hpx::future { + -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get(), param3.get(), param4.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -228,13 +245,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param3, hpx::shared_future param4, hpx::shared_future param5) - -> hpx::future { + -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -254,13 +274,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param4, hpx::shared_future param5, hpx::shared_future param6) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -282,13 +305,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param5, hpx::shared_future param6, hpx::shared_future param7) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -312,14 +338,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param6, hpx::shared_future param7, hpx::shared_future param8) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -345,14 +374,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param7, hpx::shared_future param8, hpx::shared_future param9) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get(), param9.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -380,14 +412,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param8, hpx::shared_future param9, hpx::shared_future param10) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get(), param9.get(), param10.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -417,14 +452,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param9, hpx::shared_future param10, hpx::shared_future param11) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get(), param9.get(), param10.get(), param11.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -456,15 +494,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param10, hpx::shared_future param11, hpx::shared_future param12) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get(), param9.get(), param10.get(), param11.get(), param12.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -498,15 +539,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param11, hpx::shared_future param12, hpx::shared_future param13) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get(), param9.get(), param10.get(), param11.get(), param12.get(), param13.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -542,15 +586,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param12, hpx::shared_future param13, hpx::shared_future param14) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get(), param9.get(), param10.get(), param11.get(), param12.get(), param13.get(), param14.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -588,15 +635,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, hpx::shared_future param13, hpx::shared_future param14, hpx::shared_future param15) - -> hpx::future { + -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get(), param8.get(), param9.get(), param10.get(), param11.get(), param12.get(), param13.get(), param14.get(), param15.get()}; - OpaqueInputData oid(wfnname, params, param_sizes, param_types, - output_sizes, output_types); - return gcc[_dfr_find_next_execution_locality()].execute_task(oid); + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); }, *(hpx::shared_future *)params[0], *(hpx::shared_future *)params[1], @@ -624,16 +674,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, switch (num_outputs) { case 1: *((void **)outputs[0]) = new hpx::shared_future(hpx::dataflow( - [](hpx::future oodf_in) -> void * { - return oodf_in.get().outputs[0]; - }, + [](hpx::future oodf_in) + -> void * { return oodf_in.get().outputs[0]; }, oodf)); - fut_allocated.push_back(*((void **)outputs[0])); + mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0])); break; case 2: { hpx::future> &&ft = hpx::dataflow( - [](hpx::future oodf_in) + [](hpx::future oodf_in) -> hpx::tuple { std::vector outputs = std::move(oodf_in.get().outputs); return hpx::make_tuple<>(outputs[0], outputs[1]); @@ -645,14 +694,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, (void *)new hpx::shared_future(std::move(hpx::get<0>(tf))); *((void **)outputs[1]) = (void *)new hpx::shared_future(std::move(hpx::get<1>(tf))); - fut_allocated.push_back(*((void **)outputs[0])); - fut_allocated.push_back(*((void **)outputs[1])); + mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0])); + mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1])); break; } case 3: { hpx::future> &&ft = hpx::dataflow( - [](hpx::future oodf_in) + [](hpx::future oodf_in) -> hpx::tuple { std::vector outputs = std::move(oodf_in.get().outputs); return hpx::make_tuple<>(outputs[0], outputs[1], outputs[2]); @@ -666,9 +715,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, (void *)new hpx::shared_future(std::move(hpx::get<1>(tf))); *((void **)outputs[2]) = (void *)new hpx::shared_future(std::move(hpx::get<2>(tf))); - fut_allocated.push_back(*((void **)outputs[0])); - fut_allocated.push_back(*((void **)outputs[1])); - fut_allocated.push_back(*((void **)outputs[2])); + mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0])); + mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1])); + mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[2])); break; } default: @@ -681,45 +730,57 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, /* JIT execution support. */ /***************************/ void _dfr_try_initialize(); +namespace mlir { +namespace concretelang { +namespace dfr { namespace { static bool dfr_required_p = false; static bool is_jit_p = false; +static bool is_root_node_p = true; } // namespace -bool _dfr_set_required(bool is_required) { - dfr_required_p = is_required; - if (dfr_required_p) + +void _dfr_set_required(bool is_required) { + mlir::concretelang::dfr::dfr_required_p = is_required; + if (mlir::concretelang::dfr::dfr_required_p) { _dfr_try_initialize(); - return true; + mlir::concretelang::dfr::is_root_node_p = + (hpx::find_here() == hpx::find_root_locality()); + } } -void _dfr_set_jit(bool is_jit) { is_jit_p = is_jit; } -bool _dfr_is_jit() { return is_jit_p; } - -static inline bool _dfr_is_root_node_impl() { - static bool is_root_node_p = - (!dfr_required_p || (hpx::find_here() == hpx::find_root_locality())); - return is_root_node_p; -} - -bool _dfr_is_root_node() { return _dfr_is_root_node_impl(); } +void _dfr_set_jit(bool is_jit) { mlir::concretelang::dfr::is_jit_p = is_jit; } +bool _dfr_is_jit() { return mlir::concretelang::dfr::is_jit_p; } +bool _dfr_is_root_node() { return mlir::concretelang::dfr::is_root_node_p; } +} // namespace dfr +} // namespace concretelang +} // namespace mlir void _dfr_register_work_function(wfnptr wfn) { - _dfr_node_level_work_function_registry->registerAnonymousWorkFunction( - (void *)wfn); + mlir::concretelang::dfr::_dfr_node_level_work_function_registry + ->registerAnonymousWorkFunction((void *)wfn); } /************************************/ /* Initialization & Finalization. */ /************************************/ +namespace mlir { +namespace concretelang { +namespace dfr { +namespace { +static std::atomic init_guard = {0}; +} +} // namespace dfr +} // namespace concretelang +} // namespace mlir static inline void _dfr_stop_impl() { - if (_dfr_is_root_node()) + if (mlir::concretelang::dfr::_dfr_is_root_node()) hpx::apply([]() { hpx::finalize(); }); hpx::stop(); - if (!_dfr_is_root_node()) + if (!mlir::concretelang::dfr::_dfr_is_root_node()) exit(EXIT_SUCCESS); } static inline void _dfr_start_impl(int argc, char *argv[]) { - dl_handle = dlopen(nullptr, RTLD_NOW); + mlir::concretelang::dfr::dl_handle = dlopen(nullptr, RTLD_NOW); if (argc == 0) { unsigned long nCores, nOMPThreads, nHPXThreads; hwloc_topology_t topology; @@ -758,23 +819,25 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { } // Instantiate on each node - new KeyManager(); - new KeyManager(); - new WorkFunctionRegistry(); - _dfr_jit_workfunction_registration_barrier = new hpx::lcos::barrier( - "wait_register_remote_work_functions", hpx::get_num_localities().get(), - hpx::get_locality_id()); - _dfr_jit_phase_barrier = new hpx::lcos::barrier( + new mlir::concretelang::dfr::KeyManager(); + new mlir::concretelang::dfr::KeyManager(); + new mlir::concretelang::dfr::WorkFunctionRegistry(); + mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier = + new hpx::lcos::barrier("wait_register_remote_work_functions", + hpx::get_num_localities().get(), + hpx::get_locality_id()); + mlir::concretelang::dfr::_dfr_jit_phase_barrier = new hpx::lcos::barrier( "phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id()); - if (_dfr_is_root_node()) { + if (mlir::concretelang::dfr::_dfr_is_root_node()) { // Create compute server components on each node - from the root // node only - and the corresponding compute client on the root // node. auto num_nodes = hpx::get_num_localities().get(); - gcc = hpx::new_( - hpx::default_layout(hpx::find_all_localities()), num_nodes) - .get(); + mlir::concretelang::dfr::gcc = + hpx::new_( + hpx::default_layout(hpx::find_all_localities()), num_nodes) + .get(); } } @@ -788,23 +851,27 @@ void _dfr_start() { uint64_t uninitialised = 0; uint64_t active = 1; uint64_t suspended = 2; - if (init_guard.compare_exchange_strong(uninitialised, active)) + if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(uninitialised, + active)) _dfr_start_impl(0, nullptr); - else if (init_guard.compare_exchange_strong(suspended, active)) + else if (mlir::concretelang::dfr::init_guard.compare_exchange_strong( + suspended, active)) hpx::resume(); // If this is not the root node in a non-JIT execution, then this // node should only run the scheduler for any incoming work until // termination is flagged. If this is JIT, we need to run the // cancelled function which registers the work functions. - if (!_dfr_is_root_node() && !_dfr_is_jit()) + if (!mlir::concretelang::dfr::_dfr_is_root_node() && + !mlir::concretelang::dfr::_dfr_is_jit()) _dfr_stop_impl(); // TODO: conditional -- If this is the root node, and this is JIT // execution, we need to wait for the compute nodes to compile and // register work functions - if (_dfr_is_root_node() && _dfr_is_jit()) { - _dfr_jit_workfunction_registration_barrier->wait(); + if (mlir::concretelang::dfr::_dfr_is_root_node() && + mlir::concretelang::dfr::_dfr_is_jit()) { + mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier->wait(); } } @@ -817,8 +884,9 @@ void _dfr_stop() { // where the root is free to send work out. // TODO: optimize this by moving synchro to local remote nodes // waiting in the scheduler for registration. - if (!_dfr_is_root_node() /*&& _dfr_is_jit() /** implicitly true*/) { - _dfr_jit_workfunction_registration_barrier->wait(); + if (!mlir::concretelang::dfr:: + _dfr_is_root_node() /*&& _dfr_is_jit() /** implicitly true*/) { + mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier->wait(); } // The barrier is only needed to synchronize the different @@ -830,8 +898,8 @@ void _dfr_stop() { // gain as the root node would be waiting for the end of computation // on all remote nodes before reaching here anyway (dataflow // dependences). - if (_dfr_is_jit()) { - _dfr_jit_phase_barrier->wait(); + if (mlir::concretelang::dfr::_dfr_is_jit()) { + mlir::concretelang::dfr::_dfr_jit_phase_barrier->wait(); } // TODO: this can be removed along with the matching hpx::resume if @@ -839,25 +907,28 @@ void _dfr_stop() { // threads outside of parallel regions - to be tested. uint64_t active = 1; uint64_t suspended = 2; - if (init_guard.compare_exchange_strong(active, suspended)) + if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(active, + suspended)) hpx::suspend(); // TODO: until we have better unique identifiers for keys it is // safer to drop them in-between phases. - _dfr_node_level_bsk_manager->clear_keys(); - _dfr_node_level_ksk_manager->clear_keys(); + mlir::concretelang::dfr::_dfr_node_level_bsk_manager->clear_keys(); + mlir::concretelang::dfr::_dfr_node_level_ksk_manager->clear_keys(); - while (!new_allocated.empty()) { - delete[] static_cast(new_allocated.front()); - new_allocated.pop_front(); + while (!mlir::concretelang::dfr::new_allocated.empty()) { + delete[] static_cast( + mlir::concretelang::dfr::new_allocated.front()); + mlir::concretelang::dfr::new_allocated.pop_front(); } - while (!fut_allocated.empty()) { - delete static_cast *>(fut_allocated.front()); - fut_allocated.pop_front(); + while (!mlir::concretelang::dfr::fut_allocated.empty()) { + delete static_cast *>( + mlir::concretelang::dfr::fut_allocated.front()); + mlir::concretelang::dfr::fut_allocated.pop_front(); } - while (!m_allocated.empty()) { - free(m_allocated.front()); - m_allocated.pop_front(); + while (!mlir::concretelang::dfr::m_allocated.empty()) { + free(mlir::concretelang::dfr::m_allocated.front()); + mlir::concretelang::dfr::m_allocated.pop_front(); } } @@ -865,7 +936,8 @@ void _dfr_try_initialize() { // Initialize and immediately suspend the HPX runtime if not yet done. uint64_t uninitialised = 0; uint64_t suspended = 2; - if (init_guard.compare_exchange_strong(uninitialised, suspended)) { + if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(uninitialised, + suspended)) { _dfr_start_impl(0, nullptr); hpx::suspend(); } @@ -875,9 +947,11 @@ void _dfr_terminate() { uint64_t active = 1; uint64_t suspended = 2; uint64_t terminated = 3; - if (init_guard.compare_exchange_strong(suspended, active)) + if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(suspended, + active)) hpx::resume(); - if (init_guard.compare_exchange_strong(active, terminated)) + if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(active, + terminated)) _dfr_stop_impl(); } @@ -925,8 +999,20 @@ void _dfr_print_debug(size_t val) { #include "concretelang/Runtime/DFRuntime.hpp" -bool _dfr_set_required(bool is_required) { return !is_required; } -void _dfr_set_jit(bool) {} +namespace mlir { +namespace concretelang { +namespace dfr { +namespace { +static bool is_jit_p = false; +} // namespace + +void _dfr_set_required(bool is_required) {} +void _dfr_set_jit(bool p) { is_jit_p = p; } +bool _dfr_is_jit() { return is_jit_p; } bool _dfr_is_root_node() { return true; } +} // namespace dfr +} // namespace concretelang +} // namespace mlir + void _dfr_terminate() {} #endif diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc index c48a323e1..79e2eeac8 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_auto_parallelization.cc @@ -59,7 +59,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, % )XXX", "main", false, true, false); - if (_dfr_is_root_node()) { + if (mlir::concretelang::dfr::_dfr_is_root_node()) { llvm::Expected res_1 = lambda(1_u64, 2_u64, 3_u64, 4_u64); llvm::Expected res_2 = lambda(4_u64, 5_u64, 6_u64, 7_u64); llvm::Expected res_3 = lambda(1_u64, 1_u64, 1_u64, 1_u64); @@ -112,7 +112,7 @@ TEST(ParallelizeAndRunFHE, nn_small_parallel) { mlir::concretelang::IntLambdaArgument> arg(input, shape2D); - if (_dfr_is_root_node()) { + if (mlir::concretelang::dfr::_dfr_is_root_node()) { llvm::Expected> res = lambda.operator()>({&arg}); ASSERT_EXPECTED_SUCCESS(res); @@ -154,7 +154,7 @@ TEST(ParallelizeAndRunFHE, nn_small_sequential) { arg(input, shape2D); // This is sequential: only execute on root node. - if (_dfr_is_root_node()) { + if (mlir::concretelang::dfr::_dfr_is_root_node()) { llvm::Expected> res = lambda.operator()>({&arg}); ASSERT_EXPECTED_SUCCESS(res);