feat(runtime): enable distributed execution.

This commit is contained in:
Antoniu Pop
2022-03-15 14:46:47 +00:00
committed by Antoniu Pop
parent 2cc8c69ff3
commit 954b2098c6
15 changed files with 765 additions and 285 deletions

View File

@@ -19,6 +19,7 @@ extern "C" {
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
#include <concretelang/Runtime/DFRuntime.hpp>
namespace concretelang {
namespace clientlib {

View File

@@ -115,7 +115,13 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> {
let summary = "Create a dataflow task.";
}
def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> {
def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
let arguments = (ins Variadic<AnyType>:$list);
let results = (outs );
let summary = "Register the task work-function with the runtime system.";
}
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
let arguments = (ins RT_Future: $input);
let results = (outs );
}

View File

@@ -6,102 +6,39 @@
#ifndef CONCRETELANG_DFR_DFRUNTIME_HPP
#define CONCRETELANG_DFR_DFRUNTIME_HPP
#include <cassert>
#include <cstdint>
#include <dlfcn.h>
#include <memory>
#include <utility>
#include "concretelang/Runtime/runtime_api.h"
/* Debug interface. */
#include "concretelang/Runtime/dfr_debug_interface.h"
bool _dfr_is_root_node();
void _dfr_is_jit(bool);
bool _dfr_is_jit();
void _dfr_terminate();
extern void *dl_handle;
struct WorkFunctionRegistry;
extern WorkFunctionRegistry *node_level_work_function_registry;
typedef enum _dfr_task_arg_type {
_DFR_TASK_ARG_BASE = 0,
_DFR_TASK_ARG_MEMREF = 1,
_DFR_TASK_ARG_UNRANKED_MEMREF = 2,
_DFR_TASK_ARG_CONTEXT = 3
} _dfr_task_arg_type;
/// Recover the name of the work function
static inline const char *_dfr_get_function_name_from_address(void *fn) {
Dl_info info;
if (!dladdr(fn, &info) || info.dli_sname == nullptr)
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_name_from_address",
"Error recovering work function name from address.");
return info.dli_sname;
static inline _dfr_task_arg_type _dfr_get_arg_type(uint64_t val) {
return (_dfr_task_arg_type)(val & 0xFF);
}
static inline wfnptr _dfr_get_function_pointer_from_name(const char *fn_name) {
auto ptr = dlsym(dl_handle, fn_name);
if (ptr == nullptr)
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_pointer_from_name",
"Error recovering work function pointer from name.");
return (wfnptr)ptr;
static inline uint64_t _dfr_get_memref_element_size(uint64_t val) {
return val >> 8;
}
/// Determine where new task should run. For now just round-robin
/// distribution - TODO: optimise.
static inline size_t _dfr_find_next_execution_locality() {
static size_t num_nodes = hpx::get_num_localities().get();
static std::atomic<std::size_t> next_locality{0};
size_t next_loc = ++next_locality;
return next_loc % num_nodes;
static inline uint64_t _dfr_set_arg_type(uint64_t val,
_dfr_task_arg_type type) {
return (val & ~(0xFF)) | type;
}
static inline bool _dfr_is_root_node() {
return hpx::find_here() == hpx::find_root_locality();
static inline uint64_t _dfr_set_memref_element_size(uint64_t val, size_t size) {
assert(size < (((uint64_t)1) << 48));
return (val & 0xFF) | (((uint64_t)size) << 8);
}
struct WorkFunctionRegistry {
WorkFunctionRegistry() { node_level_work_function_registry = this; }
wfnptr getWorkFunctionPointer(const std::string &name) {
std::lock_guard<std::mutex> guard(registry_guard);
auto fnptrit = name_to_ptr_registry.find(name);
if (fnptrit != name_to_ptr_registry.end())
return (wfnptr)fnptrit->second;
auto ptr = dlsym(dl_handle, name.c_str());
if (ptr == nullptr)
HPX_THROW_EXCEPTION(hpx::no_success,
"WorkFunctionRegistry::getWorkFunctionPointer",
"Error recovering work function pointer from name.");
ptr_to_name_registry.insert(
std::pair<const void *, std::string>(ptr, name));
name_to_ptr_registry.insert(
std::pair<std::string, const void *>(name, ptr));
return (wfnptr)ptr;
}
std::string getWorkFunctionName(const void *fn) {
std::lock_guard<std::mutex> guard(registry_guard);
auto fnnameit = ptr_to_name_registry.find(fn);
if (fnnameit != ptr_to_name_registry.end())
return fnnameit->second;
Dl_info info;
std::string ret;
// Assume that if we can't find the name, there is no dynamic
// library to find it in. TODO: fix this to distinguish JIT/binary
// and in case of distributed exec.
if (!dladdr(fn, &info) || info.dli_sname == nullptr) {
static std::atomic<unsigned int> fnid{0};
ret = "_dfr_jit_wfnname_" + std::to_string(fnid++);
} else {
ret = info.dli_sname;
}
ptr_to_name_registry.insert(std::pair<const void *, std::string>(fn, ret));
name_to_ptr_registry.insert(std::pair<std::string, const void *>(ret, fn));
return ret;
}
private:
std::mutex registry_guard;
std::map<const void *, std::string> ptr_to_name_registry;
std::map<std::string, const void *> name_to_ptr_registry;
};
#endif

View File

@@ -12,7 +12,7 @@
extern "C" {
size_t _dfr_debug_get_node_id();
size_t _dfr_debug_get_worker_id();
void _dfr_debug_print_task(const char *name, int inputs, int outputs);
void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs);
void _dfr_print_debug(size_t val);
}
#endif

View File

@@ -26,101 +26,250 @@
#include <hpx/include/runtime.hpp>
#include <hpx/modules/collectives.hpp>
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/key_manager.hpp"
#include <mlir/ExecutionEngine/CRunnerUtils.h>
extern WorkFunctionRegistry *node_level_work_function_registry;
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/Runtime/dfr_debug_interface.h"
#include "concretelang/Runtime/key_manager.hpp"
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Runtime/workfunction_registry.hpp"
extern WorkFunctionRegistry *_dfr_node_level_work_function_registry;
extern std::list<void *> new_allocated;
using namespace hpx::naming;
using namespace hpx::components;
using namespace hpx::collectives;
static inline size_t _dfr_get_memref_rank(size_t size) {
return (size - 2 * sizeof(char *) /*allocated_ptr & aligned_ptr*/
- sizeof(int64_t) /*offset*/) /
(2 * sizeof(int64_t) /*size&stride/rank*/);
}
struct OpaqueInputData {
OpaqueInputData() = default;
OpaqueInputData(std::string wfn_name, std::vector<void *> params,
std::vector<size_t> param_sizes,
std::vector<size_t> output_sizes, bool alloc_p = false)
: wfn_name(wfn_name), params(std::move(params)),
param_sizes(std::move(param_sizes)),
output_sizes(std::move(output_sizes)), alloc_p(alloc_p) {}
OpaqueInputData(std::string _wfn_name, std::vector<void *> _params,
std::vector<size_t> _param_sizes,
std::vector<uint64_t> _param_types,
std::vector<size_t> _output_sizes,
std::vector<uint64_t> _output_types, bool _alloc_p = false)
: wfn_name(_wfn_name), params(std::move(_params)),
param_sizes(std::move(_param_sizes)),
param_types(std::move(_param_types)),
output_sizes(std::move(_output_sizes)),
output_types(std::move(_output_types)), alloc_p(_alloc_p),
source_locality(hpx::find_here()) {}
OpaqueInputData(const OpaqueInputData &oid)
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
param_sizes(std::move(oid.param_sizes)),
output_sizes(std::move(oid.output_sizes)), alloc_p(oid.alloc_p) {}
param_types(std::move(oid.param_types)),
output_sizes(std::move(oid.output_sizes)),
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
source_locality(oid.source_locality) {}
friend class hpx::serialization::access;
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar &wfn_name;
ar &param_sizes;
ar &output_sizes;
for (auto p : param_sizes) {
char *param = new char[p];
// TODO: Optimise these serialisation operations
for (size_t i = 0; i < p; ++i)
ar &param[i];
ar >> wfn_name;
ar >> param_sizes >> param_types;
ar >> output_sizes >> output_types;
for (size_t p = 0; p < param_sizes.size(); ++p) {
char *param = new char[param_sizes[p]];
new_allocated.push_back((void *)param);
ar >> hpx::serialization::make_array(param, param_sizes[p]);
params.push_back((void *)param);
switch (_dfr_get_arg_type(param_types[p])) {
case _DFR_TASK_ARG_BASE:
break;
case _DFR_TASK_ARG_MEMREF: {
size_t rank = _dfr_get_memref_rank(param_sizes[p]);
UnrankedMemRefType<char> umref = {rank, params[p]};
DynamicMemRefType<char> mref(umref);
size_t elementSize = _dfr_get_memref_element_size(param_types[p]);
size_t size = 1;
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
size_t alloc_size = (size + mref.offset) * elementSize;
char *data = new char[alloc_size];
new_allocated.push_back((void *)data);
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
size * elementSize);
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
} break;
case _DFR_TASK_ARG_CONTEXT: {
uint64_t bsk_id, ksk_id;
ar >> bsk_id >> ksk_id >> source_locality;
mlir::concretelang::RuntimeContext *context =
new mlir::concretelang::RuntimeContext;
new_allocated.push_back((void *)context);
mlir::concretelang::RuntimeContext **_context =
new mlir::concretelang::RuntimeContext *[1];
new_allocated.push_back((void *)_context);
_context[0] = context;
context->bsk = (LweBootstrapKey_u64 *)bsk_id;
context->ksk = (LweKeyswitchKey_u64 *)ksk_id;
params[p] = (void *)_context;
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
"Error: invalid task argument type.");
}
}
alloc_p = true;
}
template <class Archive>
void save(Archive &ar, const unsigned int version) const {
ar &wfn_name;
ar &param_sizes;
ar &output_sizes;
for (size_t p = 0; p < params.size(); ++p)
for (size_t i = 0; i < param_sizes[p]; ++i)
ar &static_cast<char *>(params[p])[i];
ar << wfn_name;
ar << param_sizes << param_types;
ar << output_sizes << output_types;
for (size_t p = 0; p < params.size(); ++p) {
// Save the first level of the data structure - if the parameter
// is a tensor/memref, there is a second level.
ar << hpx::serialization::make_array((char *)params[p], param_sizes[p]);
switch (_dfr_get_arg_type(param_types[p])) {
case _DFR_TASK_ARG_BASE:
break;
case _DFR_TASK_ARG_MEMREF: {
size_t rank = _dfr_get_memref_rank(param_sizes[p]);
UnrankedMemRefType<char> umref = {rank, params[p]};
DynamicMemRefType<char> mref(umref);
size_t elementSize = _dfr_get_memref_element_size(param_types[p]);
size_t size = 1;
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
ar << hpx::serialization::make_array(
mref.data + mref.offset * elementSize, size * elementSize);
} break;
case _DFR_TASK_ARG_CONTEXT: {
mlir::concretelang::RuntimeContext *context =
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
LweKeyswitchKey_u64 *ksk = context->ksk;
LweBootstrapKey_u64 *bsk = context->bsk;
// TODO: find better unique identifiers. This is not a
// correctness issue, but performance.
uint64_t bsk_id = (uint64_t)bsk;
uint64_t ksk_id = (uint64_t)ksk;
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
_dfr_register_bsk(bsk, bsk_id);
_dfr_register_ksk(ksk, ksk_id);
ar << bsk_id << ksk_id << source_locality;
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
"Error: invalid task argument type.");
}
}
}
HPX_SERIALIZATION_SPLIT_MEMBER()
std::string wfn_name;
std::vector<void *> params;
std::vector<size_t> param_sizes;
std::vector<uint64_t> param_types;
std::vector<size_t> output_sizes;
std::vector<uint64_t> output_types;
bool alloc_p = false;
hpx::naming::id_type source_locality;
};
struct OpaqueOutputData {
OpaqueOutputData() = default;
OpaqueOutputData(std::vector<void *> outputs,
std::vector<size_t> output_sizes, bool alloc_p = false)
std::vector<size_t> output_sizes,
std::vector<uint64_t> output_types, bool alloc_p = false)
: outputs(std::move(outputs)), output_sizes(std::move(output_sizes)),
alloc_p(alloc_p) {}
output_types(std::move(output_types)), alloc_p(alloc_p) {}
OpaqueOutputData(const OpaqueOutputData &ood)
: outputs(std::move(ood.outputs)),
output_sizes(std::move(ood.output_sizes)), alloc_p(ood.alloc_p) {}
output_sizes(std::move(ood.output_sizes)),
output_types(std::move(ood.output_types)), alloc_p(ood.alloc_p) {}
friend class hpx::serialization::access;
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar &output_sizes;
for (auto p : output_sizes) {
char *output = new char[p];
for (size_t i = 0; i < p; ++i)
ar &output[i];
outputs.push_back((void *)output);
ar >> output_sizes >> output_types;
for (size_t p = 0; p < output_sizes.size(); ++p) {
char *output = new char[output_sizes[p]];
new_allocated.push_back((void *)output);
ar >> hpx::serialization::make_array(output, output_sizes[p]);
outputs.push_back((void *)output);
switch (_dfr_get_arg_type(output_types[p])) {
case _DFR_TASK_ARG_BASE:
break;
case _DFR_TASK_ARG_MEMREF: {
size_t rank = _dfr_get_memref_rank(output_sizes[p]);
UnrankedMemRefType<char> umref = {rank, outputs[p]};
DynamicMemRefType<char> mref(umref);
size_t elementSize = _dfr_get_memref_element_size(output_types[p]);
size_t size = 1;
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
size_t alloc_size = (size + mref.offset) * elementSize;
char *data = new char[alloc_size];
new_allocated.push_back((void *)data);
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
size * elementSize);
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
nullptr;
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->data = data;
} break;
case _DFR_TASK_ARG_CONTEXT: {
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
"Error: invalid task argument type.");
}
}
alloc_p = true;
}
template <class Archive>
void save(Archive &ar, const unsigned int version) const {
ar &output_sizes;
ar << output_sizes << output_types;
for (size_t p = 0; p < outputs.size(); ++p) {
for (size_t i = 0; i < output_sizes[p]; ++i)
ar &static_cast<char *>(outputs[p])[i];
// TODO: investigate if HPX is automatically deallocating
// these. Here it could be safely assumed that these would no
// longer be live.
// delete (char*)outputs[p];
ar << hpx::serialization::make_array((char *)outputs[p], output_sizes[p]);
switch (_dfr_get_arg_type(output_types[p])) {
case _DFR_TASK_ARG_BASE:
break;
case _DFR_TASK_ARG_MEMREF: {
size_t rank = _dfr_get_memref_rank(output_sizes[p]);
UnrankedMemRefType<char> umref = {rank, outputs[p]};
DynamicMemRefType<char> mref(umref);
size_t elementSize = _dfr_get_memref_element_size(output_types[p]);
size_t size = 1;
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
ar << hpx::serialization::make_array(
mref.data + mref.offset * elementSize, size * elementSize);
} break;
case _DFR_TASK_ARG_CONTEXT: {
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
"Error: invalid task argument type.");
}
}
}
HPX_SERIALIZATION_SPLIT_MEMBER()
std::vector<void *> outputs;
std::vector<size_t> output_sizes;
std::vector<uint64_t> output_types;
bool alloc_p = false;
};
@@ -129,10 +278,24 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
// Component actions exposed
OpaqueOutputData execute_task(const OpaqueInputData &inputs) {
auto wfn = node_level_work_function_registry->getWorkFunctionPointer(
auto wfn = _dfr_node_level_work_function_registry->getWorkFunctionPointer(
inputs.wfn_name);
std::vector<void *> outputs;
if (!_dfr_is_root_node()) {
for (size_t p = 0; p < inputs.params.size(); ++p) {
if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) {
mlir::concretelang::RuntimeContext *ctx =
(*(mlir::concretelang::RuntimeContext **)inputs.params[p]);
ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk);
ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk);
break;
}
}
}
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
inputs.output_sizes.size());
hpx::cout << std::flush;
switch (inputs.output_sizes.size()) {
case 1: {
void *output = (void *)(new char[inputs.output_sizes[0]]);
@@ -449,12 +612,8 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
"Error: number of task outputs not supported.");
}
if (inputs.alloc_p)
for (auto p : inputs.params)
delete ((char *)p);
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
inputs.alloc_p);
std::move(inputs.output_types), inputs.alloc_p);
}
HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task);

