fix(dfr): simplify initialization and add namespaces where possible.

This commit is contained in:
Antoniu Pop
2022-04-07 13:20:52 +01:00
committed by Antoniu Pop
parent c79cedf557
commit 615109d432
8 changed files with 330 additions and 196 deletions

View File

@@ -14,11 +14,14 @@
#include "concretelang/Runtime/runtime_api.h"
bool _dfr_set_required(bool);
namespace mlir {
namespace concretelang {
namespace dfr {
void _dfr_set_required(bool);
void _dfr_set_jit(bool);
bool _dfr_is_jit();
bool _dfr_is_root_node();
void _dfr_terminate();
typedef enum _dfr_task_arg_type {
_DFR_TASK_ARG_BASE = 0,
@@ -42,4 +45,7 @@ static inline uint64_t _dfr_set_memref_element_size(uint64_t val, size_t size) {
return (val & 0xFF) | (((uint64_t)size) << 8);
}
} // namespace dfr
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -35,13 +35,14 @@
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Runtime/workfunction_registry.hpp"
extern WorkFunctionRegistry *_dfr_node_level_work_function_registry;
extern std::list<void *> new_allocated;
using namespace hpx::naming;
using namespace hpx::components;
using namespace hpx::collectives;
namespace mlir {
namespace concretelang {
namespace dfr {
static inline size_t _dfr_get_memref_rank(size_t size) {
return (size - 2 * sizeof(char *) /*allocated_ptr & aligned_ptr*/
- sizeof(int64_t) /*offset*/) /
@@ -619,15 +620,26 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task);
};
HPX_REGISTER_ACTION_DECLARATION(GenericComputeServer::execute_task_action,
GenericComputeServer_execute_task_action)
} // namespace dfr
} // namespace concretelang
} // namespace mlir
HPX_REGISTER_ACTION_DECLARATION(
mlir::concretelang::dfr::GenericComputeServer::execute_task_action,
GenericComputeServer_execute_task_action)
HPX_REGISTER_COMPONENT_MODULE()
HPX_REGISTER_COMPONENT(hpx::components::component<GenericComputeServer>,
GenericComputeServer)
HPX_REGISTER_COMPONENT(
hpx::components::component<mlir::concretelang::dfr::GenericComputeServer>,
GenericComputeServer)
HPX_REGISTER_ACTION(GenericComputeServer::execute_task_action,
GenericComputeServer_execute_task_action)
HPX_REGISTER_ACTION(
mlir::concretelang::dfr::GenericComputeServer::execute_task_action,
GenericComputeServer_execute_task_action)
namespace mlir {
namespace concretelang {
namespace dfr {
struct GenericComputeClient
: client_base<GenericComputeClient, GenericComputeServer> {
@@ -642,4 +654,7 @@ struct GenericComputeClient
}
};
} // namespace dfr
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -20,11 +20,20 @@ extern "C" {
#include "concrete-ffi.h"
}
extern std::list<void *> new_allocated;
namespace mlir {
namespace concretelang {
namespace dfr {
template <typename T> struct KeyManager;
extern KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
extern KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
namespace {
static void *dl_handle;
static KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
static KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
static std::list<void *> new_allocated;
static std::list<void *> fut_allocated;
static std::list<void *> m_allocated;
} // namespace
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id);
@@ -84,11 +93,6 @@ void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
key = deserialize_lwe_keyswitching_key_u64(buffer);
}
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t);
HPX_PLAIN_ACTION(_dfr_fetch_ksk, _dfr_fetch_ksk_action)
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t);
HPX_PLAIN_ACTION(_dfr_fetch_bsk, _dfr_fetch_bsk_action)
template <typename LweKeyType> struct KeyManager {
KeyManager() {}
LweKeyType *get_key(hpx::naming::id_type loc, const uint64_t key_id);
@@ -131,6 +135,25 @@ private:
std::map<uint64_t, KeyWrapper<LweKeyType>> keystore;
};
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t key_id) {
return _dfr_node_level_bsk_manager->fetch_key(key_id);
}
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t key_id) {
return _dfr_node_level_ksk_manager->fetch_key(key_id);
}
} // namespace dfr
} // namespace concretelang
} // namespace mlir
HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_ksk, _dfr_fetch_ksk_action)
HPX_PLAIN_ACTION(mlir::concretelang::dfr::_dfr_fetch_bsk, _dfr_fetch_bsk_action)
namespace mlir {
namespace concretelang {
namespace dfr {
template <> KeyManager<LweBootstrapKey_u64>::KeyManager() {
_dfr_node_level_bsk_manager = this;
}
@@ -183,14 +206,6 @@ KeyManager<LweKeyswitchKey_u64>::get_key(hpx::naming::id_type loc,
return keyit->second.key;
}
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t key_id) {
return _dfr_node_level_bsk_manager->fetch_key(key_id);
}
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t key_id) {
return _dfr_node_level_ksk_manager->fetch_key(key_id);
}
/************************/
/* Key management API. */
/************************/
@@ -209,4 +224,7 @@ LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
}
} // namespace dfr
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -16,9 +16,14 @@
#include "concretelang/Runtime/DFRuntime.hpp"
extern void *dl_handle;
namespace mlir {
namespace concretelang {
namespace dfr {
struct WorkFunctionRegistry;
extern WorkFunctionRegistry *_dfr_node_level_work_function_registry;
namespace {
static WorkFunctionRegistry *_dfr_node_level_work_function_registry;
}
struct WorkFunctionRegistry {
WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; }
@@ -89,4 +94,7 @@ private:
std::map<std::string, const void *> name_to_ptr_registry;
};
} // namespace dfr
} // namespace concretelang
} // namespace mlir
#endif

