feat(dfr): add memory management for futures and associated data.

This commit is contained in:
Antoniu Pop
2022-06-08 21:58:33 +01:00
committed by Antoniu Pop
parent b405a2daf2
commit fbca52f4a0
14 changed files with 1047 additions and 348 deletions

View File

@@ -19,7 +19,6 @@ extern "C" {
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
#include <concretelang/Runtime/DFRuntime.hpp>
namespace concretelang {
namespace clientlib {
@@ -86,9 +85,15 @@ public:
}
EvaluationKeys evaluationKeys() {
auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0"));
auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
return EvaluationKeys(sharedKsk, sharedBsk);
auto kskIt = this->keyswitchKeys.find("ksk_v0");
auto bskIt = this->bootstrapKeys.find("bsk_v0");
if (kskIt != this->keyswitchKeys.end() &&
bskIt != this->bootstrapKeys.end()) {
auto sharedKsk = std::get<1>(kskIt->second);
auto sharedBsk = std::get<1>(bskIt->second);
return EvaluationKeys(sharedKsk, sharedBsk);
}
return EvaluationKeys();
}
const std::map<LweSecretKeyID,

View File

@@ -23,6 +23,8 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createBufferizeDataflowTaskOpsPass(bool debug = false);
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createFixupBufferDeallocationPass(bool debug = false);
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
mlir::RewritePatternSet &patterns);
void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter,

View File

@@ -82,6 +82,15 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
}];
}
def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> {
let summary =
"Prevent deallocation of buffers returned as futures by tasks.";
let description = [{ This pass removes buffer deallocation calls on
buffers being used for dataflow communication between
tasks. These buffers cannot be deallocated directly without
synchronization as they can be needed by asynchronous
computation. Instead, these will be deallocated by the runtime
when no longer needed.}]; }
#endif

View File

@@ -6,10 +6,12 @@
#ifndef CONCRETELANG_DIALECT_RT_IR_RTOPS_H
#define CONCRETELANG_DIALECT_RT_IR_RTOPS_H
#include <mlir/Dialect/Bufferization/IR/AllocationOpInterface.h>
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/DataLayoutInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "concretelang/Dialect/RT/IR/RTTypes.h"

View File