View File

@@ -7,129 +7,206 @@
#define CONCRETELANG_DFR_KEY_MANAGER_HPP
#include <memory>
#include <mutex>
#include <utility>
#include <hpx/include/runtime.hpp>
#include <hpx/modules/collectives.hpp>
#include <hpx/modules/serialization.hpp>
#include "concretelang/Runtime/DFRuntime.hpp"
struct PbsKeyManager;
extern PbsKeyManager *node_level_key_manager;
extern "C" {
#include "concrete-ffi.h"
}
struct PbsKeyWrapper {
std::shared_ptr<void *> key;
size_t key_id;
size_t size;
extern std::list<void *> new_allocated;
PbsKeyWrapper() {}
template <typename T> struct KeyManager;
extern KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
extern KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id);
PbsKeyWrapper(void *key, size_t key_id, size_t size)
: key(std::make_shared<void *>(key)), key_id(key_id), size(size) {}
PbsKeyWrapper(std::shared_ptr<void *> key, size_t key_id, size_t size)
: key(key), key_id(key_id), size(size) {}
PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept
: key(moved.key), key_id(moved.key_id), size(moved.size) {}
PbsKeyWrapper(const PbsKeyWrapper &pbsk)
: key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {}
template <typename LweKeyType> struct KeyWrapper {
LweKeyType *key;
KeyWrapper() : key(nullptr) {}
KeyWrapper(LweKeyType *key) : key(key) {}
KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {}
KeyWrapper(const KeyWrapper &kw) : key(kw.key) {}
friend class hpx::serialization::access;
template <class Archive>
void save(Archive &ar, const unsigned int version) const {
char *_key_ = static_cast<char *>(*key);
ar &key_id &size;
for (size_t i = 0; i < size; ++i)
ar &_key_[i];
}
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar &key_id &size;
char *_key_ = (char *)malloc(size);
for (size_t i = 0; i < size; ++i)
ar &_key_[i];
key = std::make_shared<void *>(_key_);
}
void save(Archive &ar, const unsigned int version) const;
template <class Archive> void load(Archive &ar, const unsigned int version);
HPX_SERIALIZATION_SPLIT_MEMBER()
};
inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs) {
return lhs.key_id == rhs.key_id;
template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey_u64>::save(Archive &ar,
const unsigned int version) const {
Buffer buffer = serialize_lwe_bootstrap_key_u64(key);
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
template <>
template <class Archive>
void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
const unsigned int version) {
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
new_allocated.push_back((void *)pointer);
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_bootstrap_key_u64(buffer);
}
PbsKeyWrapper _dfr_fetch_key(size_t);
HPX_PLAIN_ACTION(_dfr_fetch_key, _dfr_fetch_key_action)
template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey_u64>::save(Archive &ar,
const unsigned int version) const {
Buffer buffer = serialize_lwe_keyswitching_key_u64(key);
ar << buffer.length;
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
}
template <>
template <class Archive>
void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
const unsigned int version) {
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
new_allocated.push_back((void *)pointer);
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_keyswitching_key_u64(buffer);
}
struct PbsKeyManager {
// The initial keys registered on the root node and whether to push
// them is TBD.
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t);
HPX_PLAIN_ACTION(_dfr_fetch_ksk, _dfr_fetch_ksk_action)
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t);
HPX_PLAIN_ACTION(_dfr_fetch_bsk, _dfr_fetch_bsk_action)
PbsKeyManager() { node_level_key_manager = this; }
template <typename LweKeyType> struct KeyManager {
KeyManager() {}
LweKeyType *get_key(hpx::naming::id_type loc, const uint64_t key_id);
PbsKeyWrapper get_key(const size_t key_id) {
keystore_guard.lock();
auto keyit = keystore.find(key_id);
keystore_guard.unlock();
if (keyit == keystore.end()) {
_dfr_fetch_key_action fet;
PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id);
if (pkw.size == 0) {
// Maybe retry or try other nodes... but for now it's an error.
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
"Error: key not found on remote node.");
} else {
std::lock_guard<std::mutex> guard(keystore_guard);
keyit = keystore.insert(std::pair<size_t, PbsKeyWrapper>(key_id, pkw))
.first;
}
}
return keyit->second;
}
// To be used only for remote requests
PbsKeyWrapper fetch_key(const size_t key_id) {
KeyWrapper<LweKeyType> fetch_key(const uint64_t key_id) {
std::lock_guard<std::mutex> guard(keystore_guard);
auto keyit = keystore.find(key_id);
if (keyit != keystore.end())
return keyit->second;
// If this node does not contain this key, return an empty wrapper
return PbsKeyWrapper(nullptr, 0, 0);
// If this node does not contain this key, this is an error
// (location was supplied as source for this key).
HPX_THROW_EXCEPTION(
hpx::no_success, "fetch_key",
"Error: could not find key to be fetched on source location.");
}
void register_key(void *key, size_t key_id, size_t size) {
void register_key(LweKeyType *key, uint64_t key_id) {
std::lock_guard<std::mutex> guard(keystore_guard);
auto keyit = keystore
.insert(std::pair<size_t, PbsKeyWrapper>(
key_id, PbsKeyWrapper(key, key_id, size)))
.first;
auto keyit = keystore.find(key_id);
if (keyit == keystore.end()) {
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key",
"Error: could not register new key.");
keyit = keystore
.insert(std::pair<uint64_t, KeyWrapper<LweKeyType>>(
key_id, KeyWrapper<LweKeyType>(key)))
.first;
if (keyit == keystore.end()) {
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key",
"Error: could not register new key.");
}
}
}
void broadcast_keys() {
void clear_keys() {
std::lock_guard<std::mutex> guard(keystore_guard);
if (_dfr_is_root_node())
hpx::collectives::broadcast_to("keystore", this->keystore).get();
else
keystore = std::move(
hpx::collectives::broadcast_from<std::map<size_t, PbsKeyWrapper>>(
"keystore")
.get());
keystore.clear();
}
private:
std::mutex keystore_guard;
std::map<size_t, PbsKeyWrapper> keystore;
std::map<uint64_t, KeyWrapper<LweKeyType>> keystore;
};
PbsKeyWrapper _dfr_fetch_key(size_t key_id) {
return node_level_key_manager->fetch_key(key_id);
template <> KeyManager<LweBootstrapKey_u64>::KeyManager() {
_dfr_node_level_bsk_manager = this;
}
template <>
LweBootstrapKey_u64 *
KeyManager<LweBootstrapKey_u64>::get_key(hpx::naming::id_type loc,
const uint64_t key_id) {
keystore_guard.lock();
auto keyit = keystore.find(key_id);
keystore_guard.unlock();
if (keyit == keystore.end()) {
_dfr_fetch_bsk_action fetch;
KeyWrapper<LweBootstrapKey_u64> &&bskw = fetch(loc, key_id);
if (bskw.key == nullptr) {
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
"Error: Bootstrap key not found on root node.");
} else {
_dfr_register_bsk(bskw.key, key_id);
}
return bskw.key;
}
return keyit->second.key;
}
template <> KeyManager<LweKeyswitchKey_u64>::KeyManager() {
_dfr_node_level_ksk_manager = this;
}
template <>
LweKeyswitchKey_u64 *
KeyManager<LweKeyswitchKey_u64>::get_key(hpx::naming::id_type loc,
const uint64_t key_id) {
keystore_guard.lock();
auto keyit = keystore.find(key_id);
keystore_guard.unlock();
if (keyit == keystore.end()) {
_dfr_fetch_ksk_action fetch;
KeyWrapper<LweKeyswitchKey_u64> &&kskw = fetch(loc, key_id);
if (kskw.key == nullptr) {
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
"Error: Keyswitching key not found on root node.");
} else {
_dfr_register_ksk(kskw.key, key_id);
}
return kskw.key;
}
return keyit->second.key;
}
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t key_id) {
return _dfr_node_level_bsk_manager->fetch_key(key_id);
}
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t key_id) {
return _dfr_node_level_ksk_manager->fetch_key(key_id);
}
/************************/
/* Key management API. */
/************************/
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id) {
_dfr_node_level_bsk_manager->register_key(key, key_id);
}
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id) {
_dfr_node_level_ksk_manager->register_key(key, key_id);
}
LweBootstrapKey_u64 *_dfr_get_bsk(hpx::naming::id_type loc, uint64_t key_id) {
return _dfr_node_level_bsk_manager->get_key(loc, key_id);
}
LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
}
#endif