View File

@@ -8,7 +8,7 @@
#include "concretelang-c/Support/CompilerEngine.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/ClientLib/Serializers.h"
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/JITSupport.h"
#include "concretelang/Support/Jit.h"

View File

@@ -158,8 +158,9 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) {
// elementAttr |= _DFR_TASK_ARG_MEMREF;
uint64_t elementAttr = 0;
size_t element_size = dataLayout.getTypeSize(elementType);
elementAttr = _dfr_set_arg_type(elementAttr, _DFR_TASK_ARG_MEMREF);
elementAttr = _dfr_set_memref_element_size(elementAttr, element_size);
elementAttr =
dfr::_dfr_set_arg_type(elementAttr, dfr::_DFR_TASK_ARG_MEMREF);
elementAttr = dfr::_dfr_set_memref_element_size(elementAttr, element_size);
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(elementAttr));
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
@@ -169,14 +170,14 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) {
// bytes.
if (type.isa<mlir::UnrankedMemRefType>()) {
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_UNRANKED_MEMREF));
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_UNRANKED_MEMREF));
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(16));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
}
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_BASE));
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_BASE));
// FHE types are converted to pointers, so we take their size as 8
// bytes until we can get the actual size of the actual types.
@@ -191,7 +192,7 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) {
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
} else if (type.isa<mlir::concretelang::Concrete::ContextType>()) {
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_CONTEXT));
loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT));
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
@@ -364,7 +365,7 @@ struct LowerDataflowTasksPass
// we do not need to do any computation, only register all work
// functions with the runtime system
if (!outliningMap.empty()) {
if (!_dfr_is_root_node()) {
if (!dfr::_dfr_is_root_node()) {
// auto regFunc = builder.create<FuncOp>(func.getLoc(),
// func.getName(), func.getType());

View File

@@ -21,27 +21,25 @@
#include "concretelang/Runtime/distributed_generic_task_server.hpp"
#include "concretelang/Runtime/runtime_api.h"
std::vector<GenericComputeClient> gcc;
void *dl_handle;
KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
WorkFunctionRegistry *_dfr_node_level_work_function_registry;
std::list<void *> new_allocated;
std::list<void *> fut_allocated;
std::list<void *> m_allocated;
hpx::lcos::barrier *_dfr_jit_workfunction_registration_barrier;
hpx::lcos::barrier *_dfr_jit_phase_barrier;
std::atomic<uint64_t> init_guard = {0};
namespace mlir {
namespace concretelang {
namespace dfr {
namespace {
static std::vector<GenericComputeClient> gcc;
static hpx::lcos::barrier *_dfr_jit_workfunction_registration_barrier;
static hpx::lcos::barrier *_dfr_jit_phase_barrier;
} // namespace
} // namespace dfr
} // namespace concretelang
} // namespace mlir
using namespace hpx;
void *_dfr_make_ready_future(void *in) {
void *future = static_cast<void *>(
new hpx::shared_future<void *>(hpx::make_ready_future(in)));
m_allocated.push_back(in);
fut_allocated.push_back(future);
mlir::concretelang::dfr::m_allocated.push_back(in);
mlir::concretelang::dfr::fut_allocated.push_back(future);
return future;
}
@@ -98,8 +96,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
va_end(args);
for (size_t i = 0; i < num_params; ++i) {
if (_dfr_get_arg_type(param_types[i] == _DFR_TASK_ARG_MEMREF)) {
m_allocated.push_back(
if (mlir::concretelang::dfr::_dfr_get_arg_type(
param_types[i] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF)) {
mlir::concretelang::dfr::m_allocated.push_back(
(void *)static_cast<StridedMemRefType<char, 1> *>(params[i])->data);
}
}
@@ -108,9 +107,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
// shared memory as pointers suffice, but is needed in the
// distributed case where the functions need to be located/loaded on
// the node.
auto wfnname =
_dfr_node_level_work_function_registry->getWorkFunctionName((void *)wfn);
hpx::future<hpx::future<OpaqueOutputData>> oodf;
auto wfnname = mlir::concretelang::dfr::_dfr_node_level_work_function_registry
->getWorkFunctionName((void *)wfn);
hpx::future<hpx::future<mlir::concretelang::dfr::OpaqueOutputData>> oodf;
// In order to allow complete dataflow semantics for
// communication/synchronization, we split tasks in two parts: an
@@ -120,13 +119,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
// individual synchronization for each return independently.
switch (num_params) {
case 0:
oodf = std::move(
hpx::dataflow([wfnname, param_sizes, param_types, output_sizes,
output_types]() -> hpx::future<OpaqueOutputData> {
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes, output_types]()
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
}));
break;
@@ -134,11 +136,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0]));
break;
@@ -148,11 +153,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1]));
@@ -164,12 +172,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -183,12 +194,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get(), param3.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -204,13 +218,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get(), param3.get(),
param4.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -228,13 +245,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get(), param3.get(),
param4.get(), param5.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -254,13 +274,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -282,13 +305,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -312,14 +338,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(),
param3.get(), param4.get(), param5.get(),
param6.get(), param7.get(), param8.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -345,14 +374,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -380,14 +412,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -417,14 +452,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -456,15 +494,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -498,15 +539,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get(), param13.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -542,15 +586,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get(), param13.get(), param14.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -588,15 +635,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15)
-> hpx::future<OpaqueOutputData> {
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get(), param13.get(), param14.get(), param15.get()};
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
@@ -624,16 +674,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
switch (num_outputs) {
case 1:
*((void **)outputs[0]) = new hpx::shared_future<void *>(hpx::dataflow(
[](hpx::future<OpaqueOutputData> oodf_in) -> void * {
return oodf_in.get().outputs[0];
},
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
-> void * { return oodf_in.get().outputs[0]; },
oodf));
fut_allocated.push_back(*((void **)outputs[0]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
break;
case 2: {
hpx::future<hpx::tuple<void *, void *>> &&ft = hpx::dataflow(
[](hpx::future<OpaqueOutputData> oodf_in)
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
-> hpx::tuple<void *, void *> {
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
return hpx::make_tuple<>(outputs[0], outputs[1]);
@@ -645,14 +694,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
(void *)new hpx::shared_future<void *>(std::move(hpx::get<0>(tf)));
*((void **)outputs[1]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
fut_allocated.push_back(*((void **)outputs[0]));
fut_allocated.push_back(*((void **)outputs[1]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1]));
break;
}
case 3: {
hpx::future<hpx::tuple<void *, void *, void *>> &&ft = hpx::dataflow(
[](hpx::future<OpaqueOutputData> oodf_in)
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
-> hpx::tuple<void *, void *, void *> {
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
return hpx::make_tuple<>(outputs[0], outputs[1], outputs[2]);
@@ -666,9 +715,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
*((void **)outputs[2]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<2>(tf)));
fut_allocated.push_back(*((void **)outputs[0]));
fut_allocated.push_back(*((void **)outputs[1]));
fut_allocated.push_back(*((void **)outputs[2]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[2]));
break;
}
default:
@@ -681,45 +730,57 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
/* JIT execution support. */
/***************************/
void _dfr_try_initialize();
namespace mlir {
namespace concretelang {
namespace dfr {
namespace {
static bool dfr_required_p = false;
static bool is_jit_p = false;
static bool is_root_node_p = true;
} // namespace
bool _dfr_set_required(bool is_required) {
dfr_required_p = is_required;
if (dfr_required_p)
void _dfr_set_required(bool is_required) {
mlir::concretelang::dfr::dfr_required_p = is_required;
if (mlir::concretelang::dfr::dfr_required_p) {
_dfr_try_initialize();
return true;
mlir::concretelang::dfr::is_root_node_p =
(hpx::find_here() == hpx::find_root_locality());
}
}
void _dfr_set_jit(bool is_jit) { is_jit_p = is_jit; }
bool _dfr_is_jit() { return is_jit_p; }
static inline bool _dfr_is_root_node_impl() {
static bool is_root_node_p =
(!dfr_required_p || (hpx::find_here() == hpx::find_root_locality()));
return is_root_node_p;
}
bool _dfr_is_root_node() { return _dfr_is_root_node_impl(); }
void _dfr_set_jit(bool is_jit) { mlir::concretelang::dfr::is_jit_p = is_jit; }
bool _dfr_is_jit() { return mlir::concretelang::dfr::is_jit_p; }
bool _dfr_is_root_node() { return mlir::concretelang::dfr::is_root_node_p; }
} // namespace dfr
} // namespace concretelang
} // namespace mlir
void _dfr_register_work_function(wfnptr wfn) {
_dfr_node_level_work_function_registry->registerAnonymousWorkFunction(
(void *)wfn);
mlir::concretelang::dfr::_dfr_node_level_work_function_registry
->registerAnonymousWorkFunction((void *)wfn);
}
/************************************/
/* Initialization & Finalization. */
/************************************/
namespace mlir {
namespace concretelang {
namespace dfr {
namespace {
static std::atomic<uint64_t> init_guard = {0};
}
} // namespace dfr
} // namespace concretelang
} // namespace mlir
static inline void _dfr_stop_impl() {
if (_dfr_is_root_node())
if (mlir::concretelang::dfr::_dfr_is_root_node())
hpx::apply([]() { hpx::finalize(); });
hpx::stop();
if (!_dfr_is_root_node())
if (!mlir::concretelang::dfr::_dfr_is_root_node())
exit(EXIT_SUCCESS);
}
static inline void _dfr_start_impl(int argc, char *argv[]) {
dl_handle = dlopen(nullptr, RTLD_NOW);
mlir::concretelang::dfr::dl_handle = dlopen(nullptr, RTLD_NOW);
if (argc == 0) {
unsigned long nCores, nOMPThreads, nHPXThreads;
hwloc_topology_t topology;
@@ -758,23 +819,25 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
}
// Instantiate on each node
new KeyManager<LweBootstrapKey_u64>();
new KeyManager<LweKeyswitchKey_u64>();
new WorkFunctionRegistry();
_dfr_jit_workfunction_registration_barrier = new hpx::lcos::barrier(
"wait_register_remote_work_functions", hpx::get_num_localities().get(),
hpx::get_locality_id());
_dfr_jit_phase_barrier = new hpx::lcos::barrier(
new mlir::concretelang::dfr::KeyManager<LweBootstrapKey_u64>();
new mlir::concretelang::dfr::KeyManager<LweKeyswitchKey_u64>();
new mlir::concretelang::dfr::WorkFunctionRegistry();
mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier =
new hpx::lcos::barrier("wait_register_remote_work_functions",
hpx::get_num_localities().get(),
hpx::get_locality_id());
mlir::concretelang::dfr::_dfr_jit_phase_barrier = new hpx::lcos::barrier(
"phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id());
if (_dfr_is_root_node()) {
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
// Create compute server components on each node - from the root
// node only - and the corresponding compute client on the root
// node.
auto num_nodes = hpx::get_num_localities().get();
gcc = hpx::new_<GenericComputeClient[]>(
hpx::default_layout(hpx::find_all_localities()), num_nodes)
.get();
mlir::concretelang::dfr::gcc =
hpx::new_<mlir::concretelang::dfr::GenericComputeClient[]>(
hpx::default_layout(hpx::find_all_localities()), num_nodes)
.get();
}
}
@@ -788,23 +851,27 @@ void _dfr_start() {
uint64_t uninitialised = 0;
uint64_t active = 1;
uint64_t suspended = 2;
if (init_guard.compare_exchange_strong(uninitialised, active))
if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(uninitialised,
active))
_dfr_start_impl(0, nullptr);
else if (init_guard.compare_exchange_strong(suspended, active))
else if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(
suspended, active))
hpx::resume();
// If this is not the root node in a non-JIT execution, then this
// node should only run the scheduler for any incoming work until
// termination is flagged. If this is JIT, we need to run the
// cancelled function which registers the work functions.
if (!_dfr_is_root_node() && !_dfr_is_jit())
if (!mlir::concretelang::dfr::_dfr_is_root_node() &&
!mlir::concretelang::dfr::_dfr_is_jit())
_dfr_stop_impl();
// TODO: conditional -- If this is the root node, and this is JIT
// execution, we need to wait for the compute nodes to compile and
// register work functions
if (_dfr_is_root_node() && _dfr_is_jit()) {
_dfr_jit_workfunction_registration_barrier->wait();
if (mlir::concretelang::dfr::_dfr_is_root_node() &&
mlir::concretelang::dfr::_dfr_is_jit()) {
mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier->wait();
}
}
@@ -817,8 +884,9 @@ void _dfr_stop() {
// where the root is free to send work out.
// TODO: optimize this by moving synchro to local remote nodes
// waiting in the scheduler for registration.
if (!_dfr_is_root_node() /*&& _dfr_is_jit() /** implicitly true*/) {
_dfr_jit_workfunction_registration_barrier->wait();
if (!mlir::concretelang::dfr::
_dfr_is_root_node() /*&& _dfr_is_jit() /** implicitly true*/) {
mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier->wait();
}
// The barrier is only needed to synchronize the different
@@ -830,8 +898,8 @@ void _dfr_stop() {
// gain as the root node would be waiting for the end of computation
// on all remote nodes before reaching here anyway (dataflow
// dependences).
if (_dfr_is_jit()) {
_dfr_jit_phase_barrier->wait();
if (mlir::concretelang::dfr::_dfr_is_jit()) {
mlir::concretelang::dfr::_dfr_jit_phase_barrier->wait();
}
// TODO: this can be removed along with the matching hpx::resume if
@@ -839,25 +907,28 @@ void _dfr_stop() {
// threads outside of parallel regions - to be tested.
uint64_t active = 1;
uint64_t suspended = 2;
if (init_guard.compare_exchange_strong(active, suspended))
if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(active,
suspended))
hpx::suspend();
// TODO: until we have better unique identifiers for keys it is
// safer to drop them in-between phases.
_dfr_node_level_bsk_manager->clear_keys();
_dfr_node_level_ksk_manager->clear_keys();
mlir::concretelang::dfr::_dfr_node_level_bsk_manager->clear_keys();
mlir::concretelang::dfr::_dfr_node_level_ksk_manager->clear_keys();
while (!new_allocated.empty()) {
delete[] static_cast<char *>(new_allocated.front());
new_allocated.pop_front();
while (!mlir::concretelang::dfr::new_allocated.empty()) {
delete[] static_cast<char *>(
mlir::concretelang::dfr::new_allocated.front());
mlir::concretelang::dfr::new_allocated.pop_front();
}
while (!fut_allocated.empty()) {
delete static_cast<hpx::shared_future<void *> *>(fut_allocated.front());
fut_allocated.pop_front();
while (!mlir::concretelang::dfr::fut_allocated.empty()) {
delete static_cast<hpx::shared_future<void *> *>(
mlir::concretelang::dfr::fut_allocated.front());
mlir::concretelang::dfr::fut_allocated.pop_front();
}
while (!m_allocated.empty()) {
free(m_allocated.front());
m_allocated.pop_front();
while (!mlir::concretelang::dfr::m_allocated.empty()) {
free(mlir::concretelang::dfr::m_allocated.front());
mlir::concretelang::dfr::m_allocated.pop_front();
}
}
@@ -865,7 +936,8 @@ void _dfr_try_initialize() {
// Initialize and immediately suspend the HPX runtime if not yet done.
uint64_t uninitialised = 0;
uint64_t suspended = 2;
if (init_guard.compare_exchange_strong(uninitialised, suspended)) {
if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(uninitialised,
suspended)) {
_dfr_start_impl(0, nullptr);
hpx::suspend();
}
@@ -875,9 +947,11 @@ void _dfr_terminate() {
uint64_t active = 1;
uint64_t suspended = 2;
uint64_t terminated = 3;
if (init_guard.compare_exchange_strong(suspended, active))
if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(suspended,
active))
hpx::resume();
if (init_guard.compare_exchange_strong(active, terminated))
if (mlir::concretelang::dfr::init_guard.compare_exchange_strong(active,
terminated))
_dfr_stop_impl();
}
@@ -925,8 +999,20 @@ void _dfr_print_debug(size_t val) {
#include "concretelang/Runtime/DFRuntime.hpp"
bool _dfr_set_required(bool is_required) { return !is_required; }
void _dfr_set_jit(bool) {}
namespace mlir {
namespace concretelang {
namespace dfr {
namespace {
static bool is_jit_p = false;
} // namespace
void _dfr_set_required(bool is_required) {}
void _dfr_set_jit(bool p) { is_jit_p = p; }
bool _dfr_is_jit() { return is_jit_p; }
bool _dfr_is_root_node() { return true; }
} // namespace dfr
} // namespace concretelang
} // namespace mlir
void _dfr_terminate() {}
#endif

View File

@@ -59,7 +59,7 @@ func.func @main(%arg0: !FHE.eint<7>, %arg1: !FHE.eint<7>, %arg2: !FHE.eint<7>, %
)XXX",
"main", false, true, false);
if (_dfr_is_root_node()) {
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
llvm::Expected<uint64_t> res_1 = lambda(1_u64, 2_u64, 3_u64, 4_u64);
llvm::Expected<uint64_t> res_2 = lambda(4_u64, 5_u64, 6_u64, 7_u64);
llvm::Expected<uint64_t> res_3 = lambda(1_u64, 1_u64, 1_u64, 1_u64);
@@ -112,7 +112,7 @@ TEST(ParallelizeAndRunFHE, nn_small_parallel) {
mlir::concretelang::IntLambdaArgument<uint8_t>>
arg(input, shape2D);
if (_dfr_is_root_node()) {
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg});
ASSERT_EXPECTED_SUCCESS(res);
@@ -154,7 +154,7 @@ TEST(ParallelizeAndRunFHE, nn_small_sequential) {
arg(input, shape2D);
// This is sequential: only execute on root node.
if (_dfr_is_root_node()) {
if (mlir::concretelang::dfr::_dfr_is_root_node()) {
llvm::Expected<std::vector<uint64_t>> res =
lambda.operator()<std::vector<uint64_t>>({&arg});
ASSERT_EXPECTED_SUCCESS(res);