@@ -1,9 +1,10 @@
#ifndef CONCRETELANG_DIALECT_RT_IR_RT_OPS
#define CONCRETELANG_DIALECT_RT_IR_RT_OPS
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -15,9 +16,13 @@ class RT_Op<string mnemonic, list<Trait> traits = []> :
def RT_DataflowTaskOp : RT_Op<"dataflow_task", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"DataflowYieldOp">]> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<AllocationOpInterface,
["buildDealloc", "buildClone"]>,
SingleBlockImplicitTerminator<"DataflowYieldOp">,
AutomaticAllocationScope] > {
let arguments = (ins Variadic<AnyType>: $inputs);
let results = (outs Variadic<AnyType>:$outputs);
let results = (outs Variadic<AnyType>: $outputs);
let regions = (region AnyRegion:$body);
@@ -85,8 +90,12 @@ Example:
}];
}
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future"> {
let arguments = (ins AnyType: $input);
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<AllocationOpInterface,
["buildDealloc", "buildClone"]>]> {
let arguments = (ins AnyType: $input,
AnyType: $memrefCloned);
let results = (outs RT_Future: $output);
let summary = "Build a ready future.";
let description = [{
@@ -115,14 +124,27 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> {
let summary = "Create a dataflow task.";
}
def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
def RT_RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
let arguments = (ins Variadic<AnyType>:$list);
let results = (outs );
let summary = "Register the task work-function with the runtime system.";
}
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
def RT_CloneFutureOp : RT_Op<"clone_future",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<AllocationOpInterface,
["buildDealloc", "buildClone"]>] > {
let builders = [
OpBuilder<(ins "Value": $input), [{
return build($_builder, $_state, input.getType(), input);
}]>];
let arguments = (ins RT_Future: $input);
let results = (outs RT_Future: $output);
}
def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> {
let arguments = (ins AnyType: $input);
let results = (outs );
}

View File

@@ -8,6 +8,7 @@
#include <cstdarg>
#include <cstdlib>
#include <malloc.h>
#include <string>
#include <hpx/async_colocated/get_colocation_id.hpp>
@@ -28,6 +29,7 @@
#include <mlir/ExecutionEngine/CRunnerUtils.h>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/Runtime/dfr_debug_interface.h"
@@ -49,6 +51,17 @@ static inline size_t _dfr_get_memref_rank(size_t size) {
(2 * sizeof(int64_t) /*size&stride/rank*/);
}
static inline void _dfr_checked_aligned_alloc(void **out, size_t align,
size_t size) {
int res = posix_memalign(out, align, size);
if (res == ENOMEM)
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
"Error: insufficient memory available.");
if (res == EINVAL)
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
"Error: invalid memory alignment.");
}
struct OpaqueInputData {
OpaqueInputData() = default;
@@ -62,7 +75,7 @@ struct OpaqueInputData {
param_types(std::move(_param_types)),
output_sizes(std::move(_output_sizes)),
output_types(std::move(_output_types)), alloc_p(_alloc_p),
source_locality(hpx::find_here()) {}
source_locality(hpx::find_here()), ksk_id(0), bsk_id(0) {}
OpaqueInputData(const OpaqueInputData &oid)
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
@@ -70,16 +83,18 @@ struct OpaqueInputData {
param_types(std::move(oid.param_types)),
output_sizes(std::move(oid.output_sizes)),
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
source_locality(oid.source_locality) {}
source_locality(oid.source_locality), ksk_id(oid.ksk_id),
bsk_id(oid.bsk_id) {}
friend class hpx::serialization::access;
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar >> wfn_name;
ar >> param_sizes >> param_types;
ar >> output_sizes >> output_types;
ar >> source_locality;
for (size_t p = 0; p < param_sizes.size(); ++p) {
char *param = new char[param_sizes[p]];
new_allocated.push_back((void *)param);
char *param;
_dfr_checked_aligned_alloc((void **)&param, 64, param_sizes[p]);
ar >> hpx::serialization::make_array(param, param_sizes[p]);
params.push_back((void *)param);
@@ -95,28 +110,23 @@ struct OpaqueInputData {
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
size_t alloc_size = (size + mref.offset) * elementSize;
char *data = new char[alloc_size];
new_allocated.push_back((void *)data);
char *data;
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
size * elementSize);
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
} break;
case _DFR_TASK_ARG_CONTEXT: {
uint64_t bsk_id, ksk_id;
ar >> bsk_id >> ksk_id >> source_locality;
ar >> bsk_id >> ksk_id;
mlir::concretelang::RuntimeContext *context =
new mlir::concretelang::RuntimeContext;
new_allocated.push_back((void *)context);
mlir::concretelang::RuntimeContext **_context =
new mlir::concretelang::RuntimeContext *[1];
new_allocated.push_back((void *)_context);
_context[0] = context;
context->bsk = (LweBootstrapKey_u64 *)bsk_id;
context->ksk = (LweKeyswitchKey_u64 *)ksk_id;
params[p] = (void *)_context;
delete ((char *)params[p]);
// TODO: this might be relaxed with newer versions of HPX.
// Do not set the context here as remote operations are
// unstable when initiated within a HPX helper thread.
params[p] =
(void *)
_dfr_node_level_runtime_context_manager->getContextAddress();
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
@@ -131,6 +141,7 @@ struct OpaqueInputData {
ar << wfn_name;
ar << param_sizes << param_types;
ar << output_sizes << output_types;
ar << source_locality;
for (size_t p = 0; p < params.size(); ++p) {
// Save the first level of the data structure - if the parameter
// is a tensor/memref, there is a second level.
@@ -152,18 +163,16 @@ struct OpaqueInputData {
case _DFR_TASK_ARG_CONTEXT: {
mlir::concretelang::RuntimeContext *context =
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
LweKeyswitchKey_u64 *ksk = context->ksk;
LweBootstrapKey_u64 *bsk = context->bsk;
// TODO: find better unique identifiers. This is not a
// correctness issue, but performance.
uint64_t bsk_id = (uint64_t)bsk;
uint64_t ksk_id = (uint64_t)ksk;
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
_dfr_register_bsk(bsk, bsk_id);
_dfr_register_ksk(ksk, ksk_id);
ar << bsk_id << ksk_id << source_locality;
std::cout << "Registering Key ids " << (uint64_t)ksk << " "
<< (uint64_t)bsk << "\n"
<< std::flush;
_dfr_register_bsk(bsk, (uint64_t)bsk);
_dfr_register_ksk(ksk, (uint64_t)ksk);
ar << (uint64_t)bsk << (uint64_t)ksk;
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
@@ -182,6 +191,8 @@ struct OpaqueInputData {
std::vector<uint64_t> output_types;
bool alloc_p = false;
hpx::naming::id_type source_locality;
uint64_t ksk_id;
uint64_t bsk_id;
};
struct OpaqueOutputData {
@@ -200,8 +211,9 @@ struct OpaqueOutputData {
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar >> output_sizes >> output_types;
for (size_t p = 0; p < output_sizes.size(); ++p) {
char *output = new char[output_sizes[p]];
new_allocated.push_back((void *)output);
char *output;
_dfr_checked_aligned_alloc((void **)&output, 64, (output_sizes[p]));
ar >> hpx::serialization::make_array(output, output_sizes[p]);
outputs.push_back((void *)output);
@@ -217,8 +229,8 @@ struct OpaqueOutputData {
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
size_t alloc_size = (size + mref.offset) * elementSize;
char *data = new char[alloc_size];
new_allocated.push_back((void *)data);
char *data;
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
size * elementSize);
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
@@ -283,23 +295,20 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.wfn_name);
std::vector<void *> outputs;
if (!_dfr_is_root_node()) {
for (size_t p = 0; p < inputs.params.size(); ++p) {
if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) {
mlir::concretelang::RuntimeContext *ctx =
(*(mlir::concretelang::RuntimeContext **)inputs.params[p]);
ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk);
ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk);
break;
}
}
if (inputs.source_locality != hpx::find_here() &&
(inputs.ksk_id || inputs.bsk_id)) {
_dfr_node_level_runtime_context_manager->getContext(
inputs.ksk_id, inputs.bsk_id, inputs.source_locality);
}
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
inputs.output_sizes.size());
hpx::cout << std::flush;
switch (inputs.output_sizes.size()) {
case 1: {
void *output = (void *)(new char[inputs.output_sizes[0]]);
void *output;
_dfr_checked_aligned_alloc(&output, 512, inputs.output_sizes[0]);
switch (inputs.params.size()) {
case 0:
wfn(output);
@@ -387,18 +396,52 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output};
new_allocated.push_back(output);
break;
}
case 2: {
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
void *output1, *output2;
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
switch (inputs.params.size()) {
case 0:
wfn(output1, output2);
@@ -491,20 +534,54 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output1, output2);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output1, output2);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output1,
output2);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output1, output2);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output1, output2);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output1, output2};
new_allocated.push_back(output1);
new_allocated.push_back(output2);
break;
}
case 3: {
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
void *output3 = (void *)(new char[inputs.output_sizes[2]]);
void *output1, *output2, *output3;
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[2]);
switch (inputs.params.size()) {
case 0:
wfn(output1, output2, output3);
@@ -597,15 +674,47 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output1, output2, output3);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output1, output2, output3);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output1,
output2, output3);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output1, output2, output3);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output1, output2, output3);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output1, output2, output3};
new_allocated.push_back(output1);
new_allocated.push_back(output2);
new_allocated.push_back(output3);
break;
}
default:
@@ -613,6 +722,18 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
"Error: number of task outputs not supported.");
}
// Deallocate input data buffers from OID deserialization (load)
if (!_dfr_is_root_node()) {
for (size_t p = 0; p < inputs.param_sizes.size(); ++p) {
if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) {
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
->data);
delete ((char *)inputs.params[p]);
}
}
}
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
std::move(inputs.output_types), inputs.alloc_p);
}