View File

@@ -16,14 +16,9 @@ typedef void (*wfnptr)(...);
void *_dfr_make_ready_future(void *);
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
void _dfr_register_work_function(wfnptr);
void *_dfr_await_future(void *);
/* Keys can have node-local copies which can be retrieved. This
should only be called on the node where the key is required. */
void _dfr_register_key(void *, size_t, size_t);
void _dfr_broadcast_keys();
void *_dfr_get_key(size_t);
/* Memory management:
_dfr_make_ready_future allocates the future, not the underlying storage.
_dfr_create_async_task allocates both future and storage for outputs. */
@@ -36,4 +31,5 @@ void _dfr_stop();
void _dfr_terminate();
}
#endif

View File

@@ -0,0 +1,92 @@
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
// Exceptions. See
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
// for license information.
#ifndef CONCRETELANG_DFR_WORKFUNCTION_REGISTRY_HPP
#define CONCRETELANG_DFR_WORKFUNCTION_REGISTRY_HPP
#include <memory>
#include <mutex>
#include <utility>
#include <hpx/include/runtime.hpp>
#include <hpx/modules/collectives.hpp>
#include <hpx/modules/serialization.hpp>
#include "concretelang/Runtime/DFRuntime.hpp"
extern void *dl_handle;
struct WorkFunctionRegistry;
extern WorkFunctionRegistry *_dfr_node_level_work_function_registry;
struct WorkFunctionRegistry {
WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; }
wfnptr getWorkFunctionPointer(const std::string &name) {
std::lock_guard<std::recursive_mutex> guard(registry_guard);
auto fnptrit = name_to_ptr_registry.find(name);
if (fnptrit != name_to_ptr_registry.end())
return (wfnptr)fnptrit->second;
auto ptr = dlsym(dl_handle, name.c_str());
if (ptr == nullptr) {
HPX_THROW_EXCEPTION(hpx::no_success,
"WorkFunctionRegistry::getWorkFunctionPointer",
"Error recovering work function pointer from name.");
}
registerWorkFunction(ptr, name);
return (wfnptr)ptr;
}
std::string getWorkFunctionName(const void *fn) {
std::lock_guard<std::recursive_mutex> guard(registry_guard);
auto fnnameit = ptr_to_name_registry.find(fn);
if (fnnameit != ptr_to_name_registry.end())
return fnnameit->second;
Dl_info info;
std::string ret;
// Assume that if we can't find the name, there is no dynamic
// library to find it in. TODO: fix this to distinguish JIT/binary
// and in case of distributed exec.
if (!dladdr(fn, &info) || info.dli_sname == nullptr) {
ret = registerAnonymousWorkFunction(fn);
} else {
ret = info.dli_sname;
registerWorkFunction(fn, ret);
}
return ret;
}
void registerWorkFunction(const void *fn, std::string name) {
std::lock_guard<std::recursive_mutex> guard(registry_guard);
auto fnnameit = ptr_to_name_registry.find(fn);
if (fnnameit == ptr_to_name_registry.end())
ptr_to_name_registry.insert(
std::pair<const void *, std::string>(fn, name));
auto fnptrit = name_to_ptr_registry.find(name);
if (fnptrit == name_to_ptr_registry.end())
name_to_ptr_registry.insert(
std::pair<std::string, const void *>(name, fn));
}
std::string registerAnonymousWorkFunction(const void *fn) {
std::lock_guard<std::recursive_mutex> guard(registry_guard);
static std::atomic<unsigned int> fnid{0};
std::string name = "_dfr_jit_wfnname_" + std::to_string(fnid++);
registerWorkFunction(fn, name);
return name;
}
private:
std::recursive_mutex registry_guard;
std::map<const void *, std::string> ptr_to_name_registry;
std::map<std::string, const void *> name_to_ptr_registry;
};
#endif

