mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(dfr): add memory management for futures and associated data.
This commit is contained in:
@@ -19,7 +19,6 @@ extern "C" {
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
@@ -86,9 +85,15 @@ public:
|
||||
}
|
||||
|
||||
EvaluationKeys evaluationKeys() {
|
||||
auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0"));
|
||||
auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
|
||||
return EvaluationKeys(sharedKsk, sharedBsk);
|
||||
auto kskIt = this->keyswitchKeys.find("ksk_v0");
|
||||
auto bskIt = this->bootstrapKeys.find("bsk_v0");
|
||||
if (kskIt != this->keyswitchKeys.end() &&
|
||||
bskIt != this->bootstrapKeys.end()) {
|
||||
auto sharedKsk = std::get<1>(kskIt->second);
|
||||
auto sharedBsk = std::get<1>(bskIt->second);
|
||||
return EvaluationKeys(sharedKsk, sharedBsk);
|
||||
}
|
||||
return EvaluationKeys();
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
|
||||
@@ -23,6 +23,8 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createBufferizeDataflowTaskOpsPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createFixupBufferDeallocationPass(bool debug = false);
|
||||
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
|
||||
mlir::RewritePatternSet &patterns);
|
||||
void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter,
|
||||
|
||||
@@ -82,6 +82,15 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Prevent deallocation of buffers returned as futures by tasks.";
|
||||
|
||||
let description = [{ This pass removes buffer deallocation calls on
|
||||
buffers being used for dataflow communication between
|
||||
tasks. These buffers cannot be deallocated directly without
|
||||
synchronization as they can be needed by asynchronous
|
||||
computation. Instead, these will be deallocated by the runtime
|
||||
when no longer needed.}]; }
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
#ifndef CONCRETELANG_DIALECT_RT_IR_RTOPS_H
|
||||
#define CONCRETELANG_DIALECT_RT_IR_RTOPS_H
|
||||
|
||||
#include <mlir/Dialect/Bufferization/IR/AllocationOpInterface.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/DataLayoutInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#ifndef CONCRETELANG_DIALECT_RT_IR_RT_OPS
|
||||
#define CONCRETELANG_DIALECT_RT_IR_RT_OPS
|
||||
|
||||
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
|
||||
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/DataLayoutInterfaces.td"
|
||||
|
||||
@@ -15,9 +16,13 @@ class RT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
|
||||
def RT_DataflowTaskOp : RT_Op<"dataflow_task", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"DataflowYieldOp">]> {
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<AllocationOpInterface,
|
||||
["buildDealloc", "buildClone"]>,
|
||||
SingleBlockImplicitTerminator<"DataflowYieldOp">,
|
||||
AutomaticAllocationScope] > {
|
||||
let arguments = (ins Variadic<AnyType>: $inputs);
|
||||
let results = (outs Variadic<AnyType>:$outputs);
|
||||
let results = (outs Variadic<AnyType>: $outputs);
|
||||
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
@@ -85,8 +90,12 @@ Example:
|
||||
}];
|
||||
}
|
||||
|
||||
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future"> {
|
||||
let arguments = (ins AnyType: $input);
|
||||
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<AllocationOpInterface,
|
||||
["buildDealloc", "buildClone"]>]> {
|
||||
let arguments = (ins AnyType: $input,
|
||||
AnyType: $memrefCloned);
|
||||
let results = (outs RT_Future: $output);
|
||||
let summary = "Build a ready future.";
|
||||
let description = [{
|
||||
@@ -115,14 +124,27 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> {
|
||||
let summary = "Create a dataflow task.";
|
||||
}
|
||||
|
||||
def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
|
||||
def RT_RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
|
||||
let arguments = (ins Variadic<AnyType>:$list);
|
||||
let results = (outs );
|
||||
let summary = "Register the task work-function with the runtime system.";
|
||||
}
|
||||
|
||||
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
|
||||
def RT_CloneFutureOp : RT_Op<"clone_future",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<AllocationOpInterface,
|
||||
["buildDealloc", "buildClone"]>] > {
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value": $input), [{
|
||||
return build($_builder, $_state, input.getType(), input);
|
||||
}]>];
|
||||
|
||||
let arguments = (ins RT_Future: $input);
|
||||
let results = (outs RT_Future: $output);
|
||||
}
|
||||
|
||||
def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> {
|
||||
let arguments = (ins AnyType: $input);
|
||||
let results = (outs );
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
#include <malloc.h>
|
||||
#include <string>
|
||||
|
||||
#include <hpx/async_colocated/get_colocation_id.hpp>
|
||||
@@ -28,6 +29,7 @@
|
||||
|
||||
#include <mlir/ExecutionEngine/CRunnerUtils.h>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Runtime/dfr_debug_interface.h"
|
||||
@@ -49,6 +51,17 @@ static inline size_t _dfr_get_memref_rank(size_t size) {
|
||||
(2 * sizeof(int64_t) /*size&stride/rank*/);
|
||||
}
|
||||
|
||||
static inline void _dfr_checked_aligned_alloc(void **out, size_t align,
|
||||
size_t size) {
|
||||
int res = posix_memalign(out, align, size);
|
||||
if (res == ENOMEM)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
|
||||
"Error: insufficient memory available.");
|
||||
if (res == EINVAL)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
|
||||
"Error: invalid memory alignment.");
|
||||
}
|
||||
|
||||
struct OpaqueInputData {
|
||||
OpaqueInputData() = default;
|
||||
|
||||
@@ -62,7 +75,7 @@ struct OpaqueInputData {
|
||||
param_types(std::move(_param_types)),
|
||||
output_sizes(std::move(_output_sizes)),
|
||||
output_types(std::move(_output_types)), alloc_p(_alloc_p),
|
||||
source_locality(hpx::find_here()) {}
|
||||
source_locality(hpx::find_here()), ksk_id(0), bsk_id(0) {}
|
||||
|
||||
OpaqueInputData(const OpaqueInputData &oid)
|
||||
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
|
||||
@@ -70,16 +83,18 @@ struct OpaqueInputData {
|
||||
param_types(std::move(oid.param_types)),
|
||||
output_sizes(std::move(oid.output_sizes)),
|
||||
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
|
||||
source_locality(oid.source_locality) {}
|
||||
source_locality(oid.source_locality), ksk_id(oid.ksk_id),
|
||||
bsk_id(oid.bsk_id) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar >> wfn_name;
|
||||
ar >> param_sizes >> param_types;
|
||||
ar >> output_sizes >> output_types;
|
||||
ar >> source_locality;
|
||||
for (size_t p = 0; p < param_sizes.size(); ++p) {
|
||||
char *param = new char[param_sizes[p]];
|
||||
new_allocated.push_back((void *)param);
|
||||
char *param;
|
||||
_dfr_checked_aligned_alloc((void **)¶m, 64, param_sizes[p]);
|
||||
ar >> hpx::serialization::make_array(param, param_sizes[p]);
|
||||
params.push_back((void *)param);
|
||||
|
||||
@@ -95,28 +110,23 @@ struct OpaqueInputData {
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
size_t alloc_size = (size + mref.offset) * elementSize;
|
||||
char *data = new char[alloc_size];
|
||||
new_allocated.push_back((void *)data);
|
||||
char *data;
|
||||
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
|
||||
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
|
||||
size * elementSize);
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
uint64_t bsk_id, ksk_id;
|
||||
ar >> bsk_id >> ksk_id >> source_locality;
|
||||
ar >> bsk_id >> ksk_id;
|
||||
|
||||
mlir::concretelang::RuntimeContext *context =
|
||||
new mlir::concretelang::RuntimeContext;
|
||||
new_allocated.push_back((void *)context);
|
||||
mlir::concretelang::RuntimeContext **_context =
|
||||
new mlir::concretelang::RuntimeContext *[1];
|
||||
new_allocated.push_back((void *)_context);
|
||||
_context[0] = context;
|
||||
|
||||
context->bsk = (LweBootstrapKey_u64 *)bsk_id;
|
||||
context->ksk = (LweKeyswitchKey_u64 *)ksk_id;
|
||||
params[p] = (void *)_context;
|
||||
delete ((char *)params[p]);
|
||||
// TODO: this might be relaxed with newer versions of HPX.
|
||||
// Do not set the context here as remote operations are
|
||||
// unstable when initiated within a HPX helper thread.
|
||||
params[p] =
|
||||
(void *)
|
||||
_dfr_node_level_runtime_context_manager->getContextAddress();
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
@@ -131,6 +141,7 @@ struct OpaqueInputData {
|
||||
ar << wfn_name;
|
||||
ar << param_sizes << param_types;
|
||||
ar << output_sizes << output_types;
|
||||
ar << source_locality;
|
||||
for (size_t p = 0; p < params.size(); ++p) {
|
||||
// Save the first level of the data structure - if the parameter
|
||||
// is a tensor/memref, there is a second level.
|
||||
@@ -152,18 +163,16 @@ struct OpaqueInputData {
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
mlir::concretelang::RuntimeContext *context =
|
||||
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
|
||||
LweKeyswitchKey_u64 *ksk = context->ksk;
|
||||
LweBootstrapKey_u64 *bsk = context->bsk;
|
||||
|
||||
// TODO: find better unique identifiers. This is not a
|
||||
// correctness issue, but performance.
|
||||
uint64_t bsk_id = (uint64_t)bsk;
|
||||
uint64_t ksk_id = (uint64_t)ksk;
|
||||
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
|
||||
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
|
||||
|
||||
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
|
||||
_dfr_register_bsk(bsk, bsk_id);
|
||||
_dfr_register_ksk(ksk, ksk_id);
|
||||
ar << bsk_id << ksk_id << source_locality;
|
||||
std::cout << "Registering Key ids " << (uint64_t)ksk << " "
|
||||
<< (uint64_t)bsk << "\n"
|
||||
<< std::flush;
|
||||
_dfr_register_bsk(bsk, (uint64_t)bsk);
|
||||
_dfr_register_ksk(ksk, (uint64_t)ksk);
|
||||
ar << (uint64_t)bsk << (uint64_t)ksk;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
@@ -182,6 +191,8 @@ struct OpaqueInputData {
|
||||
std::vector<uint64_t> output_types;
|
||||
bool alloc_p = false;
|
||||
hpx::naming::id_type source_locality;
|
||||
uint64_t ksk_id;
|
||||
uint64_t bsk_id;
|
||||
};
|
||||
|
||||
struct OpaqueOutputData {
|
||||
@@ -200,8 +211,9 @@ struct OpaqueOutputData {
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar >> output_sizes >> output_types;
|
||||
for (size_t p = 0; p < output_sizes.size(); ++p) {
|
||||
char *output = new char[output_sizes[p]];
|
||||
new_allocated.push_back((void *)output);
|
||||
char *output;
|
||||
_dfr_checked_aligned_alloc((void **)&output, 64, (output_sizes[p]));
|
||||
|
||||
ar >> hpx::serialization::make_array(output, output_sizes[p]);
|
||||
outputs.push_back((void *)output);
|
||||
|
||||
@@ -217,8 +229,8 @@ struct OpaqueOutputData {
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
size_t alloc_size = (size + mref.offset) * elementSize;
|
||||
char *data = new char[alloc_size];
|
||||
new_allocated.push_back((void *)data);
|
||||
char *data;
|
||||
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
|
||||
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
|
||||
size * elementSize);
|
||||
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
|
||||
@@ -283,23 +295,20 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.wfn_name);
|
||||
std::vector<void *> outputs;
|
||||
|
||||
if (!_dfr_is_root_node()) {
|
||||
for (size_t p = 0; p < inputs.params.size(); ++p) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) {
|
||||
mlir::concretelang::RuntimeContext *ctx =
|
||||
(*(mlir::concretelang::RuntimeContext **)inputs.params[p]);
|
||||
ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk);
|
||||
ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (inputs.source_locality != hpx::find_here() &&
|
||||
(inputs.ksk_id || inputs.bsk_id)) {
|
||||
_dfr_node_level_runtime_context_manager->getContext(
|
||||
inputs.ksk_id, inputs.bsk_id, inputs.source_locality);
|
||||
}
|
||||
|
||||
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
|
||||
inputs.output_sizes.size());
|
||||
hpx::cout << std::flush;
|
||||
|
||||
switch (inputs.output_sizes.size()) {
|
||||
case 1: {
|
||||
void *output = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output;
|
||||
_dfr_checked_aligned_alloc(&output, 512, inputs.output_sizes[0]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output);
|
||||
@@ -387,18 +396,52 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output};
|
||||
new_allocated.push_back(output);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
void *output1, *output2;
|
||||
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
|
||||
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2);
|
||||
@@ -491,20 +534,54 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output1, output2);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output1, output2);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output1,
|
||||
output2);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output1, output2);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output1, output2);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output1, output2};
|
||||
new_allocated.push_back(output1);
|
||||
new_allocated.push_back(output2);
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
void *output3 = (void *)(new char[inputs.output_sizes[2]]);
|
||||
void *output1, *output2, *output3;
|
||||
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
|
||||
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
|
||||
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[2]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2, output3);
|
||||
@@ -597,15 +674,47 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output1, output2, output3);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output1, output2, output3);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output1,
|
||||
output2, output3);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output1, output2, output3);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output1, output2, output3);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output1, output2, output3};
|
||||
new_allocated.push_back(output1);
|
||||
new_allocated.push_back(output2);
|
||||
new_allocated.push_back(output3);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -613,6 +722,18 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
"Error: number of task outputs not supported.");
|
||||
}
|
||||
|
||||
// Deallocate input data buffers from OID deserialization (load)
|
||||
if (!_dfr_is_root_node()) {
|
||||
for (size_t p = 0; p < inputs.param_sizes.size(); ++p) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
|
||||
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
|
||||
->data);
|
||||
delete ((char *)inputs.params[p]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
|
||||
std::move(inputs.output_types), inputs.alloc_p);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <hpx/modules/serialization.hpp>
|
||||
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
@@ -25,13 +26,12 @@ namespace concretelang {
|
||||
namespace dfr {
|
||||
|
||||
template <typename T> struct KeyManager;
|
||||
struct RuntimeContextManager;
|
||||
namespace {
|
||||
static void *dl_handle;
|
||||
static KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
|
||||
static KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
|
||||
static std::list<void *> new_allocated;
|
||||
static std::list<void *> fut_allocated;
|
||||
static std::list<void *> m_allocated;
|
||||
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
|
||||
} // namespace
|
||||
|
||||
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
|
||||
@@ -66,7 +66,6 @@ void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
new_allocated.push_back((void *)pointer);
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_bootstrap_key_u64(buffer);
|
||||
@@ -87,7 +86,6 @@ void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
new_allocated.push_back((void *)pointer);
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_keyswitching_key_u64(buffer);
|
||||
@@ -224,6 +222,68 @@ LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
|
||||
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
|
||||
}
|
||||
|
||||
/************************/
|
||||
/* Context management. */
|
||||
/************************/
|
||||
|
||||
struct RuntimeContextManager {
|
||||
// TODO: this is only ok so long as we don't change keys. Once we
|
||||
// use multiple keys, should have a map.
|
||||
RuntimeContext *context;
|
||||
std::mutex context_guard;
|
||||
uint64_t ksk_id;
|
||||
uint64_t bsk_id;
|
||||
|
||||
RuntimeContextManager() {
|
||||
ksk_id = 0;
|
||||
bsk_id = 0;
|
||||
context = nullptr;
|
||||
_dfr_node_level_runtime_context_manager = this;
|
||||
}
|
||||
|
||||
RuntimeContext *getContext(uint64_t ksk, uint64_t bsk,
|
||||
hpx::naming::id_type source_locality) {
|
||||
std::cout << "GetContext on node " << hpx::get_locality_id()
|
||||
<< " with context " << context << " " << bsk_id << " " << ksk_id
|
||||
<< "\n"
|
||||
<< std::flush;
|
||||
if (context != nullptr) {
|
||||
std::cout << "simil " << ksk_id << " " << ksk << " " << bsk_id << " "
|
||||
<< bsk << "\n"
|
||||
<< std::flush;
|
||||
assert(ksk == ksk_id && bsk == bsk_id &&
|
||||
"Context manager can only used with single keys for now.");
|
||||
} else {
|
||||
assert(ksk_id == 0 && bsk_id == 0 &&
|
||||
"Context empty but context manager has key ids.");
|
||||
LweKeyswitchKey_u64 *keySwitchKey = _dfr_get_ksk(source_locality, ksk);
|
||||
LweBootstrapKey_u64 *bootstrapKey = _dfr_get_bsk(source_locality, bsk);
|
||||
std::lock_guard<std::mutex> guard(context_guard);
|
||||
if (context == nullptr) {
|
||||
auto ctx = new RuntimeContext();
|
||||
ctx->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
|
||||
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
|
||||
new ::concretelang::clientlib::LweKeyswitchKey(keySwitchKey)),
|
||||
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
|
||||
new ::concretelang::clientlib::LweBootstrapKey(bootstrapKey)));
|
||||
ksk_id = ksk;
|
||||
bsk_id = bsk;
|
||||
context = ctx;
|
||||
std::cout << "Fetching Key ids " << ksk_id << " " << bsk_id << "\n"
|
||||
<< std::flush;
|
||||
} else {
|
||||
std::cout << " GOT context after LOCK on node "
|
||||
<< hpx::get_locality_id() << " with context " << context
|
||||
<< " " << bsk_id << " " << ksk_id << "\n"
|
||||
<< std::flush;
|
||||
}
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
RuntimeContext **getContextAddress() { return &context; }
|
||||
};
|
||||
|
||||
} // namespace dfr
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -14,7 +14,7 @@ extern "C" {
|
||||
|
||||
typedef void (*wfnptr)(...);
|
||||
|
||||
void *_dfr_make_ready_future(void *);
|
||||
void *_dfr_make_ready_future(void *, size_t);
|
||||
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
|
||||
void _dfr_register_work_function(wfnptr);
|
||||
void *_dfr_await_future(void *);
|
||||
|
||||
Reference in New Issue
Block a user