View File

@@ -15,6 +15,7 @@
#include <hpx/modules/serialization.hpp>
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
extern "C" {
#include "concrete-ffi.h"
@@ -25,13 +26,12 @@ namespace concretelang {
namespace dfr {
template <typename T> struct KeyManager;
struct RuntimeContextManager;
namespace {
static void *dl_handle;
static KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
static KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
static std::list<void *> new_allocated;
static std::list<void *> fut_allocated;
static std::list<void *> m_allocated;
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
@@ -66,7 +66,6 @@ void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
new_allocated.push_back((void *)pointer);
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_bootstrap_key_u64(buffer);
@@ -87,7 +86,6 @@ void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
new_allocated.push_back((void *)pointer);
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_keyswitching_key_u64(buffer);
@@ -224,6 +222,68 @@ LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
}
/************************/
/* Context management. */
/************************/
struct RuntimeContextManager {
// TODO: this is only ok so long as we don't change keys. Once we
// use multiple keys, should have a map.
RuntimeContext *context;
std::mutex context_guard;
uint64_t ksk_id;
uint64_t bsk_id;
RuntimeContextManager() {
ksk_id = 0;
bsk_id = 0;
context = nullptr;
_dfr_node_level_runtime_context_manager = this;
}
RuntimeContext *getContext(uint64_t ksk, uint64_t bsk,
hpx::naming::id_type source_locality) {
std::cout << "GetContext on node " << hpx::get_locality_id()
<< " with context " << context << " " << bsk_id << " " << ksk_id
<< "\n"
<< std::flush;
if (context != nullptr) {
std::cout << "simil " << ksk_id << " " << ksk << " " << bsk_id << " "
<< bsk << "\n"
<< std::flush;
assert(ksk == ksk_id && bsk == bsk_id &&
"Context manager can only used with single keys for now.");
} else {
assert(ksk_id == 0 && bsk_id == 0 &&
"Context empty but context manager has key ids.");
LweKeyswitchKey_u64 *keySwitchKey = _dfr_get_ksk(source_locality, ksk);
LweBootstrapKey_u64 *bootstrapKey = _dfr_get_bsk(source_locality, bsk);
std::lock_guard<std::mutex> guard(context_guard);
if (context == nullptr) {
auto ctx = new RuntimeContext();
ctx->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
new ::concretelang::clientlib::LweKeyswitchKey(keySwitchKey)),
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
new ::concretelang::clientlib::LweBootstrapKey(bootstrapKey)));
ksk_id = ksk;
bsk_id = bsk;
context = ctx;
std::cout << "Fetching Key ids " << ksk_id << " " << bsk_id << "\n"
<< std::flush;
} else {
std::cout << " GOT context after LOCK on node "
<< hpx::get_locality_id() << " with context " << context
<< " " << bsk_id << " " << ksk_id << "\n"
<< std::flush;
}
}
return context;
}
RuntimeContext **getContextAddress() { return &context; }
};
} // namespace dfr
} // namespace concretelang
} // namespace mlir

View File

@@ -14,7 +14,7 @@ extern "C" {
typedef void (*wfnptr)(...);
void *_dfr_make_ready_future(void *);
void *_dfr_make_ready_future(void *, size_t);
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
void _dfr_register_work_function(wfnptr);
void *_dfr_await_future(void *);