View File

@@ -14,5 +14,5 @@ add_mlir_library(RTDialectAnalysis
LINK_LIBS PUBLIC
MLIRIR
RTDialect
ConcretelangRuntime
)

View File

@@ -15,6 +15,7 @@
#include <concretelang/Dialect/RT/IR/RTDialect.h>
#include <concretelang/Dialect/RT/IR/RTOps.h>
#include <concretelang/Dialect/RT/IR/RTTypes.h>
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/math.h>
#include <mlir/IR/BuiltinOps.h>
@@ -132,7 +133,8 @@ static void replaceAllUsesNotInDFTsInRegionWith(Value orig, Value replacement,
}
// TODO: Fix type sizes. For now we're using some default values.
static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) {
static std::pair<mlir::Value, mlir::Value>
getSizeInBytes(Value val, Location loc, OpBuilder builder) {
DataLayout dataLayout = DataLayout::closest(val.getDefiningOp());
Type type = (val.getType().isa<RT::FutureType>())
? val.getType().dyn_cast<RT::FutureType>().getElementType()
@@ -153,26 +155,61 @@ static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) {
Value rank = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_rank));
Value sizes_shapes = builder.create<LLVM::MulOp>(loc, rank, multiplier);
Value result = builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
return result;
Value typeSize =
builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
Type elementType = type.dyn_cast<mlir::MemRefType>().getElementType();
// Assume here that the base type is a simple scalar-type or at
// least its size can be determined.
// size_t elementAttr = dataLayout.getTypeSize(elementType);
// Make room for a byte to store the type of this argument/output
// elementAttr <<= 8;
// elementAttr |= _DFR_TASK_ARG_MEMREF;
uint64_t elementAttr = 0;
size_t element_size = dataLayout.getTypeSize(elementType);
elementAttr = _dfr_set_arg_type(elementAttr, _DFR_TASK_ARG_MEMREF);
elementAttr = _dfr_set_memref_element_size(elementAttr, element_size);
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(elementAttr));
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
}
// Unranked memrefs should be lowered to just pointer + size, so we need 16
// bytes.
if (type.isa<mlir::UnrankedMemRefType>())
return builder.create<arith::ConstantOp>(loc,
builder.getI64IntegerAttr(16));
if (type.isa<mlir::UnrankedMemRefType>()) {
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_UNRANKED_MEMREF));
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(16));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
}
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_BASE));
// FHE types are converted to pointers, so we take their size as 8
// bytes until we can get the actual size of the actual types.
if (type.isa<mlir::concretelang::Concrete::ContextType>() ||
type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>())
return builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::LweKeySwitchKeyType>() ||
type.isa<mlir::concretelang::Concrete::LweBootstrapKeyType>() ||
type.isa<mlir::concretelang::Concrete::ForeignPlaintextListType>() ||
type.isa<mlir::concretelang::Concrete::PlaintextListType>()) {
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
} else if (type.isa<mlir::concretelang::Concrete::ContextType>()) {
Value arg_type = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_CONTEXT));
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
}
// For all other types, get type size.
return builder.create<arith::ConstantOp>(
Value result = builder.create<arith::ConstantOp>(
loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type)));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
}
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
@@ -189,9 +226,22 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
builder.setInsertionPoint(DFTOp);
for (Value val : DFTOp.getOperands()) {
if (!val.getType().isa<RT::FutureType>()) {
Type futType = RT::FutureType::get(val.getType());
auto mrf =
builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType, val);
Value newval;
// If we are building a future on a MemRef, then we need to clone
// the memref in order to allow the deallocation pass which does
// not synchronize with task execution.
if (val.getType().isa<mlir::MemRefType>() && !val.isa<BlockArgument>()) {
newval = builder
.create<memref::AllocOp>(
DFTOp.getLoc(), val.getType().cast<mlir::MemRefType>())
.getResult();
builder.create<memref::CopyOp>(DFTOp.getLoc(), val, newval);
} else {
newval = val;
}
Type futType = RT::FutureType::get(newval.getType());
auto mrf = builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType,
newval);
map.map(mrf->getResult(0), val);
replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody);
}
@@ -201,7 +251,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
// DataflowTaskOp. This also includes the necessary handling of
// operands and results (conversion to/from futures and propagation).
SmallVector<Value, 4> catOperands;
int size = 3 + DFTOp.getNumResults() * 2 + DFTOp.getNumOperands() * 2;
int size = 3 + DFTOp.getNumResults() * 3 + DFTOp.getNumOperands() * 3;
catOperands.reserve(size);
auto fnptr = builder.create<mlir::func::ConstantOp>(
DFTOp.getLoc(), workFunction.getFunctionType(),
@@ -214,8 +264,10 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
catOperands.push_back(numIns.getResult());
catOperands.push_back(numOuts.getResult());
for (auto operand : DFTOp.getOperands()) {
auto op_size = getSizeInBytes(operand, DFTOp.getLoc(), builder);
catOperands.push_back(operand);
catOperands.push_back(getSizeInBytes(operand, DFTOp.getLoc(), builder));
catOperands.push_back(op_size.first);
catOperands.push_back(op_size.second);
}
// We need to adjust the results for the CreateAsyncTaskOp which
@@ -228,9 +280,11 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
futType);
auto op_size = getSizeInBytes(result, DFTOp.getLoc(), builder);
map.map(result, brpp->getResult(0));
catOperands.push_back(brpp->getResult(0));
catOperands.push_back(getSizeInBytes(result, DFTOp.getLoc(), builder));
catOperands.push_back(op_size.first);
catOperands.push_back(op_size.second);
}
builder.create<RT::CreateAsyncTaskOp>(
DFTOp.getLoc(),
@@ -272,6 +326,18 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
DFTOp.erase();
}
static void registerWorkFunction(FuncOp parentFunc, FuncOp workFunction) {
OpBuilder builder(parentFunc.body());
builder.setInsertionPointToStart(&parentFunc.body().front());
auto fnptr = builder.create<mlir::ConstantOp>(
parentFunc.getLoc(), workFunction.getType(),
SymbolRefAttr::get(builder.getContext(), workFunction.getName()));
builder.create<RT::RegisterTaskWorkFunctionOp>(parentFunc.getLoc(),
fnptr.getResult());
}
/// For documentation see Autopar.td
struct LowerDataflowTasksPass
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
@@ -306,6 +372,33 @@ struct LowerDataflowTasksPass
for (auto mapping : outliningMap)
lowerDataflowTaskOp(mapping.first, mapping.second);
// If this is a JIT invocation and we're not on the root node,
// we do not need to do any computation, only register all work
// functions with the runtime system
if (!outliningMap.empty()) {
if (!_dfr_is_root_node()) {
// auto regFunc = builder.create<FuncOp>(func.getLoc(),
// func.getName(), func.getType());
func.eraseBody();
Block *b = new Block;
b->addArguments(func.getType().getInputs());
func.body().push_front(b);
for (int i = func.getType().getNumInputs() - 1; i >= 0; --i)
func.eraseArgument(i);
for (int i = func.getType().getNumResults() - 1; i >= 0; --i)
func.eraseResult(i);
OpBuilder builder(func.body());
builder.setInsertionPointToEnd(&func.body().front());
builder.create<ReturnOp>(func.getLoc());
}
}
// Generate code to register all work-functions with the
// runtime.
for (auto mapping : outliningMap)
registerWorkFunction(func, mapping.second);
// Issue _dfr_start/stop calls for this function
if (!outliningMap.empty()) {
OpBuilder builder(func.getBody());

View File

@@ -171,6 +171,26 @@ struct CreateAsyncTaskOpInterfaceLowering
return success();
}
};
struct RegisterTaskWorkFunctionOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::RegisterTaskWorkFunctionOp> {
using ConvertOpToLLVMPattern<
RT::RegisterTaskWorkFunctionOp>::ConvertOpToLLVMPattern;
mlir::LogicalResult
matchAndRewrite(RT::RegisterTaskWorkFunctionOp rtwfOp,
ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
RT::RegisterTaskWorkFunctionOp::Adaptor transformed(operands);
auto rtwfFuncType =
LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true);
auto rtwfFuncOp = getOrInsertFuncOpDecl(
rtwfOp, "_dfr_register_work_function", rtwfFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(rtwfOp, rtwfFuncOp,
transformed.getOperands());
return success();
}
};
struct DeallocateFutureOpInterfaceLowering
: public ConvertOpToLLVMPattern<RT::DeallocateFutureOp> {
using ConvertOpToLLVMPattern<RT::DeallocateFutureOp>::ConvertOpToLLVMPattern;
@@ -296,6 +316,7 @@ void mlir::concretelang::populateRTToLLVMConversionPatterns(
DerefReturnPtrPlaceholderOpInterfaceLowering,
DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering,
CreateAsyncTaskOpInterfaceLowering,
RegisterTaskWorkFunctionOpInterfaceLowering,
DeallocateFutureOpInterfaceLowering,
DeallocateFutureDataOpInterfaceLowering,
WorkFunctionReturnOpInterfaceLowering>(converter);

View File

@@ -11,6 +11,7 @@
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
#include <hpx/barrier.hpp>
#include <hpx/future.hpp>
#include <hpx/hpx_start.hpp>
#include <hpx/hpx_suspend.hpp>
@@ -22,11 +23,16 @@
std::vector<GenericComputeClient> gcc;
void *dl_handle;
PbsKeyManager *node_level_key_manager;
WorkFunctionRegistry *node_level_work_function_registry;
KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
WorkFunctionRegistry *_dfr_node_level_work_function_registry;
std::list<void *> new_allocated;
std::list<void *> fut_allocated;
std::list<void *> m_allocated;
hpx::lcos::barrier *_dfr_jit_workfunction_registration_barrier;
hpx::lcos::barrier *_dfr_jit_phase_barrier;
std::atomic<uint64_t> init_guard = {0};
using namespace hpx;
@@ -52,6 +58,17 @@ void _dfr_deallocate_future(void *in) {
delete (static_cast<hpx::shared_future<void *> *>(in));
}
// Determine where new task should run. For now just round-robin
// distribution - TODO: optimise.
static inline size_t _dfr_find_next_execution_locality() {
static size_t num_nodes = hpx::get_num_localities().get();
static std::atomic<std::size_t> next_locality{0};
size_t next_loc = ++next_locality;
return next_loc % num_nodes;
}
/// Runtime generic async_task. Each first NUM_PARAMS pairs of
/// arguments in the variadic list corresponds to a void* pointer on a
/// hpx::future<void*> and the size of data within the future. After
@@ -60,28 +77,39 @@ void _dfr_deallocate_future(void *in) {
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
...) {
std::vector<void *> params;
std::vector<void *> outputs;
std::vector<size_t> param_sizes;
std::vector<uint64_t> param_types;
std::vector<void *> outputs;
std::vector<size_t> output_sizes;
std::vector<uint64_t> output_types;
va_list args;
va_start(args, num_outputs);
for (size_t i = 0; i < num_params; ++i) {
params.push_back(va_arg(args, void *));
param_sizes.push_back(va_arg(args, size_t));
param_sizes.push_back(va_arg(args, uint64_t));
param_types.push_back(va_arg(args, uint64_t));
}
for (size_t i = 0; i < num_outputs; ++i) {
outputs.push_back(va_arg(args, void *));
output_sizes.push_back(va_arg(args, size_t));
output_sizes.push_back(va_arg(args, uint64_t));
output_types.push_back(va_arg(args, uint64_t));
}
va_end(args);
for (size_t i = 0; i < num_params; ++i) {
if (_dfr_get_arg_type(param_types[i] == _DFR_TASK_ARG_MEMREF)) {
m_allocated.push_back(
(void *)static_cast<StridedMemRefType<char, 1> *>(params[i])->data);
}
}
// We pass functions by name - which is not strictly necessary in
// shared memory as pointers suffice, but is needed in the
// distributed case where the functions need to be located/loaded on
// the node.
auto wfnname =
node_level_work_function_registry->getWorkFunctionName((void *)wfn);
_dfr_node_level_work_function_registry->getWorkFunctionName((void *)wfn);
hpx::future<hpx::future<OpaqueOutputData>> oodf;
// In order to allow complete dataflow semantics for
@@ -93,20 +121,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
switch (num_params) {
case 0:
oodf = std::move(
hpx::dataflow([wfnname, param_sizes,
output_sizes]() -> hpx::future<OpaqueOutputData> {
hpx::dataflow([wfnname, param_sizes, param_types, output_sizes,
output_types]() -> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
}));
break;
case 1:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0)
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0)
-> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {param0.get()};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0]));
@@ -114,11 +145,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 2:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1)
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1)
-> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get()};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
@@ -127,13 +160,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
case 3:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2)
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2)
-> hpx::future<OpaqueOutputData> {
std::vector<void *> params = {param0.get(), param1.get(),
param2.get()};
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
output_sizes, output_types);
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
@@ -642,25 +677,44 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
}
}
/***************************/
/* JIT execution support. */
/***************************/
static inline bool _dfr_is_root_node_impl() {
static bool is_root_node_p = (hpx::find_here() == hpx::find_root_locality());
return is_root_node_p;
}
bool _dfr_is_root_node() { return _dfr_is_root_node_impl(); }
void _dfr_register_work_function(wfnptr wfn) {
_dfr_node_level_work_function_registry->registerAnonymousWorkFunction(
(void *)wfn);
}
/********************************/
/* Distributed key management. */
/********************************/
void _dfr_register_key(void *key, size_t key_id, size_t size) {
node_level_key_manager->register_key(key, key_id, size);
}
void _dfr_broadcast_keys() { node_level_key_manager->broadcast_keys(); }
void *_dfr_get_key(size_t key_id) {
return *node_level_key_manager->get_key(key_id).key.get();
}
/************************************/
/* Initialization & Finalization. */
/************************************/
/* Runtime initialization and finalization. */
// TODO: need to set a flag for when executing in JIT to allow remote
// nodes to execute generated function that registers work functions.
// This also means that compute nodes need to go through all the
// phases of computation synchronized with the root node.
static inline bool _dfr_is_jit_impl(bool is_jit = false) {
static bool is_jit_p = is_jit;
if (is_jit && !is_jit_p)
is_jit_p = true;
return is_jit_p;
}
void _dfr_is_jit(bool is_jit) { _dfr_is_jit_impl(is_jit); }
bool _dfr_is_jit() { return _dfr_is_jit_impl(); }
static inline void _dfr_stop_impl() {
if (_dfr_is_root_node())
if (_dfr_is_root_node_impl())
hpx::apply([]() { hpx::finalize(); });
hpx::stop();
}
@@ -701,10 +755,16 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
}
// Instantiate on each node
new PbsKeyManager();
new KeyManager<LweBootstrapKey_u64>();
new KeyManager<LweKeyswitchKey_u64>();
new WorkFunctionRegistry();
_dfr_jit_workfunction_registration_barrier = new hpx::lcos::barrier(
"wait_register_remote_work_functions", hpx::get_num_localities().get(),
hpx::get_locality_id());
_dfr_jit_phase_barrier = new hpx::lcos::barrier(
"phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id());
if (_dfr_is_root_node()) {
if (_dfr_is_root_node_impl()) {
// Create compute server components on each node - from the root
// node only - and the corresponding compute client on the root
// node.
@@ -712,9 +772,6 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
gcc = hpx::new_<GenericComputeClient[]>(
hpx::default_layout(hpx::find_all_localities()), num_nodes)
.get();
} else {
hpx::stop();
exit(EXIT_SUCCESS);
}
}
@@ -722,16 +779,36 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
JIT invocation). These serve to pause/resume the runtime
scheduler and to clean up used resources. */
void _dfr_start() {
uint64_t uninitialised = 0;
if (init_guard.compare_exchange_strong(uninitialised, 1))
_dfr_start_impl(0, nullptr);
else
hpx::resume();
hpx::resume();
if (!_dfr_is_jit())
_dfr_stop_impl();
// TODO: conditional -- If this is the root node, and this is JIT
// execution, we need to wait for the compute nodes to compile and
// register work functions
if (_dfr_is_root_node_impl() && _dfr_is_jit()) {
_dfr_jit_workfunction_registration_barrier->wait();
}
}
void _dfr_stop() {
if (!_dfr_is_root_node_impl() /*&& _dfr_is_jit() /** implicitly true*/) {
_dfr_jit_workfunction_registration_barrier->wait();
}
// The barrier is only needed to synchronize the different
// computation phases when the compute nodes need to generate and
// register new work functions in each phase.
if (_dfr_is_jit()) {
_dfr_jit_phase_barrier->wait();
}
hpx::suspend();
_dfr_node_level_bsk_manager->clear_keys();
_dfr_node_level_ksk_manager->clear_keys();
while (!new_allocated.empty()) {
delete[] static_cast<char *>(new_allocated.front());
new_allocated.pop_front();
@@ -758,7 +835,7 @@ void _dfr_terminate() {
/* Main wrapper. */
/*******************/
extern "C" {
extern int main(int argc, char *argv[]); // __attribute__((weak));
extern int main(int argc, char *argv[]) __attribute__((weak));
extern int __real_main(int argc, char *argv[]) __attribute__((weak));
int __wrap_main(int argc, char *argv[]) {
int r;
@@ -790,7 +867,7 @@ size_t _dfr_debug_get_node_id() { return hpx::get_locality_id(); }
size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); }
void _dfr_debug_print_task(const char *name, int inputs, int outputs) {
void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs) {
// clang-format off
hpx::cout << "Task \"" << name << "\""
<< " [" << inputs << " inputs, " << outputs << " outputs]"
@@ -806,7 +883,10 @@ void _dfr_print_debug(size_t val) {
#else // CONCRETELANG_PARALLEL_EXECUTION_ENABLED
#include <concretelang/Runtime/runtime_api.h>
#include "concretelang/Runtime/DFRuntime.hpp"
bool _dfr_is_root_node() { return true; }
void _dfr_is_jit(bool) {}
void _dfr_terminate() {}
#endif

View File

@@ -3,6 +3,7 @@
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
// for license information.
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/JITSupport.h>
#include <llvm/Support/TargetSelect.h>
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
@@ -48,8 +49,11 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
// Mark the lambda as compiled using DF parallelization
result->lambda->setUseDataflow(options.dataflowParallelize ||
options.autoParallelize);
result->clientParameters =
compilationResult.get().clientParameters.getValue();
if (!mlir::concretelang::dfr::_dfr_is_root_node())
result->clientParameters = clientlib::ClientParameters();
else
result->clientParameters =
compilationResult.get().clientParameters.getValue();
return std::move(result);
}

View File

@@ -13,6 +13,7 @@
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
#include "concretelang/Common/BitsSize.h"
#include <concretelang/Runtime/DFRuntime.hpp>
#include <concretelang/Support/Error.h>
#include <concretelang/Support/Jit.h>
#include <concretelang/Support/logging.h>
@@ -89,6 +90,19 @@ JITLambda::call(clientlib::PublicArguments &args,
"call: current runtime doesn't support dataflow execution, while "
"compilation used dataflow parallelization");
}
#else
dfr::_dfr_set_jit(true);
// When using JIT on distributed systems, the compiler only
// generates work-functions and their registration calls. No results
// are returned and no inputs are needed.
if (!dfr::_dfr_is_root_node()) {
std::vector<void *> rawArgs;
if (auto err = invokeRaw(rawArgs)) {
return std::move(err);
}
std::vector<clientlib::TensorData> buffers;
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
}
#endif
// invokeRaw needs to have pointers on arguments and a pointers on the result
// as last argument.

View File

@@ -31,7 +31,7 @@
#include "concretelang/Dialect/RT/IR/RTDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
#include "concretelang/Runtime/runtime_api.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Support/CompilerEngine.h"
#include "concretelang/Support/Error.h"
#include "concretelang/Support/JITSupport.h"