mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 20:25:34 -05:00
feat(runtime): enable distributed execution.
This commit is contained in:
@@ -19,6 +19,7 @@ extern "C" {
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
|
||||
@@ -115,7 +115,13 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> {
|
||||
let summary = "Create a dataflow task.";
|
||||
}
|
||||
|
||||
def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> {
|
||||
def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
|
||||
let arguments = (ins Variadic<AnyType>:$list);
|
||||
let results = (outs );
|
||||
let summary = "Register the task work-function with the runtime system.";
|
||||
}
|
||||
|
||||
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
|
||||
let arguments = (ins RT_Future: $input);
|
||||
let results = (outs );
|
||||
}
|
||||
|
||||
@@ -6,102 +6,39 @@
|
||||
#ifndef CONCRETELANG_DFR_DFRUNTIME_HPP
|
||||
#define CONCRETELANG_DFR_DFRUNTIME_HPP
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <dlfcn.h>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
|
||||
/* Debug interface. */
|
||||
#include "concretelang/Runtime/dfr_debug_interface.h"
|
||||
bool _dfr_is_root_node();
|
||||
void _dfr_is_jit(bool);
|
||||
bool _dfr_is_jit();
|
||||
void _dfr_terminate();
|
||||
|
||||
extern void *dl_handle;
|
||||
struct WorkFunctionRegistry;
|
||||
extern WorkFunctionRegistry *node_level_work_function_registry;
|
||||
typedef enum _dfr_task_arg_type {
|
||||
_DFR_TASK_ARG_BASE = 0,
|
||||
_DFR_TASK_ARG_MEMREF = 1,
|
||||
_DFR_TASK_ARG_UNRANKED_MEMREF = 2,
|
||||
_DFR_TASK_ARG_CONTEXT = 3
|
||||
} _dfr_task_arg_type;
|
||||
|
||||
/// Recover the name of the work function
|
||||
static inline const char *_dfr_get_function_name_from_address(void *fn) {
|
||||
Dl_info info;
|
||||
|
||||
if (!dladdr(fn, &info) || info.dli_sname == nullptr)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_name_from_address",
|
||||
"Error recovering work function name from address.");
|
||||
return info.dli_sname;
|
||||
static inline _dfr_task_arg_type _dfr_get_arg_type(uint64_t val) {
|
||||
return (_dfr_task_arg_type)(val & 0xFF);
|
||||
}
|
||||
|
||||
static inline wfnptr _dfr_get_function_pointer_from_name(const char *fn_name) {
|
||||
auto ptr = dlsym(dl_handle, fn_name);
|
||||
|
||||
if (ptr == nullptr)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_function_pointer_from_name",
|
||||
"Error recovering work function pointer from name.");
|
||||
return (wfnptr)ptr;
|
||||
static inline uint64_t _dfr_get_memref_element_size(uint64_t val) {
|
||||
return val >> 8;
|
||||
}
|
||||
|
||||
/// Determine where new task should run. For now just round-robin
|
||||
/// distribution - TODO: optimise.
|
||||
static inline size_t _dfr_find_next_execution_locality() {
|
||||
static size_t num_nodes = hpx::get_num_localities().get();
|
||||
static std::atomic<std::size_t> next_locality{0};
|
||||
|
||||
size_t next_loc = ++next_locality;
|
||||
|
||||
return next_loc % num_nodes;
|
||||
static inline uint64_t _dfr_set_arg_type(uint64_t val,
|
||||
_dfr_task_arg_type type) {
|
||||
return (val & ~(0xFF)) | type;
|
||||
}
|
||||
|
||||
static inline bool _dfr_is_root_node() {
|
||||
return hpx::find_here() == hpx::find_root_locality();
|
||||
static inline uint64_t _dfr_set_memref_element_size(uint64_t val, size_t size) {
|
||||
assert(size < (((uint64_t)1) << 48));
|
||||
return (val & 0xFF) | (((uint64_t)size) << 8);
|
||||
}
|
||||
|
||||
struct WorkFunctionRegistry {
|
||||
WorkFunctionRegistry() { node_level_work_function_registry = this; }
|
||||
|
||||
wfnptr getWorkFunctionPointer(const std::string &name) {
|
||||
std::lock_guard<std::mutex> guard(registry_guard);
|
||||
|
||||
auto fnptrit = name_to_ptr_registry.find(name);
|
||||
if (fnptrit != name_to_ptr_registry.end())
|
||||
return (wfnptr)fnptrit->second;
|
||||
|
||||
auto ptr = dlsym(dl_handle, name.c_str());
|
||||
if (ptr == nullptr)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"WorkFunctionRegistry::getWorkFunctionPointer",
|
||||
"Error recovering work function pointer from name.");
|
||||
ptr_to_name_registry.insert(
|
||||
std::pair<const void *, std::string>(ptr, name));
|
||||
name_to_ptr_registry.insert(
|
||||
std::pair<std::string, const void *>(name, ptr));
|
||||
return (wfnptr)ptr;
|
||||
}
|
||||
|
||||
std::string getWorkFunctionName(const void *fn) {
|
||||
std::lock_guard<std::mutex> guard(registry_guard);
|
||||
|
||||
auto fnnameit = ptr_to_name_registry.find(fn);
|
||||
if (fnnameit != ptr_to_name_registry.end())
|
||||
return fnnameit->second;
|
||||
|
||||
Dl_info info;
|
||||
std::string ret;
|
||||
// Assume that if we can't find the name, there is no dynamic
|
||||
// library to find it in. TODO: fix this to distinguish JIT/binary
|
||||
// and in case of distributed exec.
|
||||
if (!dladdr(fn, &info) || info.dli_sname == nullptr) {
|
||||
static std::atomic<unsigned int> fnid{0};
|
||||
ret = "_dfr_jit_wfnname_" + std::to_string(fnid++);
|
||||
} else {
|
||||
ret = info.dli_sname;
|
||||
}
|
||||
ptr_to_name_registry.insert(std::pair<const void *, std::string>(fn, ret));
|
||||
name_to_ptr_registry.insert(std::pair<std::string, const void *>(ret, fn));
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex registry_guard;
|
||||
std::map<const void *, std::string> ptr_to_name_registry;
|
||||
std::map<std::string, const void *> name_to_ptr_registry;
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
extern "C" {
|
||||
size_t _dfr_debug_get_node_id();
|
||||
size_t _dfr_debug_get_worker_id();
|
||||
void _dfr_debug_print_task(const char *name, int inputs, int outputs);
|
||||
void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs);
|
||||
void _dfr_print_debug(size_t val);
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -26,101 +26,250 @@
|
||||
#include <hpx/include/runtime.hpp>
|
||||
#include <hpx/modules/collectives.hpp>
|
||||
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/key_manager.hpp"
|
||||
#include <mlir/ExecutionEngine/CRunnerUtils.h>
|
||||
|
||||
extern WorkFunctionRegistry *node_level_work_function_registry;
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Runtime/dfr_debug_interface.h"
|
||||
#include "concretelang/Runtime/key_manager.hpp"
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
#include "concretelang/Runtime/workfunction_registry.hpp"
|
||||
|
||||
extern WorkFunctionRegistry *_dfr_node_level_work_function_registry;
|
||||
extern std::list<void *> new_allocated;
|
||||
|
||||
using namespace hpx::naming;
|
||||
using namespace hpx::components;
|
||||
using namespace hpx::collectives;
|
||||
|
||||
static inline size_t _dfr_get_memref_rank(size_t size) {
|
||||
return (size - 2 * sizeof(char *) /*allocated_ptr & aligned_ptr*/
|
||||
- sizeof(int64_t) /*offset*/) /
|
||||
(2 * sizeof(int64_t) /*size&stride/rank*/);
|
||||
}
|
||||
|
||||
struct OpaqueInputData {
|
||||
OpaqueInputData() = default;
|
||||
|
||||
OpaqueInputData(std::string wfn_name, std::vector<void *> params,
|
||||
std::vector<size_t> param_sizes,
|
||||
std::vector<size_t> output_sizes, bool alloc_p = false)
|
||||
: wfn_name(wfn_name), params(std::move(params)),
|
||||
param_sizes(std::move(param_sizes)),
|
||||
output_sizes(std::move(output_sizes)), alloc_p(alloc_p) {}
|
||||
OpaqueInputData(std::string _wfn_name, std::vector<void *> _params,
|
||||
std::vector<size_t> _param_sizes,
|
||||
std::vector<uint64_t> _param_types,
|
||||
std::vector<size_t> _output_sizes,
|
||||
std::vector<uint64_t> _output_types, bool _alloc_p = false)
|
||||
: wfn_name(_wfn_name), params(std::move(_params)),
|
||||
param_sizes(std::move(_param_sizes)),
|
||||
param_types(std::move(_param_types)),
|
||||
output_sizes(std::move(_output_sizes)),
|
||||
output_types(std::move(_output_types)), alloc_p(_alloc_p),
|
||||
source_locality(hpx::find_here()) {}
|
||||
|
||||
OpaqueInputData(const OpaqueInputData &oid)
|
||||
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
|
||||
param_sizes(std::move(oid.param_sizes)),
|
||||
output_sizes(std::move(oid.output_sizes)), alloc_p(oid.alloc_p) {}
|
||||
param_types(std::move(oid.param_types)),
|
||||
output_sizes(std::move(oid.output_sizes)),
|
||||
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
|
||||
source_locality(oid.source_locality) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar &wfn_name;
|
||||
ar ¶m_sizes;
|
||||
ar &output_sizes;
|
||||
for (auto p : param_sizes) {
|
||||
char *param = new char[p];
|
||||
// TODO: Optimise these serialisation operations
|
||||
for (size_t i = 0; i < p; ++i)
|
||||
ar ¶m[i];
|
||||
ar >> wfn_name;
|
||||
ar >> param_sizes >> param_types;
|
||||
ar >> output_sizes >> output_types;
|
||||
for (size_t p = 0; p < param_sizes.size(); ++p) {
|
||||
char *param = new char[param_sizes[p]];
|
||||
new_allocated.push_back((void *)param);
|
||||
ar >> hpx::serialization::make_array(param, param_sizes[p]);
|
||||
params.push_back((void *)param);
|
||||
|
||||
switch (_dfr_get_arg_type(param_types[p])) {
|
||||
case _DFR_TASK_ARG_BASE:
|
||||
break;
|
||||
case _DFR_TASK_ARG_MEMREF: {
|
||||
size_t rank = _dfr_get_memref_rank(param_sizes[p]);
|
||||
UnrankedMemRefType<char> umref = {rank, params[p]};
|
||||
DynamicMemRefType<char> mref(umref);
|
||||
size_t elementSize = _dfr_get_memref_element_size(param_types[p]);
|
||||
size_t size = 1;
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
size_t alloc_size = (size + mref.offset) * elementSize;
|
||||
char *data = new char[alloc_size];
|
||||
new_allocated.push_back((void *)data);
|
||||
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
|
||||
size * elementSize);
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
uint64_t bsk_id, ksk_id;
|
||||
ar >> bsk_id >> ksk_id >> source_locality;
|
||||
|
||||
mlir::concretelang::RuntimeContext *context =
|
||||
new mlir::concretelang::RuntimeContext;
|
||||
new_allocated.push_back((void *)context);
|
||||
mlir::concretelang::RuntimeContext **_context =
|
||||
new mlir::concretelang::RuntimeContext *[1];
|
||||
new_allocated.push_back((void *)_context);
|
||||
_context[0] = context;
|
||||
|
||||
context->bsk = (LweBootstrapKey_u64 *)bsk_id;
|
||||
context->ksk = (LweKeyswitchKey_u64 *)ksk_id;
|
||||
params[p] = (void *)_context;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
"Error: invalid task argument type.");
|
||||
}
|
||||
}
|
||||
alloc_p = true;
|
||||
}
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
ar &wfn_name;
|
||||
ar ¶m_sizes;
|
||||
ar &output_sizes;
|
||||
for (size_t p = 0; p < params.size(); ++p)
|
||||
for (size_t i = 0; i < param_sizes[p]; ++i)
|
||||
ar &static_cast<char *>(params[p])[i];
|
||||
ar << wfn_name;
|
||||
ar << param_sizes << param_types;
|
||||
ar << output_sizes << output_types;
|
||||
for (size_t p = 0; p < params.size(); ++p) {
|
||||
// Save the first level of the data structure - if the parameter
|
||||
// is a tensor/memref, there is a second level.
|
||||
ar << hpx::serialization::make_array((char *)params[p], param_sizes[p]);
|
||||
switch (_dfr_get_arg_type(param_types[p])) {
|
||||
case _DFR_TASK_ARG_BASE:
|
||||
break;
|
||||
case _DFR_TASK_ARG_MEMREF: {
|
||||
size_t rank = _dfr_get_memref_rank(param_sizes[p]);
|
||||
UnrankedMemRefType<char> umref = {rank, params[p]};
|
||||
DynamicMemRefType<char> mref(umref);
|
||||
size_t elementSize = _dfr_get_memref_element_size(param_types[p]);
|
||||
size_t size = 1;
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
ar << hpx::serialization::make_array(
|
||||
mref.data + mref.offset * elementSize, size * elementSize);
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
mlir::concretelang::RuntimeContext *context =
|
||||
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
|
||||
LweKeyswitchKey_u64 *ksk = context->ksk;
|
||||
LweBootstrapKey_u64 *bsk = context->bsk;
|
||||
|
||||
// TODO: find better unique identifiers. This is not a
|
||||
// correctness issue, but performance.
|
||||
uint64_t bsk_id = (uint64_t)bsk;
|
||||
uint64_t ksk_id = (uint64_t)ksk;
|
||||
|
||||
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
|
||||
_dfr_register_bsk(bsk, bsk_id);
|
||||
_dfr_register_ksk(ksk, ksk_id);
|
||||
ar << bsk_id << ksk_id << source_locality;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
"Error: invalid task argument type.");
|
||||
}
|
||||
}
|
||||
}
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
|
||||
std::string wfn_name;
|
||||
std::vector<void *> params;
|
||||
std::vector<size_t> param_sizes;
|
||||
std::vector<uint64_t> param_types;
|
||||
std::vector<size_t> output_sizes;
|
||||
std::vector<uint64_t> output_types;
|
||||
bool alloc_p = false;
|
||||
hpx::naming::id_type source_locality;
|
||||
};
|
||||
|
||||
struct OpaqueOutputData {
|
||||
OpaqueOutputData() = default;
|
||||
OpaqueOutputData(std::vector<void *> outputs,
|
||||
std::vector<size_t> output_sizes, bool alloc_p = false)
|
||||
std::vector<size_t> output_sizes,
|
||||
std::vector<uint64_t> output_types, bool alloc_p = false)
|
||||
: outputs(std::move(outputs)), output_sizes(std::move(output_sizes)),
|
||||
alloc_p(alloc_p) {}
|
||||
output_types(std::move(output_types)), alloc_p(alloc_p) {}
|
||||
OpaqueOutputData(const OpaqueOutputData &ood)
|
||||
: outputs(std::move(ood.outputs)),
|
||||
output_sizes(std::move(ood.output_sizes)), alloc_p(ood.alloc_p) {}
|
||||
output_sizes(std::move(ood.output_sizes)),
|
||||
output_types(std::move(ood.output_types)), alloc_p(ood.alloc_p) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar &output_sizes;
|
||||
for (auto p : output_sizes) {
|
||||
char *output = new char[p];
|
||||
for (size_t i = 0; i < p; ++i)
|
||||
ar &output[i];
|
||||
outputs.push_back((void *)output);
|
||||
ar >> output_sizes >> output_types;
|
||||
for (size_t p = 0; p < output_sizes.size(); ++p) {
|
||||
char *output = new char[output_sizes[p]];
|
||||
new_allocated.push_back((void *)output);
|
||||
ar >> hpx::serialization::make_array(output, output_sizes[p]);
|
||||
outputs.push_back((void *)output);
|
||||
|
||||
switch (_dfr_get_arg_type(output_types[p])) {
|
||||
case _DFR_TASK_ARG_BASE:
|
||||
break;
|
||||
case _DFR_TASK_ARG_MEMREF: {
|
||||
size_t rank = _dfr_get_memref_rank(output_sizes[p]);
|
||||
UnrankedMemRefType<char> umref = {rank, outputs[p]};
|
||||
DynamicMemRefType<char> mref(umref);
|
||||
size_t elementSize = _dfr_get_memref_element_size(output_types[p]);
|
||||
size_t size = 1;
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
size_t alloc_size = (size + mref.offset) * elementSize;
|
||||
char *data = new char[alloc_size];
|
||||
new_allocated.push_back((void *)data);
|
||||
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
|
||||
size * elementSize);
|
||||
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
|
||||
nullptr;
|
||||
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->data = data;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
"Error: invalid task argument type.");
|
||||
}
|
||||
}
|
||||
alloc_p = true;
|
||||
}
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
ar &output_sizes;
|
||||
ar << output_sizes << output_types;
|
||||
for (size_t p = 0; p < outputs.size(); ++p) {
|
||||
for (size_t i = 0; i < output_sizes[p]; ++i)
|
||||
ar &static_cast<char *>(outputs[p])[i];
|
||||
// TODO: investigate if HPX is automatically deallocating
|
||||
// these. Here it could be safely assumed that these would no
|
||||
// longer be live.
|
||||
// delete (char*)outputs[p];
|
||||
ar << hpx::serialization::make_array((char *)outputs[p], output_sizes[p]);
|
||||
|
||||
switch (_dfr_get_arg_type(output_types[p])) {
|
||||
case _DFR_TASK_ARG_BASE:
|
||||
break;
|
||||
case _DFR_TASK_ARG_MEMREF: {
|
||||
size_t rank = _dfr_get_memref_rank(output_sizes[p]);
|
||||
UnrankedMemRefType<char> umref = {rank, outputs[p]};
|
||||
DynamicMemRefType<char> mref(umref);
|
||||
size_t elementSize = _dfr_get_memref_element_size(output_types[p]);
|
||||
size_t size = 1;
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
ar << hpx::serialization::make_array(
|
||||
mref.data + mref.offset * elementSize, size * elementSize);
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save",
|
||||
"Error: invalid task argument type.");
|
||||
}
|
||||
}
|
||||
}
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
|
||||
std::vector<void *> outputs;
|
||||
std::vector<size_t> output_sizes;
|
||||
std::vector<uint64_t> output_types;
|
||||
bool alloc_p = false;
|
||||
};
|
||||
|
||||
@@ -129,10 +278,24 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
|
||||
// Component actions exposed
|
||||
OpaqueOutputData execute_task(const OpaqueInputData &inputs) {
|
||||
auto wfn = node_level_work_function_registry->getWorkFunctionPointer(
|
||||
auto wfn = _dfr_node_level_work_function_registry->getWorkFunctionPointer(
|
||||
inputs.wfn_name);
|
||||
std::vector<void *> outputs;
|
||||
|
||||
if (!_dfr_is_root_node()) {
|
||||
for (size_t p = 0; p < inputs.params.size(); ++p) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) {
|
||||
mlir::concretelang::RuntimeContext *ctx =
|
||||
(*(mlir::concretelang::RuntimeContext **)inputs.params[p]);
|
||||
ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk);
|
||||
ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
|
||||
inputs.output_sizes.size());
|
||||
hpx::cout << std::flush;
|
||||
switch (inputs.output_sizes.size()) {
|
||||
case 1: {
|
||||
void *output = (void *)(new char[inputs.output_sizes[0]]);
|
||||
@@ -449,12 +612,8 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
"Error: number of task outputs not supported.");
|
||||
}
|
||||
|
||||
if (inputs.alloc_p)
|
||||
for (auto p : inputs.params)
|
||||
delete ((char *)p);
|
||||
|
||||
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
|
||||
inputs.alloc_p);
|
||||
std::move(inputs.output_types), inputs.alloc_p);
|
||||
}
|
||||
|
||||
HPX_DEFINE_COMPONENT_ACTION(GenericComputeServer, execute_task);
|
||||
|
||||
@@ -7,129 +7,206 @@
|
||||
#define CONCRETELANG_DFR_KEY_MANAGER_HPP
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
#include <hpx/include/runtime.hpp>
|
||||
#include <hpx/modules/collectives.hpp>
|
||||
#include <hpx/modules/serialization.hpp>
|
||||
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
|
||||
struct PbsKeyManager;
|
||||
extern PbsKeyManager *node_level_key_manager;
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
}
|
||||
|
||||
struct PbsKeyWrapper {
|
||||
std::shared_ptr<void *> key;
|
||||
size_t key_id;
|
||||
size_t size;
|
||||
extern std::list<void *> new_allocated;
|
||||
|
||||
PbsKeyWrapper() {}
|
||||
template <typename T> struct KeyManager;
|
||||
extern KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
|
||||
extern KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
|
||||
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
|
||||
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id);
|
||||
|
||||
PbsKeyWrapper(void *key, size_t key_id, size_t size)
|
||||
: key(std::make_shared<void *>(key)), key_id(key_id), size(size) {}
|
||||
|
||||
PbsKeyWrapper(std::shared_ptr<void *> key, size_t key_id, size_t size)
|
||||
: key(key), key_id(key_id), size(size) {}
|
||||
|
||||
PbsKeyWrapper(PbsKeyWrapper &&moved) noexcept
|
||||
: key(moved.key), key_id(moved.key_id), size(moved.size) {}
|
||||
|
||||
PbsKeyWrapper(const PbsKeyWrapper &pbsk)
|
||||
: key(pbsk.key), key_id(pbsk.key_id), size(pbsk.size) {}
|
||||
template <typename LweKeyType> struct KeyWrapper {
|
||||
LweKeyType *key;
|
||||
|
||||
KeyWrapper() : key(nullptr) {}
|
||||
KeyWrapper(LweKeyType *key) : key(key) {}
|
||||
KeyWrapper(KeyWrapper &&moved) noexcept : key(moved.key) {}
|
||||
KeyWrapper(const KeyWrapper &kw) : key(kw.key) {}
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive>
|
||||
void save(Archive &ar, const unsigned int version) const {
|
||||
char *_key_ = static_cast<char *>(*key);
|
||||
ar &key_id &size;
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
ar &_key_[i];
|
||||
}
|
||||
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar &key_id &size;
|
||||
char *_key_ = (char *)malloc(size);
|
||||
for (size_t i = 0; i < size; ++i)
|
||||
ar &_key_[i];
|
||||
key = std::make_shared<void *>(_key_);
|
||||
}
|
||||
void save(Archive &ar, const unsigned int version) const;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version);
|
||||
HPX_SERIALIZATION_SPLIT_MEMBER()
|
||||
};
|
||||
|
||||
inline bool operator==(const PbsKeyWrapper &lhs, const PbsKeyWrapper &rhs) {
|
||||
return lhs.key_id == rhs.key_id;
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweBootstrapKey_u64>::save(Archive &ar,
|
||||
const unsigned int version) const {
|
||||
Buffer buffer = serialize_lwe_bootstrap_key_u64(key);
|
||||
ar << buffer.length;
|
||||
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
}
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
|
||||
const unsigned int version) {
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
new_allocated.push_back((void *)pointer);
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_bootstrap_key_u64(buffer);
|
||||
}
|
||||
|
||||
PbsKeyWrapper _dfr_fetch_key(size_t);
|
||||
HPX_PLAIN_ACTION(_dfr_fetch_key, _dfr_fetch_key_action)
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweKeyswitchKey_u64>::save(Archive &ar,
|
||||
const unsigned int version) const {
|
||||
Buffer buffer = serialize_lwe_keyswitching_key_u64(key);
|
||||
ar << buffer.length;
|
||||
ar << hpx::serialization::make_array(buffer.pointer, buffer.length);
|
||||
}
|
||||
template <>
|
||||
template <class Archive>
|
||||
void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
|
||||
const unsigned int version) {
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
new_allocated.push_back((void *)pointer);
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_keyswitching_key_u64(buffer);
|
||||
}
|
||||
|
||||
struct PbsKeyManager {
|
||||
// The initial keys registered on the root node and whether to push
|
||||
// them is TBD.
|
||||
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t);
|
||||
HPX_PLAIN_ACTION(_dfr_fetch_ksk, _dfr_fetch_ksk_action)
|
||||
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t);
|
||||
HPX_PLAIN_ACTION(_dfr_fetch_bsk, _dfr_fetch_bsk_action)
|
||||
|
||||
PbsKeyManager() { node_level_key_manager = this; }
|
||||
template <typename LweKeyType> struct KeyManager {
|
||||
KeyManager() {}
|
||||
LweKeyType *get_key(hpx::naming::id_type loc, const uint64_t key_id);
|
||||
|
||||
PbsKeyWrapper get_key(const size_t key_id) {
|
||||
keystore_guard.lock();
|
||||
auto keyit = keystore.find(key_id);
|
||||
keystore_guard.unlock();
|
||||
|
||||
if (keyit == keystore.end()) {
|
||||
_dfr_fetch_key_action fet;
|
||||
PbsKeyWrapper &&pkw = fet(hpx::find_root_locality(), key_id);
|
||||
if (pkw.size == 0) {
|
||||
// Maybe retry or try other nodes... but for now it's an error.
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
|
||||
"Error: key not found on remote node.");
|
||||
} else {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
keyit = keystore.insert(std::pair<size_t, PbsKeyWrapper>(key_id, pkw))
|
||||
.first;
|
||||
}
|
||||
}
|
||||
return keyit->second;
|
||||
}
|
||||
|
||||
// To be used only for remote requests
|
||||
PbsKeyWrapper fetch_key(const size_t key_id) {
|
||||
KeyWrapper<LweKeyType> fetch_key(const uint64_t key_id) {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
|
||||
auto keyit = keystore.find(key_id);
|
||||
if (keyit != keystore.end())
|
||||
return keyit->second;
|
||||
// If this node does not contain this key, return an empty wrapper
|
||||
return PbsKeyWrapper(nullptr, 0, 0);
|
||||
// If this node does not contain this key, this is an error
|
||||
// (location was supplied as source for this key).
|
||||
HPX_THROW_EXCEPTION(
|
||||
hpx::no_success, "fetch_key",
|
||||
"Error: could not find key to be fetched on source location.");
|
||||
}
|
||||
|
||||
void register_key(void *key, size_t key_id, size_t size) {
|
||||
void register_key(LweKeyType *key, uint64_t key_id) {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
auto keyit = keystore
|
||||
.insert(std::pair<size_t, PbsKeyWrapper>(
|
||||
key_id, PbsKeyWrapper(key, key_id, size)))
|
||||
.first;
|
||||
auto keyit = keystore.find(key_id);
|
||||
if (keyit == keystore.end()) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key",
|
||||
"Error: could not register new key.");
|
||||
keyit = keystore
|
||||
.insert(std::pair<uint64_t, KeyWrapper<LweKeyType>>(
|
||||
key_id, KeyWrapper<LweKeyType>(key)))
|
||||
.first;
|
||||
if (keyit == keystore.end()) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_register_key",
|
||||
"Error: could not register new key.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void broadcast_keys() {
|
||||
void clear_keys() {
|
||||
std::lock_guard<std::mutex> guard(keystore_guard);
|
||||
if (_dfr_is_root_node())
|
||||
hpx::collectives::broadcast_to("keystore", this->keystore).get();
|
||||
else
|
||||
keystore = std::move(
|
||||
hpx::collectives::broadcast_from<std::map<size_t, PbsKeyWrapper>>(
|
||||
"keystore")
|
||||
.get());
|
||||
keystore.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
std::mutex keystore_guard;
|
||||
std::map<size_t, PbsKeyWrapper> keystore;
|
||||
std::map<uint64_t, KeyWrapper<LweKeyType>> keystore;
|
||||
};
|
||||
|
||||
PbsKeyWrapper _dfr_fetch_key(size_t key_id) {
|
||||
return node_level_key_manager->fetch_key(key_id);
|
||||
template <> KeyManager<LweBootstrapKey_u64>::KeyManager() {
|
||||
_dfr_node_level_bsk_manager = this;
|
||||
}
|
||||
|
||||
template <>
|
||||
LweBootstrapKey_u64 *
|
||||
KeyManager<LweBootstrapKey_u64>::get_key(hpx::naming::id_type loc,
|
||||
const uint64_t key_id) {
|
||||
keystore_guard.lock();
|
||||
auto keyit = keystore.find(key_id);
|
||||
keystore_guard.unlock();
|
||||
|
||||
if (keyit == keystore.end()) {
|
||||
_dfr_fetch_bsk_action fetch;
|
||||
KeyWrapper<LweBootstrapKey_u64> &&bskw = fetch(loc, key_id);
|
||||
if (bskw.key == nullptr) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
|
||||
"Error: Bootstrap key not found on root node.");
|
||||
} else {
|
||||
_dfr_register_bsk(bskw.key, key_id);
|
||||
}
|
||||
return bskw.key;
|
||||
}
|
||||
return keyit->second.key;
|
||||
}
|
||||
|
||||
template <> KeyManager<LweKeyswitchKey_u64>::KeyManager() {
|
||||
_dfr_node_level_ksk_manager = this;
|
||||
}
|
||||
|
||||
template <>
|
||||
LweKeyswitchKey_u64 *
|
||||
KeyManager<LweKeyswitchKey_u64>::get_key(hpx::naming::id_type loc,
|
||||
const uint64_t key_id) {
|
||||
keystore_guard.lock();
|
||||
auto keyit = keystore.find(key_id);
|
||||
keystore_guard.unlock();
|
||||
|
||||
if (keyit == keystore.end()) {
|
||||
_dfr_fetch_ksk_action fetch;
|
||||
KeyWrapper<LweKeyswitchKey_u64> &&kskw = fetch(loc, key_id);
|
||||
if (kskw.key == nullptr) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_get_key",
|
||||
"Error: Keyswitching key not found on root node.");
|
||||
} else {
|
||||
_dfr_register_ksk(kskw.key, key_id);
|
||||
}
|
||||
return kskw.key;
|
||||
}
|
||||
return keyit->second.key;
|
||||
}
|
||||
|
||||
KeyWrapper<LweBootstrapKey_u64> _dfr_fetch_bsk(uint64_t key_id) {
|
||||
return _dfr_node_level_bsk_manager->fetch_key(key_id);
|
||||
}
|
||||
|
||||
KeyWrapper<LweKeyswitchKey_u64> _dfr_fetch_ksk(uint64_t key_id) {
|
||||
return _dfr_node_level_ksk_manager->fetch_key(key_id);
|
||||
}
|
||||
|
||||
/************************/
|
||||
/* Key management API. */
|
||||
/************************/
|
||||
|
||||
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id) {
|
||||
_dfr_node_level_bsk_manager->register_key(key, key_id);
|
||||
}
|
||||
void _dfr_register_ksk(LweKeyswitchKey_u64 *key, uint64_t key_id) {
|
||||
_dfr_node_level_ksk_manager->register_key(key, key_id);
|
||||
}
|
||||
|
||||
LweBootstrapKey_u64 *_dfr_get_bsk(hpx::naming::id_type loc, uint64_t key_id) {
|
||||
return _dfr_node_level_bsk_manager->get_key(loc, key_id);
|
||||
}
|
||||
LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
|
||||
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -16,14 +16,9 @@ typedef void (*wfnptr)(...);
|
||||
|
||||
void *_dfr_make_ready_future(void *);
|
||||
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
|
||||
void _dfr_register_work_function(wfnptr);
|
||||
void *_dfr_await_future(void *);
|
||||
|
||||
/* Keys can have node-local copies which can be retrieved. This
|
||||
should only be called on the node where the key is required. */
|
||||
void _dfr_register_key(void *, size_t, size_t);
|
||||
void _dfr_broadcast_keys();
|
||||
void *_dfr_get_key(size_t);
|
||||
|
||||
/* Memory management:
|
||||
_dfr_make_ready_future allocates the future, not the underlying storage.
|
||||
_dfr_create_async_task allocates both future and storage for outputs. */
|
||||
@@ -36,4 +31,5 @@ void _dfr_stop();
|
||||
|
||||
void _dfr_terminate();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -0,0 +1,92 @@
|
||||
// Part of the Concrete Compiler Project, under the BSD3 License with Zama
|
||||
// Exceptions. See
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/master/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#ifndef CONCRETELANG_DFR_WORKFUNCTION_REGISTRY_HPP
|
||||
#define CONCRETELANG_DFR_WORKFUNCTION_REGISTRY_HPP
|
||||
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <utility>
|
||||
|
||||
#include <hpx/include/runtime.hpp>
|
||||
#include <hpx/modules/collectives.hpp>
|
||||
#include <hpx/modules/serialization.hpp>
|
||||
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
|
||||
extern void *dl_handle;
|
||||
struct WorkFunctionRegistry;
|
||||
extern WorkFunctionRegistry *_dfr_node_level_work_function_registry;
|
||||
|
||||
struct WorkFunctionRegistry {
|
||||
WorkFunctionRegistry() { _dfr_node_level_work_function_registry = this; }
|
||||
|
||||
wfnptr getWorkFunctionPointer(const std::string &name) {
|
||||
std::lock_guard<std::recursive_mutex> guard(registry_guard);
|
||||
|
||||
auto fnptrit = name_to_ptr_registry.find(name);
|
||||
if (fnptrit != name_to_ptr_registry.end())
|
||||
return (wfnptr)fnptrit->second;
|
||||
|
||||
auto ptr = dlsym(dl_handle, name.c_str());
|
||||
if (ptr == nullptr) {
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"WorkFunctionRegistry::getWorkFunctionPointer",
|
||||
"Error recovering work function pointer from name.");
|
||||
}
|
||||
registerWorkFunction(ptr, name);
|
||||
return (wfnptr)ptr;
|
||||
}
|
||||
|
||||
std::string getWorkFunctionName(const void *fn) {
|
||||
std::lock_guard<std::recursive_mutex> guard(registry_guard);
|
||||
|
||||
auto fnnameit = ptr_to_name_registry.find(fn);
|
||||
if (fnnameit != ptr_to_name_registry.end())
|
||||
return fnnameit->second;
|
||||
|
||||
Dl_info info;
|
||||
std::string ret;
|
||||
// Assume that if we can't find the name, there is no dynamic
|
||||
// library to find it in. TODO: fix this to distinguish JIT/binary
|
||||
// and in case of distributed exec.
|
||||
if (!dladdr(fn, &info) || info.dli_sname == nullptr) {
|
||||
ret = registerAnonymousWorkFunction(fn);
|
||||
} else {
|
||||
ret = info.dli_sname;
|
||||
registerWorkFunction(fn, ret);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
void registerWorkFunction(const void *fn, std::string name) {
|
||||
std::lock_guard<std::recursive_mutex> guard(registry_guard);
|
||||
|
||||
auto fnnameit = ptr_to_name_registry.find(fn);
|
||||
if (fnnameit == ptr_to_name_registry.end())
|
||||
ptr_to_name_registry.insert(
|
||||
std::pair<const void *, std::string>(fn, name));
|
||||
|
||||
auto fnptrit = name_to_ptr_registry.find(name);
|
||||
if (fnptrit == name_to_ptr_registry.end())
|
||||
name_to_ptr_registry.insert(
|
||||
std::pair<std::string, const void *>(name, fn));
|
||||
}
|
||||
|
||||
std::string registerAnonymousWorkFunction(const void *fn) {
|
||||
std::lock_guard<std::recursive_mutex> guard(registry_guard);
|
||||
static std::atomic<unsigned int> fnid{0};
|
||||
std::string name = "_dfr_jit_wfnname_" + std::to_string(fnid++);
|
||||
registerWorkFunction(fn, name);
|
||||
return name;
|
||||
}
|
||||
|
||||
private:
|
||||
std::recursive_mutex registry_guard;
|
||||
std::map<const void *, std::string> ptr_to_name_registry;
|
||||
std::map<std::string, const void *> name_to_ptr_registry;
|
||||
};
|
||||
|
||||
#endif
|
||||
@@ -14,5 +14,5 @@ add_mlir_library(RTDialectAnalysis
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
RTDialect
|
||||
ConcretelangRuntime
|
||||
)
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <concretelang/Dialect/RT/IR/RTDialect.h>
|
||||
#include <concretelang/Dialect/RT/IR/RTOps.h>
|
||||
#include <concretelang/Dialect/RT/IR/RTTypes.h>
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/math.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
|
||||
@@ -132,7 +133,8 @@ static void replaceAllUsesNotInDFTsInRegionWith(Value orig, Value replacement,
|
||||
}
|
||||
|
||||
// TODO: Fix type sizes. For now we're using some default values.
|
||||
static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) {
|
||||
static std::pair<mlir::Value, mlir::Value>
|
||||
getSizeInBytes(Value val, Location loc, OpBuilder builder) {
|
||||
DataLayout dataLayout = DataLayout::closest(val.getDefiningOp());
|
||||
Type type = (val.getType().isa<RT::FutureType>())
|
||||
? val.getType().dyn_cast<RT::FutureType>().getElementType()
|
||||
@@ -153,26 +155,61 @@ static mlir::Value getSizeInBytes(Value val, Location loc, OpBuilder builder) {
|
||||
Value rank = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(_rank));
|
||||
Value sizes_shapes = builder.create<LLVM::MulOp>(loc, rank, multiplier);
|
||||
Value result = builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
|
||||
return result;
|
||||
Value typeSize =
|
||||
builder.create<LLVM::AddOp>(loc, ptrs_offset, sizes_shapes);
|
||||
|
||||
Type elementType = type.dyn_cast<mlir::MemRefType>().getElementType();
|
||||
// Assume here that the base type is a simple scalar-type or at
|
||||
// least its size can be determined.
|
||||
// size_t elementAttr = dataLayout.getTypeSize(elementType);
|
||||
// Make room for a byte to store the type of this argument/output
|
||||
// elementAttr <<= 8;
|
||||
// elementAttr |= _DFR_TASK_ARG_MEMREF;
|
||||
uint64_t elementAttr = 0;
|
||||
size_t element_size = dataLayout.getTypeSize(elementType);
|
||||
elementAttr = _dfr_set_arg_type(elementAttr, _DFR_TASK_ARG_MEMREF);
|
||||
elementAttr = _dfr_set_memref_element_size(elementAttr, element_size);
|
||||
Value arg_type = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(elementAttr));
|
||||
return std::pair<mlir::Value, mlir::Value>(typeSize, arg_type);
|
||||
}
|
||||
|
||||
// Unranked memrefs should be lowered to just pointer + size, so we need 16
|
||||
// bytes.
|
||||
if (type.isa<mlir::UnrankedMemRefType>())
|
||||
return builder.create<arith::ConstantOp>(loc,
|
||||
builder.getI64IntegerAttr(16));
|
||||
if (type.isa<mlir::UnrankedMemRefType>()) {
|
||||
Value arg_type = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_UNRANKED_MEMREF));
|
||||
Value result =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(16));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
}
|
||||
|
||||
Value arg_type = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_BASE));
|
||||
|
||||
// FHE types are converted to pointers, so we take their size as 8
|
||||
// bytes until we can get the actual size of the actual types.
|
||||
if (type.isa<mlir::concretelang::Concrete::ContextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>())
|
||||
return builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweKeySwitchKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweBootstrapKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::ForeignPlaintextListType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::PlaintextListType>()) {
|
||||
Value result =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
} else if (type.isa<mlir::concretelang::Concrete::ContextType>()) {
|
||||
Value arg_type = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(_DFR_TASK_ARG_CONTEXT));
|
||||
Value result =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
}
|
||||
|
||||
// For all other types, get type size.
|
||||
return builder.create<arith::ConstantOp>(
|
||||
Value result = builder.create<arith::ConstantOp>(
|
||||
loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type)));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
}
|
||||
|
||||
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
@@ -189,9 +226,22 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
builder.setInsertionPoint(DFTOp);
|
||||
for (Value val : DFTOp.getOperands()) {
|
||||
if (!val.getType().isa<RT::FutureType>()) {
|
||||
Type futType = RT::FutureType::get(val.getType());
|
||||
auto mrf =
|
||||
builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType, val);
|
||||
Value newval;
|
||||
// If we are building a future on a MemRef, then we need to clone
|
||||
// the memref in order to allow the deallocation pass which does
|
||||
// not synchronize with task execution.
|
||||
if (val.getType().isa<mlir::MemRefType>() && !val.isa<BlockArgument>()) {
|
||||
newval = builder
|
||||
.create<memref::AllocOp>(
|
||||
DFTOp.getLoc(), val.getType().cast<mlir::MemRefType>())
|
||||
.getResult();
|
||||
builder.create<memref::CopyOp>(DFTOp.getLoc(), val, newval);
|
||||
} else {
|
||||
newval = val;
|
||||
}
|
||||
Type futType = RT::FutureType::get(newval.getType());
|
||||
auto mrf = builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType,
|
||||
newval);
|
||||
map.map(mrf->getResult(0), val);
|
||||
replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody);
|
||||
}
|
||||
@@ -201,7 +251,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
// DataflowTaskOp. This also includes the necessary handling of
|
||||
// operands and results (conversion to/from futures and propagation).
|
||||
SmallVector<Value, 4> catOperands;
|
||||
int size = 3 + DFTOp.getNumResults() * 2 + DFTOp.getNumOperands() * 2;
|
||||
int size = 3 + DFTOp.getNumResults() * 3 + DFTOp.getNumOperands() * 3;
|
||||
catOperands.reserve(size);
|
||||
auto fnptr = builder.create<mlir::func::ConstantOp>(
|
||||
DFTOp.getLoc(), workFunction.getFunctionType(),
|
||||
@@ -214,8 +264,10 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
catOperands.push_back(numIns.getResult());
|
||||
catOperands.push_back(numOuts.getResult());
|
||||
for (auto operand : DFTOp.getOperands()) {
|
||||
auto op_size = getSizeInBytes(operand, DFTOp.getLoc(), builder);
|
||||
catOperands.push_back(operand);
|
||||
catOperands.push_back(getSizeInBytes(operand, DFTOp.getLoc(), builder));
|
||||
catOperands.push_back(op_size.first);
|
||||
catOperands.push_back(op_size.second);
|
||||
}
|
||||
|
||||
// We need to adjust the results for the CreateAsyncTaskOp which
|
||||
@@ -228,9 +280,11 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
|
||||
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
|
||||
futType);
|
||||
auto op_size = getSizeInBytes(result, DFTOp.getLoc(), builder);
|
||||
map.map(result, brpp->getResult(0));
|
||||
catOperands.push_back(brpp->getResult(0));
|
||||
catOperands.push_back(getSizeInBytes(result, DFTOp.getLoc(), builder));
|
||||
catOperands.push_back(op_size.first);
|
||||
catOperands.push_back(op_size.second);
|
||||
}
|
||||
builder.create<RT::CreateAsyncTaskOp>(
|
||||
DFTOp.getLoc(),
|
||||
@@ -272,6 +326,18 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
DFTOp.erase();
|
||||
}
|
||||
|
||||
static void registerWorkFunction(FuncOp parentFunc, FuncOp workFunction) {
|
||||
OpBuilder builder(parentFunc.body());
|
||||
builder.setInsertionPointToStart(&parentFunc.body().front());
|
||||
|
||||
auto fnptr = builder.create<mlir::ConstantOp>(
|
||||
parentFunc.getLoc(), workFunction.getType(),
|
||||
SymbolRefAttr::get(builder.getContext(), workFunction.getName()));
|
||||
|
||||
builder.create<RT::RegisterTaskWorkFunctionOp>(parentFunc.getLoc(),
|
||||
fnptr.getResult());
|
||||
}
|
||||
|
||||
/// For documentation see Autopar.td
|
||||
struct LowerDataflowTasksPass
|
||||
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
|
||||
@@ -306,6 +372,33 @@ struct LowerDataflowTasksPass
|
||||
for (auto mapping : outliningMap)
|
||||
lowerDataflowTaskOp(mapping.first, mapping.second);
|
||||
|
||||
// If this is a JIT invocation and we're not on the root node,
|
||||
// we do not need to do any computation, only register all work
|
||||
// functions with the runtime system
|
||||
if (!outliningMap.empty()) {
|
||||
if (!_dfr_is_root_node()) {
|
||||
// auto regFunc = builder.create<FuncOp>(func.getLoc(),
|
||||
// func.getName(), func.getType());
|
||||
|
||||
func.eraseBody();
|
||||
Block *b = new Block;
|
||||
b->addArguments(func.getType().getInputs());
|
||||
func.body().push_front(b);
|
||||
for (int i = func.getType().getNumInputs() - 1; i >= 0; --i)
|
||||
func.eraseArgument(i);
|
||||
for (int i = func.getType().getNumResults() - 1; i >= 0; --i)
|
||||
func.eraseResult(i);
|
||||
OpBuilder builder(func.body());
|
||||
builder.setInsertionPointToEnd(&func.body().front());
|
||||
builder.create<ReturnOp>(func.getLoc());
|
||||
}
|
||||
}
|
||||
|
||||
// Generate code to register all work-functions with the
|
||||
// runtime.
|
||||
for (auto mapping : outliningMap)
|
||||
registerWorkFunction(func, mapping.second);
|
||||
|
||||
// Issue _dfr_start/stop calls for this function
|
||||
if (!outliningMap.empty()) {
|
||||
OpBuilder builder(func.getBody());
|
||||
|
||||
@@ -171,6 +171,26 @@ struct CreateAsyncTaskOpInterfaceLowering
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct RegisterTaskWorkFunctionOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::RegisterTaskWorkFunctionOp> {
|
||||
using ConvertOpToLLVMPattern<
|
||||
RT::RegisterTaskWorkFunctionOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::RegisterTaskWorkFunctionOp rtwfOp,
|
||||
ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::RegisterTaskWorkFunctionOp::Adaptor transformed(operands);
|
||||
|
||||
auto rtwfFuncType =
|
||||
LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true);
|
||||
auto rtwfFuncOp = getOrInsertFuncOpDecl(
|
||||
rtwfOp, "_dfr_register_work_function", rtwfFuncType, rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(rtwfOp, rtwfFuncOp,
|
||||
transformed.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
struct DeallocateFutureOpInterfaceLowering
|
||||
: public ConvertOpToLLVMPattern<RT::DeallocateFutureOp> {
|
||||
using ConvertOpToLLVMPattern<RT::DeallocateFutureOp>::ConvertOpToLLVMPattern;
|
||||
@@ -296,6 +316,7 @@ void mlir::concretelang::populateRTToLLVMConversionPatterns(
|
||||
DerefReturnPtrPlaceholderOpInterfaceLowering,
|
||||
DerefWorkFunctionArgumentPtrPlaceholderOpInterfaceLowering,
|
||||
CreateAsyncTaskOpInterfaceLowering,
|
||||
RegisterTaskWorkFunctionOpInterfaceLowering,
|
||||
DeallocateFutureOpInterfaceLowering,
|
||||
DeallocateFutureDataOpInterfaceLowering,
|
||||
WorkFunctionReturnOpInterfaceLowering>(converter);
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#ifdef CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
|
||||
#include <hpx/barrier.hpp>
|
||||
#include <hpx/future.hpp>
|
||||
#include <hpx/hpx_start.hpp>
|
||||
#include <hpx/hpx_suspend.hpp>
|
||||
@@ -22,11 +23,16 @@
|
||||
|
||||
std::vector<GenericComputeClient> gcc;
|
||||
void *dl_handle;
|
||||
PbsKeyManager *node_level_key_manager;
|
||||
WorkFunctionRegistry *node_level_work_function_registry;
|
||||
|
||||
KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
|
||||
KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
|
||||
|
||||
WorkFunctionRegistry *_dfr_node_level_work_function_registry;
|
||||
std::list<void *> new_allocated;
|
||||
std::list<void *> fut_allocated;
|
||||
std::list<void *> m_allocated;
|
||||
hpx::lcos::barrier *_dfr_jit_workfunction_registration_barrier;
|
||||
hpx::lcos::barrier *_dfr_jit_phase_barrier;
|
||||
std::atomic<uint64_t> init_guard = {0};
|
||||
|
||||
using namespace hpx;
|
||||
@@ -52,6 +58,17 @@ void _dfr_deallocate_future(void *in) {
|
||||
delete (static_cast<hpx::shared_future<void *> *>(in));
|
||||
}
|
||||
|
||||
// Determine where new task should run. For now just round-robin
|
||||
// distribution - TODO: optimise.
|
||||
static inline size_t _dfr_find_next_execution_locality() {
|
||||
static size_t num_nodes = hpx::get_num_localities().get();
|
||||
static std::atomic<std::size_t> next_locality{0};
|
||||
|
||||
size_t next_loc = ++next_locality;
|
||||
|
||||
return next_loc % num_nodes;
|
||||
}
|
||||
|
||||
/// Runtime generic async_task. Each first NUM_PARAMS pairs of
|
||||
/// arguments in the variadic list corresponds to a void* pointer on a
|
||||
/// hpx::future<void*> and the size of data within the future. After
|
||||
@@ -60,28 +77,39 @@ void _dfr_deallocate_future(void *in) {
|
||||
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
...) {
|
||||
std::vector<void *> params;
|
||||
std::vector<void *> outputs;
|
||||
std::vector<size_t> param_sizes;
|
||||
std::vector<uint64_t> param_types;
|
||||
std::vector<void *> outputs;
|
||||
std::vector<size_t> output_sizes;
|
||||
std::vector<uint64_t> output_types;
|
||||
|
||||
va_list args;
|
||||
va_start(args, num_outputs);
|
||||
for (size_t i = 0; i < num_params; ++i) {
|
||||
params.push_back(va_arg(args, void *));
|
||||
param_sizes.push_back(va_arg(args, size_t));
|
||||
param_sizes.push_back(va_arg(args, uint64_t));
|
||||
param_types.push_back(va_arg(args, uint64_t));
|
||||
}
|
||||
for (size_t i = 0; i < num_outputs; ++i) {
|
||||
outputs.push_back(va_arg(args, void *));
|
||||
output_sizes.push_back(va_arg(args, size_t));
|
||||
output_sizes.push_back(va_arg(args, uint64_t));
|
||||
output_types.push_back(va_arg(args, uint64_t));
|
||||
}
|
||||
va_end(args);
|
||||
|
||||
for (size_t i = 0; i < num_params; ++i) {
|
||||
if (_dfr_get_arg_type(param_types[i] == _DFR_TASK_ARG_MEMREF)) {
|
||||
m_allocated.push_back(
|
||||
(void *)static_cast<StridedMemRefType<char, 1> *>(params[i])->data);
|
||||
}
|
||||
}
|
||||
|
||||
// We pass functions by name - which is not strictly necessary in
|
||||
// shared memory as pointers suffice, but is needed in the
|
||||
// distributed case where the functions need to be located/loaded on
|
||||
// the node.
|
||||
auto wfnname =
|
||||
node_level_work_function_registry->getWorkFunctionName((void *)wfn);
|
||||
_dfr_node_level_work_function_registry->getWorkFunctionName((void *)wfn);
|
||||
hpx::future<hpx::future<OpaqueOutputData>> oodf;
|
||||
|
||||
// In order to allow complete dataflow semantics for
|
||||
@@ -93,20 +121,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
switch (num_params) {
|
||||
case 0:
|
||||
oodf = std::move(
|
||||
hpx::dataflow([wfnname, param_sizes,
|
||||
output_sizes]() -> hpx::future<OpaqueOutputData> {
|
||||
hpx::dataflow([wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types]() -> hpx::future<OpaqueOutputData> {
|
||||
std::vector<void *> params = {};
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
|
||||
output_sizes, output_types);
|
||||
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
|
||||
}));
|
||||
break;
|
||||
|
||||
case 1:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0)
|
||||
[wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types](hpx::shared_future<void *> param0)
|
||||
-> hpx::future<OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get()};
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
|
||||
output_sizes, output_types);
|
||||
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0]));
|
||||
@@ -114,11 +145,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
|
||||
case 2:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1)
|
||||
[wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1)
|
||||
-> hpx::future<OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get(), param1.get()};
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
|
||||
output_sizes, output_types);
|
||||
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
@@ -127,13 +160,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
|
||||
case 3:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, output_sizes](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2)
|
||||
[wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2)
|
||||
-> hpx::future<OpaqueOutputData> {
|
||||
std::vector<void *> params = {param0.get(), param1.get(),
|
||||
param2.get()};
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, output_sizes);
|
||||
OpaqueInputData oid(wfnname, params, param_sizes, param_types,
|
||||
output_sizes, output_types);
|
||||
return gcc[_dfr_find_next_execution_locality()].execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
@@ -642,25 +677,44 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
}
|
||||
}
|
||||
|
||||
/***************************/
|
||||
/* JIT execution support. */
|
||||
/***************************/
|
||||
static inline bool _dfr_is_root_node_impl() {
|
||||
static bool is_root_node_p = (hpx::find_here() == hpx::find_root_locality());
|
||||
return is_root_node_p;
|
||||
}
|
||||
|
||||
bool _dfr_is_root_node() { return _dfr_is_root_node_impl(); }
|
||||
|
||||
void _dfr_register_work_function(wfnptr wfn) {
|
||||
_dfr_node_level_work_function_registry->registerAnonymousWorkFunction(
|
||||
(void *)wfn);
|
||||
}
|
||||
|
||||
/********************************/
|
||||
/* Distributed key management. */
|
||||
/********************************/
|
||||
void _dfr_register_key(void *key, size_t key_id, size_t size) {
|
||||
node_level_key_manager->register_key(key, key_id, size);
|
||||
}
|
||||
|
||||
void _dfr_broadcast_keys() { node_level_key_manager->broadcast_keys(); }
|
||||
|
||||
void *_dfr_get_key(size_t key_id) {
|
||||
return *node_level_key_manager->get_key(key_id).key.get();
|
||||
}
|
||||
|
||||
/************************************/
|
||||
/* Initialization & Finalization. */
|
||||
/************************************/
|
||||
/* Runtime initialization and finalization. */
|
||||
// TODO: need to set a flag for when executing in JIT to allow remote
|
||||
// nodes to execute generated function that registers work functions.
|
||||
// This also means that compute nodes need to go through all the
|
||||
// phases of computation synchronized with the root node.
|
||||
static inline bool _dfr_is_jit_impl(bool is_jit = false) {
|
||||
static bool is_jit_p = is_jit;
|
||||
if (is_jit && !is_jit_p)
|
||||
is_jit_p = true;
|
||||
return is_jit_p;
|
||||
}
|
||||
|
||||
void _dfr_is_jit(bool is_jit) { _dfr_is_jit_impl(is_jit); }
|
||||
bool _dfr_is_jit() { return _dfr_is_jit_impl(); }
|
||||
|
||||
static inline void _dfr_stop_impl() {
|
||||
if (_dfr_is_root_node())
|
||||
if (_dfr_is_root_node_impl())
|
||||
hpx::apply([]() { hpx::finalize(); });
|
||||
hpx::stop();
|
||||
}
|
||||
@@ -701,10 +755,16 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
}
|
||||
|
||||
// Instantiate on each node
|
||||
new PbsKeyManager();
|
||||
new KeyManager<LweBootstrapKey_u64>();
|
||||
new KeyManager<LweKeyswitchKey_u64>();
|
||||
new WorkFunctionRegistry();
|
||||
_dfr_jit_workfunction_registration_barrier = new hpx::lcos::barrier(
|
||||
"wait_register_remote_work_functions", hpx::get_num_localities().get(),
|
||||
hpx::get_locality_id());
|
||||
_dfr_jit_phase_barrier = new hpx::lcos::barrier(
|
||||
"phase_barrier", hpx::get_num_localities().get(), hpx::get_locality_id());
|
||||
|
||||
if (_dfr_is_root_node()) {
|
||||
if (_dfr_is_root_node_impl()) {
|
||||
// Create compute server components on each node - from the root
|
||||
// node only - and the corresponding compute client on the root
|
||||
// node.
|
||||
@@ -712,9 +772,6 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
gcc = hpx::new_<GenericComputeClient[]>(
|
||||
hpx::default_layout(hpx::find_all_localities()), num_nodes)
|
||||
.get();
|
||||
} else {
|
||||
hpx::stop();
|
||||
exit(EXIT_SUCCESS);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -722,16 +779,36 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
JIT invocation). These serve to pause/resume the runtime
|
||||
scheduler and to clean up used resources. */
|
||||
void _dfr_start() {
|
||||
uint64_t uninitialised = 0;
|
||||
if (init_guard.compare_exchange_strong(uninitialised, 1))
|
||||
_dfr_start_impl(0, nullptr);
|
||||
else
|
||||
hpx::resume();
|
||||
hpx::resume();
|
||||
|
||||
if (!_dfr_is_jit())
|
||||
_dfr_stop_impl();
|
||||
|
||||
// TODO: conditional -- If this is the root node, and this is JIT
|
||||
// execution, we need to wait for the compute nodes to compile and
|
||||
// register work functions
|
||||
if (_dfr_is_root_node_impl() && _dfr_is_jit()) {
|
||||
_dfr_jit_workfunction_registration_barrier->wait();
|
||||
}
|
||||
}
|
||||
|
||||
void _dfr_stop() {
|
||||
if (!_dfr_is_root_node_impl() /*&& _dfr_is_jit() /** implicitly true*/) {
|
||||
_dfr_jit_workfunction_registration_barrier->wait();
|
||||
}
|
||||
|
||||
// The barrier is only needed to synchronize the different
|
||||
// computation phases when the compute nodes need to generate and
|
||||
// register new work functions in each phase.
|
||||
if (_dfr_is_jit()) {
|
||||
_dfr_jit_phase_barrier->wait();
|
||||
}
|
||||
|
||||
hpx::suspend();
|
||||
|
||||
_dfr_node_level_bsk_manager->clear_keys();
|
||||
_dfr_node_level_ksk_manager->clear_keys();
|
||||
|
||||
while (!new_allocated.empty()) {
|
||||
delete[] static_cast<char *>(new_allocated.front());
|
||||
new_allocated.pop_front();
|
||||
@@ -758,7 +835,7 @@ void _dfr_terminate() {
|
||||
/* Main wrapper. */
|
||||
/*******************/
|
||||
extern "C" {
|
||||
extern int main(int argc, char *argv[]); // __attribute__((weak));
|
||||
extern int main(int argc, char *argv[]) __attribute__((weak));
|
||||
extern int __real_main(int argc, char *argv[]) __attribute__((weak));
|
||||
int __wrap_main(int argc, char *argv[]) {
|
||||
int r;
|
||||
@@ -790,7 +867,7 @@ size_t _dfr_debug_get_node_id() { return hpx::get_locality_id(); }
|
||||
|
||||
size_t _dfr_debug_get_worker_id() { return hpx::get_worker_thread_num(); }
|
||||
|
||||
void _dfr_debug_print_task(const char *name, int inputs, int outputs) {
|
||||
void _dfr_debug_print_task(const char *name, size_t inputs, size_t outputs) {
|
||||
// clang-format off
|
||||
hpx::cout << "Task \"" << name << "\""
|
||||
<< " [" << inputs << " inputs, " << outputs << " outputs]"
|
||||
@@ -806,7 +883,10 @@ void _dfr_print_debug(size_t val) {
|
||||
|
||||
#else // CONCRETELANG_PARALLEL_EXECUTION_ENABLED
|
||||
|
||||
#include <concretelang/Runtime/runtime_api.h>
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
|
||||
bool _dfr_is_root_node() { return true; }
|
||||
void _dfr_is_jit(bool) {}
|
||||
void _dfr_terminate() {}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt
|
||||
// for license information.
|
||||
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/JITSupport.h>
|
||||
#include <llvm/Support/TargetSelect.h>
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
@@ -48,8 +49,11 @@ JITSupport::compile(llvm::SourceMgr &program, CompilationOptions options) {
|
||||
// Mark the lambda as compiled using DF parallelization
|
||||
result->lambda->setUseDataflow(options.dataflowParallelize ||
|
||||
options.autoParallelize);
|
||||
result->clientParameters =
|
||||
compilationResult.get().clientParameters.getValue();
|
||||
if (!mlir::concretelang::dfr::_dfr_is_root_node())
|
||||
result->clientParameters = clientlib::ClientParameters();
|
||||
else
|
||||
result->clientParameters =
|
||||
compilationResult.get().clientParameters.getValue();
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h>
|
||||
|
||||
#include "concretelang/Common/BitsSize.h"
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
#include <concretelang/Support/Error.h>
|
||||
#include <concretelang/Support/Jit.h>
|
||||
#include <concretelang/Support/logging.h>
|
||||
@@ -89,6 +90,19 @@ JITLambda::call(clientlib::PublicArguments &args,
|
||||
"call: current runtime doesn't support dataflow execution, while "
|
||||
"compilation used dataflow parallelization");
|
||||
}
|
||||
#else
|
||||
dfr::_dfr_set_jit(true);
|
||||
// When using JIT on distributed systems, the compiler only
|
||||
// generates work-functions and their registration calls. No results
|
||||
// are returned and no inputs are needed.
|
||||
if (!dfr::_dfr_is_root_node()) {
|
||||
std::vector<void *> rawArgs;
|
||||
if (auto err = invokeRaw(rawArgs)) {
|
||||
return std::move(err);
|
||||
}
|
||||
std::vector<clientlib::TensorData> buffers;
|
||||
return clientlib::PublicResult::fromBuffers(args.clientParameters, buffers);
|
||||
}
|
||||
#endif
|
||||
// invokeRaw needs to have pointers on arguments and a pointers on the result
|
||||
// as last argument.
|
||||
|
||||
@@ -31,7 +31,7 @@
|
||||
#include "concretelang/Dialect/RT/IR/RTDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHEDialect.h"
|
||||
#include "concretelang/Dialect/TFHE/IR/TFHETypes.h"
|
||||
#include "concretelang/Runtime/runtime_api.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Support/CompilerEngine.h"
|
||||
#include "concretelang/Support/Error.h"
|
||||
#include "concretelang/Support/JITSupport.h"
|
||||
|
||||
Reference in New Issue
Block a user