feat(dfr): add memory management for futures and associated data.

This commit is contained in:
Antoniu Pop
2022-06-08 21:58:33 +01:00
committed by Antoniu Pop
parent b405a2daf2
commit fbca52f4a0
14 changed files with 1047 additions and 348 deletions

View File

@@ -19,7 +19,6 @@ extern "C" {
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/ClientLib/KeySetCache.h"
#include "concretelang/Common/Error.h"
#include <concretelang/Runtime/DFRuntime.hpp>
namespace concretelang {
namespace clientlib {
@@ -86,9 +85,15 @@ public:
}
EvaluationKeys evaluationKeys() {
auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0"));
auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
return EvaluationKeys(sharedKsk, sharedBsk);
auto kskIt = this->keyswitchKeys.find("ksk_v0");
auto bskIt = this->bootstrapKeys.find("bsk_v0");
if (kskIt != this->keyswitchKeys.end() &&
bskIt != this->bootstrapKeys.end()) {
auto sharedKsk = std::get<1>(kskIt->second);
auto sharedBsk = std::get<1>(bskIt->second);
return EvaluationKeys(sharedKsk, sharedBsk);
}
return EvaluationKeys();
}
const std::map<LweSecretKeyID,

View File

@@ -23,6 +23,8 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createBufferizeDataflowTaskOpsPass(bool debug = false);
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
std::unique_ptr<mlir::Pass>
createFixupBufferDeallocationPass(bool debug = false);
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
mlir::RewritePatternSet &patterns);
void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter,

View File

@@ -82,6 +82,15 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
}];
}
def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> {
let summary =
"Prevent deallocation of buffers returned as futures by tasks.";
let description = [{ This pass removes buffer deallocation calls on
buffers being used for dataflow communication between
tasks. These buffers cannot be deallocated directly without
synchronization as they can be needed by asynchronous
computation. Instead, these will be deallocated by the runtime
when no longer needed.}]; }
#endif

View File

@@ -6,10 +6,12 @@
#ifndef CONCRETELANG_DIALECT_RT_IR_RTOPS_H
#define CONCRETELANG_DIALECT_RT_IR_RTOPS_H
#include <mlir/Dialect/Bufferization/IR/AllocationOpInterface.h>
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
#include <mlir/IR/BuiltinOps.h>
#include <mlir/IR/BuiltinTypes.h>
#include <mlir/Interfaces/ControlFlowInterfaces.h>
#include <mlir/Interfaces/DataLayoutInterfaces.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include "concretelang/Dialect/RT/IR/RTTypes.h"

View File

@@ -1,9 +1,10 @@
#ifndef CONCRETELANG_DIALECT_RT_IR_RT_OPS
#define CONCRETELANG_DIALECT_RT_IR_RT_OPS
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/DataLayoutInterfaces.td"
@@ -15,9 +16,13 @@ class RT_Op<string mnemonic, list<Trait> traits = []> :
def RT_DataflowTaskOp : RT_Op<"dataflow_task", [
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
SingleBlockImplicitTerminator<"DataflowYieldOp">]> {
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<AllocationOpInterface,
["buildDealloc", "buildClone"]>,
SingleBlockImplicitTerminator<"DataflowYieldOp">,
AutomaticAllocationScope] > {
let arguments = (ins Variadic<AnyType>: $inputs);
let results = (outs Variadic<AnyType>:$outputs);
let results = (outs Variadic<AnyType>: $outputs);
let regions = (region AnyRegion:$body);
@@ -85,8 +90,12 @@ Example:
}];
}
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future"> {
let arguments = (ins AnyType: $input);
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<AllocationOpInterface,
["buildDealloc", "buildClone"]>]> {
let arguments = (ins AnyType: $input,
AnyType: $memrefCloned);
let results = (outs RT_Future: $output);
let summary = "Build a ready future.";
let description = [{
@@ -115,14 +124,27 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> {
let summary = "Create a dataflow task.";
}
def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
def RT_RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
let arguments = (ins Variadic<AnyType>:$list);
let results = (outs );
let summary = "Register the task work-function with the runtime system.";
}
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
def RT_CloneFutureOp : RT_Op<"clone_future",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<AllocationOpInterface,
["buildDealloc", "buildClone"]>] > {
let builders = [
OpBuilder<(ins "Value": $input), [{
return build($_builder, $_state, input.getType(), input);
}]>];
let arguments = (ins RT_Future: $input);
let results = (outs RT_Future: $output);
}
def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> {
let arguments = (ins AnyType: $input);
let results = (outs );
}

View File

@@ -8,6 +8,7 @@
#include <cstdarg>
#include <cstdlib>
#include <malloc.h>
#include <string>
#include <hpx/async_colocated/get_colocation_id.hpp>
@@ -28,6 +29,7 @@
#include <mlir/ExecutionEngine/CRunnerUtils.h>
#include "concretelang/ClientLib/EvaluationKeys.h"
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
#include "concretelang/Runtime/dfr_debug_interface.h"
@@ -49,6 +51,17 @@ static inline size_t _dfr_get_memref_rank(size_t size) {
(2 * sizeof(int64_t) /*size&stride/rank*/);
}
static inline void _dfr_checked_aligned_alloc(void **out, size_t align,
size_t size) {
int res = posix_memalign(out, align, size);
if (res == ENOMEM)
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
"Error: insufficient memory available.");
if (res == EINVAL)
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
"Error: invalid memory alignment.");
}
struct OpaqueInputData {
OpaqueInputData() = default;
@@ -62,7 +75,7 @@ struct OpaqueInputData {
param_types(std::move(_param_types)),
output_sizes(std::move(_output_sizes)),
output_types(std::move(_output_types)), alloc_p(_alloc_p),
source_locality(hpx::find_here()) {}
source_locality(hpx::find_here()), ksk_id(0), bsk_id(0) {}
OpaqueInputData(const OpaqueInputData &oid)
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
@@ -70,16 +83,18 @@ struct OpaqueInputData {
param_types(std::move(oid.param_types)),
output_sizes(std::move(oid.output_sizes)),
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
source_locality(oid.source_locality) {}
source_locality(oid.source_locality), ksk_id(oid.ksk_id),
bsk_id(oid.bsk_id) {}
friend class hpx::serialization::access;
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar >> wfn_name;
ar >> param_sizes >> param_types;
ar >> output_sizes >> output_types;
ar >> source_locality;
for (size_t p = 0; p < param_sizes.size(); ++p) {
char *param = new char[param_sizes[p]];
new_allocated.push_back((void *)param);
char *param;
_dfr_checked_aligned_alloc((void **)&param, 64, param_sizes[p]);
ar >> hpx::serialization::make_array(param, param_sizes[p]);
params.push_back((void *)param);
@@ -95,28 +110,23 @@ struct OpaqueInputData {
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
size_t alloc_size = (size + mref.offset) * elementSize;
char *data = new char[alloc_size];
new_allocated.push_back((void *)data);
char *data;
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
size * elementSize);
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
} break;
case _DFR_TASK_ARG_CONTEXT: {
uint64_t bsk_id, ksk_id;
ar >> bsk_id >> ksk_id >> source_locality;
ar >> bsk_id >> ksk_id;
mlir::concretelang::RuntimeContext *context =
new mlir::concretelang::RuntimeContext;
new_allocated.push_back((void *)context);
mlir::concretelang::RuntimeContext **_context =
new mlir::concretelang::RuntimeContext *[1];
new_allocated.push_back((void *)_context);
_context[0] = context;
context->bsk = (LweBootstrapKey_u64 *)bsk_id;
context->ksk = (LweKeyswitchKey_u64 *)ksk_id;
params[p] = (void *)_context;
delete ((char *)params[p]);
// TODO: this might be relaxed with newer versions of HPX.
// Do not set the context here as remote operations are
// unstable when initiated within a HPX helper thread.
params[p] =
(void *)
_dfr_node_level_runtime_context_manager->getContextAddress();
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
@@ -131,6 +141,7 @@ struct OpaqueInputData {
ar << wfn_name;
ar << param_sizes << param_types;
ar << output_sizes << output_types;
ar << source_locality;
for (size_t p = 0; p < params.size(); ++p) {
// Save the first level of the data structure - if the parameter
// is a tensor/memref, there is a second level.
@@ -152,18 +163,16 @@ struct OpaqueInputData {
case _DFR_TASK_ARG_CONTEXT: {
mlir::concretelang::RuntimeContext *context =
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
LweKeyswitchKey_u64 *ksk = context->ksk;
LweBootstrapKey_u64 *bsk = context->bsk;
// TODO: find better unique identifiers. This is not a
// correctness issue, but performance.
uint64_t bsk_id = (uint64_t)bsk;
uint64_t ksk_id = (uint64_t)ksk;
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
_dfr_register_bsk(bsk, bsk_id);
_dfr_register_ksk(ksk, ksk_id);
ar << bsk_id << ksk_id << source_locality;
std::cout << "Registering Key ids " << (uint64_t)ksk << " "
<< (uint64_t)bsk << "\n"
<< std::flush;
_dfr_register_bsk(bsk, (uint64_t)bsk);
_dfr_register_ksk(ksk, (uint64_t)ksk);
ar << (uint64_t)bsk << (uint64_t)ksk;
} break;
case _DFR_TASK_ARG_UNRANKED_MEMREF:
default:
@@ -182,6 +191,8 @@ struct OpaqueInputData {
std::vector<uint64_t> output_types;
bool alloc_p = false;
hpx::naming::id_type source_locality;
uint64_t ksk_id;
uint64_t bsk_id;
};
struct OpaqueOutputData {
@@ -200,8 +211,9 @@ struct OpaqueOutputData {
template <class Archive> void load(Archive &ar, const unsigned int version) {
ar >> output_sizes >> output_types;
for (size_t p = 0; p < output_sizes.size(); ++p) {
char *output = new char[output_sizes[p]];
new_allocated.push_back((void *)output);
char *output;
_dfr_checked_aligned_alloc((void **)&output, 64, (output_sizes[p]));
ar >> hpx::serialization::make_array(output, output_sizes[p]);
outputs.push_back((void *)output);
@@ -217,8 +229,8 @@ struct OpaqueOutputData {
for (size_t r = 0; r < rank; ++r)
size *= mref.sizes[r];
size_t alloc_size = (size + mref.offset) * elementSize;
char *data = new char[alloc_size];
new_allocated.push_back((void *)data);
char *data;
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
size * elementSize);
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
@@ -283,23 +295,20 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.wfn_name);
std::vector<void *> outputs;
if (!_dfr_is_root_node()) {
for (size_t p = 0; p < inputs.params.size(); ++p) {
if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) {
mlir::concretelang::RuntimeContext *ctx =
(*(mlir::concretelang::RuntimeContext **)inputs.params[p]);
ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk);
ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk);
break;
}
}
if (inputs.source_locality != hpx::find_here() &&
(inputs.ksk_id || inputs.bsk_id)) {
_dfr_node_level_runtime_context_manager->getContext(
inputs.ksk_id, inputs.bsk_id, inputs.source_locality);
}
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
inputs.output_sizes.size());
hpx::cout << std::flush;
switch (inputs.output_sizes.size()) {
case 1: {
void *output = (void *)(new char[inputs.output_sizes[0]]);
void *output;
_dfr_checked_aligned_alloc(&output, 512, inputs.output_sizes[0]);
switch (inputs.params.size()) {
case 0:
wfn(output);
@@ -387,18 +396,52 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output};
new_allocated.push_back(output);
break;
}
case 2: {
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
void *output1, *output2;
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
switch (inputs.params.size()) {
case 0:
wfn(output1, output2);
@@ -491,20 +534,54 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output1, output2);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output1, output2);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output1,
output2);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output1, output2);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output1, output2);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output1, output2};
new_allocated.push_back(output1);
new_allocated.push_back(output2);
break;
}
case 3: {
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
void *output3 = (void *)(new char[inputs.output_sizes[2]]);
void *output1, *output2, *output3;
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[2]);
switch (inputs.params.size()) {
case 0:
wfn(output1, output2, output3);
@@ -597,15 +674,47 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], output1, output2, output3);
break;
case 17:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], output1, output2, output3);
break;
case 18:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17], output1,
output2, output3);
break;
case 19:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], output1, output2, output3);
break;
case 20:
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
inputs.params[3], inputs.params[4], inputs.params[5],
inputs.params[6], inputs.params[7], inputs.params[8],
inputs.params[9], inputs.params[10], inputs.params[11],
inputs.params[12], inputs.params[13], inputs.params[14],
inputs.params[15], inputs.params[16], inputs.params[17],
inputs.params[18], inputs.params[19], output1, output2, output3);
break;
default:
HPX_THROW_EXCEPTION(hpx::no_success,
"GenericComputeServer::execute_task",
"Error: number of task parameters not supported.");
}
outputs = {output1, output2, output3};
new_allocated.push_back(output1);
new_allocated.push_back(output2);
new_allocated.push_back(output3);
break;
}
default:
@@ -613,6 +722,18 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
"Error: number of task outputs not supported.");
}
// Deallocate input data buffers from OID deserialization (load)
if (!_dfr_is_root_node()) {
for (size_t p = 0; p < inputs.param_sizes.size(); ++p) {
if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) {
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
->data);
delete ((char *)inputs.params[p]);
}
}
}
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
std::move(inputs.output_types), inputs.alloc_p);
}

View File

@@ -15,6 +15,7 @@
#include <hpx/modules/serialization.hpp>
#include "concretelang/Runtime/DFRuntime.hpp"
#include "concretelang/Runtime/context.h"
extern "C" {
#include "concrete-ffi.h"
@@ -25,13 +26,12 @@ namespace concretelang {
namespace dfr {
template <typename T> struct KeyManager;
struct RuntimeContextManager;
namespace {
static void *dl_handle;
static KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
static KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
static std::list<void *> new_allocated;
static std::list<void *> fut_allocated;
static std::list<void *> m_allocated;
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
} // namespace
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
@@ -66,7 +66,6 @@ void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
new_allocated.push_back((void *)pointer);
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_bootstrap_key_u64(buffer);
@@ -87,7 +86,6 @@ void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
size_t length;
ar >> length;
uint8_t *pointer = new uint8_t[length];
new_allocated.push_back((void *)pointer);
ar >> hpx::serialization::make_array(pointer, length);
BufferView buffer = {(const uint8_t *)pointer, length};
key = deserialize_lwe_keyswitching_key_u64(buffer);
@@ -224,6 +222,68 @@ LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
}
/************************/
/* Context management. */
/************************/
struct RuntimeContextManager {
// TODO: this is only ok so long as we don't change keys. Once we
// use multiple keys, should have a map.
RuntimeContext *context;
std::mutex context_guard;
uint64_t ksk_id;
uint64_t bsk_id;
RuntimeContextManager() {
ksk_id = 0;
bsk_id = 0;
context = nullptr;
_dfr_node_level_runtime_context_manager = this;
}
RuntimeContext *getContext(uint64_t ksk, uint64_t bsk,
hpx::naming::id_type source_locality) {
std::cout << "GetContext on node " << hpx::get_locality_id()
<< " with context " << context << " " << bsk_id << " " << ksk_id
<< "\n"
<< std::flush;
if (context != nullptr) {
std::cout << "simil " << ksk_id << " " << ksk << " " << bsk_id << " "
<< bsk << "\n"
<< std::flush;
assert(ksk == ksk_id && bsk == bsk_id &&
"Context manager can only used with single keys for now.");
} else {
assert(ksk_id == 0 && bsk_id == 0 &&
"Context empty but context manager has key ids.");
LweKeyswitchKey_u64 *keySwitchKey = _dfr_get_ksk(source_locality, ksk);
LweBootstrapKey_u64 *bootstrapKey = _dfr_get_bsk(source_locality, bsk);
std::lock_guard<std::mutex> guard(context_guard);
if (context == nullptr) {
auto ctx = new RuntimeContext();
ctx->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
new ::concretelang::clientlib::LweKeyswitchKey(keySwitchKey)),
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
new ::concretelang::clientlib::LweBootstrapKey(bootstrapKey)));
ksk_id = ksk;
bsk_id = bsk;
context = ctx;
std::cout << "Fetching Key ids " << ksk_id << " " << bsk_id << "\n"
<< std::flush;
} else {
std::cout << " GOT context after LOCK on node "
<< hpx::get_locality_id() << " with context " << context
<< " " << bsk_id << " " << ksk_id << "\n"
<< std::flush;
}
}
return context;
}
RuntimeContext **getContextAddress() { return &context; }
};
} // namespace dfr
} // namespace concretelang
} // namespace mlir

View File

@@ -14,7 +14,7 @@ extern "C" {
typedef void (*wfnptr)(...);
void *_dfr_make_ready_future(void *);
void *_dfr_make_ready_future(void *, size_t);
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
void _dfr_register_work_function(wfnptr);
void *_dfr_await_future(void *);

View File

@@ -57,7 +57,7 @@ static bool isCandidateForTask(Operation *op) {
/// operations must not have side-effects and not be `isCandidateForTask`
static bool isSinkingBeneficiary(Operation *op) {
return isa<FHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, arith::SelectOp,
mlir::arith::CmpIOp>(op);
mlir::arith::CmpIOp, mlir::memref::GetGlobalOp>(op);
}
static bool
@@ -90,6 +90,92 @@ extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
return true;
}
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
for (auto &use : val.getUses()) {
aliasedUses.insert(&use);
if (isa<memref::CastOp, memref::ViewOp, memref::SubViewOp>(use.getOwner()))
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
}
}
static bool extractOutputMemrefAllocations(
Operation *op, SetVector<Value> existingDependencies,
SetVector<Operation *> &beneficiaryOps,
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
if (beneficiaryOps.count(op))
return true;
if (!isa<mlir::memref::AllocOp>(op))
return false;
Value val = op->getResults().front();
DenseSet<OpOperand *> aliasedUses;
getAliasedUses(val, aliasedUses);
// Helper function checking if a memref use writes to memory
auto hasMemoryWriteEffect = [&](OpOperand *use) {
// Call ops targeting concrete-ffi do not have memory effects
// interface, so handle apart.
// TODO: this could be handled better in BConcrete or higher.
if (isa<mlir::func::CallOp>(use->getOwner())) {
if (getCalledFunction(use->getOwner()).getName() ==
"memref_expand_lut_in_trivial_glwe_ct_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_add_lwe_ciphertexts_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_add_plaintext_lwe_ciphertext_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_mul_cleartext_lwe_ciphertext_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_negate_lwe_ciphertext_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_keyswitch_lwe_u64" ||
getCalledFunction(use->getOwner()).getName() ==
"memref_bootstrap_lwe_u64")
if (use->getOwner()->getOperand(0) == use->get())
return true;
if (getCalledFunction(use->getOwner()).getName() ==
"memref_copy_one_rank")
if (use->getOwner()->getOperand(1) == use->get())
return true;
}
// Otherwise we rely on the memory effect interface
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
if (!effectInterface)
return false;
SmallVector<MemoryEffects::EffectInstance, 2> effects;
effectInterface.getEffects(effects);
for (auto eff : effects)
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
eff.getValue() == use->get())
return true;
return false;
};
// We need to check if this allocated memref is written in this task.
// TODO: for now we'll assume that we don't do partial writes or read/writes.
for (auto use : aliasedUses)
if (hasMemoryWriteEffect(use) &&
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
// We will sink the operation, mark its results as now available.
beneficiaryOps.insert(op);
for (Value result : op->getResults())
availableValues.insert(result);
return true;
}
return false;
}
LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
Region &taskOpBody = taskOp.body();
@@ -104,6 +190,8 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
if (!operandOp)
continue;
extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues);
extractOutputMemrefAllocations(operandOp, sinkCandidates, toBeSunk,
availableValues, taskOp);
}
// Insert operations so that the defs get cloned before uses.

View File

@@ -25,6 +25,7 @@
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
#include <mlir/Conversion/LLVMCommon/Pattern.h>
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
#include <mlir/Dialect/Affine/Utils.h>
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
@@ -117,7 +118,8 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement,
Region &region) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
if (isa<RT::DataflowTaskOp>(use.getOwner()) &&
if ((isa<RT::DataflowTaskOp>(use.getOwner()) ||
isa<RT::DeallocateFutureOp>(use.getOwner())) &&
region.isAncestor(use.getOwner()->getParentRegion()))
use.set(replacement);
}
@@ -183,10 +185,7 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) {
// bytes until we can get the actual size of the actual types.
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
type.isa<mlir::concretelang::Concrete::LweKeySwitchKeyType>() ||
type.isa<mlir::concretelang::Concrete::LweBootstrapKeyType>() ||
type.isa<mlir::concretelang::Concrete::ForeignPlaintextListType>() ||
type.isa<mlir::concretelang::Concrete::PlaintextListType>()) {
type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
Value result =
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
@@ -204,38 +203,69 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) {
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
}
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
for (auto &use : val.getUses()) {
aliasedUses.insert(&use);
if (isa<memref::CastOp, memref::ViewOp, memref::SubViewOp>(use.getOwner()))
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
}
}
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
func::FuncOp workFunction) {
DataLayout dataLayout = DataLayout::closest(DFTOp);
Region &opBody = DFTOp->getParentOfType<func::FuncOp>().getBody();
BlockAndValueMapping map;
OpBuilder builder(DFTOp);
// First identify DFT operands that are not futures and are not
// defined by another DFT. These need to be made into futures and
// propagated to all other DFTs. We can allow PRE to eliminate the
// previous definitions if there are no non-future type uses.
builder.setInsertionPoint(DFTOp);
for (Value val : DFTOp.getOperands()) {
if (!val.getType().isa<RT::FutureType>()) {
Value newval;
OpBuilder::InsertionGuard guard(builder);
Type futType = RT::FutureType::get(val.getType());
Value memrefCloned, newval = val;
// Find out if this value is needed in any other task
SmallVector<Operation *, 2> taskOps;
for (auto &use : val.getUses())
if (isa<RT::DataflowTaskOp>(use.getOwner()))
taskOps.push_back(use.getOwner());
Operation *first = DFTOp;
for (auto op : taskOps)
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
first = op;
builder.setInsertionPoint(first);
// If we are building a future on a MemRef, then we need to clone
// the memref in order to allow the deallocation pass which does
// not synchronize with task execution.
if (val.getType().isa<mlir::MemRefType>() && !val.isa<BlockArgument>()) {
newval = builder
.create<memref::AllocOp>(
DFTOp.getLoc(), val.getType().cast<mlir::MemRefType>())
if (val.getType().isa<mlir::MemRefType>()) {
// Get the type of memref that we will clone. In case this is
// a subview, we discard the mapping so we copy to a contiguous
// layout which pre-serializes this.
MemRefType mrType = val.getType().dyn_cast<mlir::MemRefType>();
if (!mrType.getLayout().isIdentity()) {
unsigned rank = mrType.getRank();
mrType = MemRefType::Builder(mrType)
.setShape(mrType.getShape())
.setLayout(AffineMapAttr::get(
builder.getMultiDimIdentityMap(rank)));
}
newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
.getResult();
builder.create<memref::CopyOp>(DFTOp.getLoc(), val, newval);
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
memrefCloned = builder.create<arith::ConstantOp>(
val.getLoc(), builder.getI64IntegerAttr(1));
} else {
newval = val;
memrefCloned = builder.create<arith::ConstantOp>(
val.getLoc(), builder.getI64IntegerAttr(0));
}
Type futType = RT::FutureType::get(newval.getType());
auto mrf = builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType,
newval);
map.map(mrf->getResult(0), val);
replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody);
auto mrf = builder.create<RT::MakeReadyFutureOp>(val.getLoc(), futType,
newval, memrefCloned);
replaceAllUsesInDFTsInRegionWith(val, mrf, opBody);
}
}
@@ -268,6 +298,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
// unsupported even in the LLVMIR Dialect - this needs to use two
// placeholders for each output, before and after the
// CreateAsyncTaskOp.
BlockAndValueMapping map;
for (auto result : DFTOp.getResults()) {
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
@@ -297,6 +328,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
if (!isa<RT::DataflowTaskOp>(use.getOwner()) &&
!isa<RT::DeallocateFutureOp>(use.getOwner()) &&
use.getOwner()->getParentOfType<RT::DataflowTaskOp>() == nullptr) {
// Wait for this future before its uses
OpBuilder::InsertionGuard guard(builder);
@@ -315,24 +347,35 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
DFTOp.erase();
}
static void registerWorkFunction(FuncOp parentFunc, FuncOp workFunction) {
OpBuilder builder(parentFunc.body());
builder.setInsertionPointToStart(&parentFunc.body().front());
static void registerWorkFunction(mlir::func::FuncOp parentFunc,
mlir::func::FuncOp workFunction) {
OpBuilder builder(parentFunc.getBody());
builder.setInsertionPointToStart(&parentFunc.getBody().front());
auto fnptr = builder.create<mlir::ConstantOp>(
parentFunc.getLoc(), workFunction.getType(),
auto fnptr = builder.create<mlir::func::ConstantOp>(
parentFunc.getLoc(), workFunction.getFunctionType(),
SymbolRefAttr::get(builder.getContext(), workFunction.getName()));
builder.create<RT::RegisterTaskWorkFunctionOp>(parentFunc.getLoc(),
fnptr.getResult());
}
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
if (!sym)
return nullptr;
return dyn_cast_or_null<func::FuncOp>(
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
/// For documentation see Autopar.td
struct LowerDataflowTasksPass
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
void runOnOperation() override {
auto module = getOperation();
SmallVector<func::FuncOp, 4> workFunctions;
SmallVector<func::FuncOp, 1> entryPoints;
module.walk([&](mlir::func::FuncOp func) {
static int wfn_id = 0;
@@ -342,7 +385,7 @@ struct LowerDataflowTasksPass
return;
SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func);
std::vector<std::pair<RT::DataflowTaskOp, func::FuncOp>> outliningMap;
SmallVector<std::pair<RT::DataflowTaskOp, func::FuncOp>, 4> outliningMap;
func.walk([&](RT::DataflowTaskOp op) {
auto workFunctionName =
@@ -353,6 +396,7 @@ struct LowerDataflowTasksPass
outlineWorkFunction(op, workFunctionName.str());
outliningMap.push_back(
std::pair<RT::DataflowTaskOp, func::FuncOp>(op, outlinedFunc));
workFunctions.push_back(outlinedFunc);
symbolTable.insert(outlinedFunc);
return WalkResult::advance();
});
@@ -361,73 +405,72 @@ struct LowerDataflowTasksPass
for (auto mapping : outliningMap)
lowerDataflowTaskOp(mapping.first, mapping.second);
// Main is always an entry-point - otherwise check if this
// function is called within the module. TODO: we assume no
// recursion.
if (func.getName() == "main")
entryPoints.push_back(func);
else {
bool found = false;
module.walk([&](mlir::func::CallOp op) {
if (getCalledFunction(op) == func)
found = true;
});
if (!found)
entryPoints.push_back(func);
}
});
for (auto entryPoint : entryPoints) {
// If this is a JIT invocation and we're not on the root node,
// we do not need to do any computation, only register all work
// functions with the runtime system
if (!outliningMap.empty()) {
if (!workFunctions.empty()) {
if (!dfr::_dfr_is_root_node()) {
// auto regFunc = builder.create<FuncOp>(func.getLoc(),
// func.getName(), func.getType());
func.eraseBody();
entryPoint.eraseBody();
Block *b = new Block;
b->addArguments(func.getType().getInputs());
func.body().push_front(b);
for (int i = func.getType().getNumInputs() - 1; i >= 0; --i)
func.eraseArgument(i);
for (int i = func.getType().getNumResults() - 1; i >= 0; --i)
func.eraseResult(i);
OpBuilder builder(func.body());
builder.setInsertionPointToEnd(&func.body().front());
builder.create<ReturnOp>(func.getLoc());
SmallVector<Location, 4> locations;
for (auto input : entryPoint.getFunctionType().getInputs())
locations.push_back(entryPoint.getLoc());
b->addArguments(entryPoint.getFunctionType().getInputs(), locations);
entryPoint.getBody().push_front(b);
for (int i = entryPoint.getFunctionType().getNumInputs() - 1; i >= 0;
--i)
entryPoint.eraseArgument(i);
for (int i = entryPoint.getFunctionType().getNumResults() - 1; i >= 0;
--i)
entryPoint.eraseResult(i);
OpBuilder builder(entryPoint.getBody());
builder.setInsertionPointToEnd(&entryPoint.getBody().front());
builder.create<mlir::func::ReturnOp>(entryPoint.getLoc());
}
}
// Generate code to register all work-functions with the
// runtime.
for (auto mapping : outliningMap)
registerWorkFunction(func, mapping.second);
for (auto wf : workFunctions)
registerWorkFunction(entryPoint, wf);
// Issue _dfr_start/stop calls for this function
if (!outliningMap.empty()) {
OpBuilder builder(func.getBody());
builder.setInsertionPointToStart(&func.getBody().front());
if (!workFunctions.empty()) {
OpBuilder builder(entryPoint.getBody());
builder.setInsertionPointToStart(&entryPoint.getBody().front());
auto dfrStartFunOp = mlir::LLVM::lookupOrCreateFn(
func->getParentOfType<ModuleOp>(), "_dfr_start", {},
LLVM::LLVMVoidType::get(func->getContext()));
builder.create<LLVM::CallOp>(func.getLoc(), dfrStartFunOp,
module, "_dfr_start", {},
LLVM::LLVMVoidType::get(entryPoint->getContext()));
builder.create<LLVM::CallOp>(entryPoint.getLoc(), dfrStartFunOp,
mlir::ValueRange(),
ArrayRef<NamedAttribute>());
builder.setInsertionPoint(func.getBody().back().getTerminator());
builder.setInsertionPoint(entryPoint.getBody().back().getTerminator());
auto dfrStopFunOp = mlir::LLVM::lookupOrCreateFn(
func->getParentOfType<ModuleOp>(), "_dfr_stop", {},
LLVM::LLVMVoidType::get(func->getContext()));
builder.create<LLVM::CallOp>(func.getLoc(), dfrStopFunOp,
module, "_dfr_stop", {},
LLVM::LLVMVoidType::get(entryPoint->getContext()));
builder.create<LLVM::CallOp>(entryPoint.getLoc(), dfrStopFunOp,
mlir::ValueRange(),
ArrayRef<NamedAttribute>());
}
});
// Delay memref deallocations when memrefs are made into futures
module.walk([&](Operation *op) {
if (isa<RT::MakeReadyFutureOp>(*op) &&
op->getOperand(0).getType().isa<mlir::MemRefType>()) {
for (auto &use :
llvm::make_early_inc_range(op->getOperand(0).getUses())) {
if (isa<mlir::memref::DeallocOp>(use.getOwner())) {
OpBuilder builder(use.getOwner()
->getParentOfType<mlir::func::FuncOp>()
.getBody()
.back()
.getTerminator());
builder.clone(*use.getOwner());
use.getOwner()->erase();
}
}
}
return WalkResult::advance();
});
}
}
LowerDataflowTasksPass(bool debug) : debug(debug){};
@@ -440,5 +483,44 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug) {
return std::make_unique<LowerDataflowTasksPass>(debug);
}
namespace {
// For documentation see Autopar.td
struct FixupBufferDeallocationPass
: public FixupBufferDeallocationBase<FixupBufferDeallocationPass> {
void runOnOperation() override {
auto module = getOperation();
std::vector<Operation *> ops;
// All buffers allocated and either made into a future, directly
// or as a result of being returned by a task, are managed by the
// DFR runtime system's reference counting.
module.walk([&](RT::WorkFunctionReturnOp retOp) {
for (auto &use :
llvm::make_early_inc_range(retOp.getOperands().front().getUses()))
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
ops.push_back(use.getOwner());
});
module.walk([&](RT::MakeReadyFutureOp mrfOp) {
for (auto &use :
llvm::make_early_inc_range(mrfOp.getOperands().front().getUses()))
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
ops.push_back(use.getOwner());
});
for (auto op : ops)
op->erase();
}
FixupBufferDeallocationPass(bool debug) : debug(debug){};
protected:
bool debug;
};
} // end anonymous namespace
std::unique_ptr<mlir::Pass> createFixupBufferDeallocationPass(bool debug) {
return std::make_unique<FixupBufferDeallocationPass>(debug);
}
} // end namespace concretelang
} // end namespace mlir

View File

@@ -125,8 +125,9 @@ struct MakeReadyFutureOpInterfaceLowering
results[0]);
rewriter.create<LLVM::StoreOp>(mrfOp.getLoc(),
adaptor.getOperands().front(), allocatedPtr);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(mrfOp, mrfFuncOp, allocatedPtr);
SmallVector<Value, 4> mrfOperands = {adaptor.getOperands()};
mrfOperands[0] = allocatedPtr;
rewriter.replaceOpWithNewOp<LLVM::CallOp>(mrfOp, mrfFuncOp, mrfOperands);
return mlir::success();
}
};
@@ -178,16 +179,14 @@ struct RegisterTaskWorkFunctionOpInterfaceLowering
mlir::LogicalResult
matchAndRewrite(RT::RegisterTaskWorkFunctionOp rtwfOp,
ArrayRef<Value> operands,
RT::RegisterTaskWorkFunctionOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RT::RegisterTaskWorkFunctionOp::Adaptor transformed(operands);
auto rtwfFuncType =
LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true);
auto rtwfFuncOp = getOrInsertFuncOpDecl(
rtwfOp, "_dfr_register_work_function", rtwfFuncType, rewriter);
rewriter.replaceOpWithNewOp<LLVM::CallOp>(rtwfOp, rtwfFuncOp,
transformed.getOperands());
adaptor.getOperands());
return success();
}
};

View File

@@ -9,6 +9,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
#include "concretelang/Dialect/RT/IR/RTOps.h"
#include "concretelang/Dialect/RT/IR/RTTypes.h"
@@ -33,3 +34,66 @@ void DataflowTaskOp::build(
void DataflowTaskOp::getSuccessorRegions(
Optional<unsigned> index, ArrayRef<Attribute> operands,
SmallVectorImpl<RegionSuccessor> &regions) {}
llvm::Optional<mlir::Operation *>
DataflowTaskOp::buildDealloc(OpBuilder &builder, Value alloc) {
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
.getOperation();
}
llvm::Optional<mlir::Value> DataflowTaskOp::buildClone(OpBuilder &builder,
Value alloc) {
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
}
void DataflowTaskOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
for (auto input : inputs())
effects.emplace_back(MemoryEffects::Read::get(), input,
SideEffects::DefaultResource::get());
for (auto output : outputs())
effects.emplace_back(MemoryEffects::Write::get(), output,
SideEffects::DefaultResource::get());
for (auto output : outputs())
effects.emplace_back(MemoryEffects::Allocate::get(), output,
SideEffects::DefaultResource::get());
}
llvm::Optional<mlir::Operation *>
CloneFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
.getOperation();
}
llvm::Optional<mlir::Value> CloneFutureOp::buildClone(OpBuilder &builder,
Value alloc) {
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
}
void CloneFutureOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
SideEffects::DefaultResource::get());
}
llvm::Optional<mlir::Operation *>
MakeReadyFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
.getOperation();
}
llvm::Optional<mlir::Value> MakeReadyFutureOp::buildClone(OpBuilder &builder,
Value alloc) {
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
}
void MakeReadyFutureOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
effects.emplace_back(MemoryEffects::Read::get(), input(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), output(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
SideEffects::DefaultResource::get());
}

View File

@@ -38,33 +38,50 @@ static size_t num_nodes = 0;
using namespace hpx;
void *_dfr_make_ready_future(void *in) {
void *future = static_cast<void *>(
new hpx::shared_future<void *>(hpx::make_ready_future(in)));
mlir::concretelang::dfr::m_allocated.push_back(in);
mlir::concretelang::dfr::fut_allocated.push_back(future);
return future;
typedef struct dfr_refcounted_future {
hpx::shared_future<void *> *future;
std::atomic<std::size_t> count;
bool cloned_memref_p;
dfr_refcounted_future(hpx::shared_future<void *> *f, size_t c, bool clone_p)
: future(f), count(c), cloned_memref_p(clone_p) {}
} dfr_refcounted_future_t, *dfr_refcounted_future_p;
// Ready futures are only used as inputs to tasks (never passed to
// await_future), so we only need to track the references in task
// creation.
void *_dfr_make_ready_future(void *in, size_t memref_clone_p) {
return (void *)new dfr_refcounted_future_t(
new hpx::shared_future<void *>(hpx::make_ready_future(in)), 1,
memref_clone_p);
}
void *_dfr_await_future(void *in) {
return static_cast<hpx::shared_future<void *> *>(in)->get();
}
void _dfr_deallocate_future_data(void *in) {
delete[] static_cast<char *>(
static_cast<hpx::shared_future<void *> *>(in)->get());
return static_cast<dfr_refcounted_future_p>(in)->future->get();
}
void _dfr_deallocate_future(void *in) {
delete (static_cast<hpx::shared_future<void *> *>(in));
auto drf = static_cast<dfr_refcounted_future_p>(in);
size_t prev_count = drf->count.fetch_sub(1);
if (prev_count == 1) {
// If this was a memref for which a clone was needed, deallocate first.
if (drf->cloned_memref_p)
free(
(void *)(static_cast<StridedMemRefType<char, 1> *>(drf->future->get())
->data));
free(drf->future->get());
delete (drf->future);
delete drf;
}
}
void _dfr_deallocate_future_data(void *in) {}
// Determine where new task should run. For now just round-robin
// distribution - TODO: optimise.
static inline size_t _dfr_find_next_execution_locality() {
static std::atomic<std::size_t> next_locality{0};
static std::atomic<std::size_t> next_locality{1};
size_t next_loc = ++next_locality;
size_t next_loc = next_locality.fetch_add(1);
return next_loc % mlir::concretelang::dfr::num_nodes;
}
@@ -76,7 +93,8 @@ static inline size_t _dfr_find_next_execution_locality() {
/// the returns.
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
...) {
std::vector<void *> params;
// std::vector<void *> params;
std::vector<void *> refcounted_futures;
std::vector<size_t> param_sizes;
std::vector<uint64_t> param_types;
std::vector<void *> outputs;
@@ -86,7 +104,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
va_list args;
va_start(args, num_outputs);
for (size_t i = 0; i < num_params; ++i) {
params.push_back(va_arg(args, void *));
refcounted_futures.push_back(va_arg(args, void *));
param_sizes.push_back(va_arg(args, uint64_t));
param_types.push_back(va_arg(args, uint64_t));
}
@@ -97,13 +115,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
}
va_end(args);
for (size_t i = 0; i < num_params; ++i) {
if (mlir::concretelang::dfr::_dfr_get_arg_type(
param_types[i] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF)) {
mlir::concretelang::dfr::m_allocated.push_back(
(void *)static_cast<StridedMemRefType<char, 1> *>(params[i])->data);
}
}
// Take a reference on each future argument
for (auto rcf : refcounted_futures)
((dfr_refcounted_future_p)rcf)->count.fetch_add(1);
// We pass functions by name - which is not strictly necessary in
// shared memory as pointers suffice, but is needed in the
@@ -147,7 +161,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future));
break;
case 2:
@@ -164,8 +178,8 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future));
break;
case 3:
@@ -184,9 +198,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future));
break;
case 4:
@@ -206,10 +220,10 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future));
break;
case 5:
@@ -231,11 +245,11 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future));
break;
case 6:
@@ -258,12 +272,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future));
break;
case 7:
@@ -287,13 +301,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future));
break;
case 8:
@@ -318,14 +332,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future));
break;
case 9:
@@ -352,15 +366,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future));
break;
case 10:
@@ -388,16 +402,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8],
*(hpx::shared_future<void *> *)params[9]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future));
break;
case 11:
@@ -426,17 +440,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8],
*(hpx::shared_future<void *> *)params[9],
*(hpx::shared_future<void *> *)params[10]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future));
break;
case 12:
@@ -466,18 +480,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8],
*(hpx::shared_future<void *> *)params[9],
*(hpx::shared_future<void *> *)params[10],
*(hpx::shared_future<void *> *)params[11]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future));
break;
case 13:
@@ -509,19 +523,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8],
*(hpx::shared_future<void *> *)params[9],
*(hpx::shared_future<void *> *)params[10],
*(hpx::shared_future<void *> *)params[11],
*(hpx::shared_future<void *> *)params[12]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future));
break;
case 14:
@@ -554,20 +568,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8],
*(hpx::shared_future<void *> *)params[9],
*(hpx::shared_future<void *> *)params[10],
*(hpx::shared_future<void *> *)params[11],
*(hpx::shared_future<void *> *)params[12],
*(hpx::shared_future<void *> *)params[13]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
*((dfr_refcounted_future_p)refcounted_futures[13])->future));
break;
case 15:
@@ -601,21 +615,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8],
*(hpx::shared_future<void *> *)params[9],
*(hpx::shared_future<void *> *)params[10],
*(hpx::shared_future<void *> *)params[11],
*(hpx::shared_future<void *> *)params[12],
*(hpx::shared_future<void *> *)params[13],
*(hpx::shared_future<void *> *)params[14]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
*((dfr_refcounted_future_p)refcounted_futures[14])->future));
break;
case 16:
@@ -650,22 +664,246 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*(hpx::shared_future<void *> *)params[0],
*(hpx::shared_future<void *> *)params[1],
*(hpx::shared_future<void *> *)params[2],
*(hpx::shared_future<void *> *)params[3],
*(hpx::shared_future<void *> *)params[4],
*(hpx::shared_future<void *> *)params[5],
*(hpx::shared_future<void *> *)params[6],
*(hpx::shared_future<void *> *)params[7],
*(hpx::shared_future<void *> *)params[8],
*(hpx::shared_future<void *> *)params[9],
*(hpx::shared_future<void *> *)params[10],
*(hpx::shared_future<void *> *)params[11],
*(hpx::shared_future<void *> *)params[12],
*(hpx::shared_future<void *> *)params[13],
*(hpx::shared_future<void *> *)params[14],
*(hpx::shared_future<void *> *)params[15]));
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
*((dfr_refcounted_future_p)refcounted_futures[15])->future));
break;
case 17:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get(), param13.get(), param14.get(), param15.get(),
param16.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
*((dfr_refcounted_future_p)refcounted_futures[16])->future));
break;
case 18:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get(), param13.get(), param14.get(), param15.get(),
param16.get(), param17.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
*((dfr_refcounted_future_p)refcounted_futures[16])->future,
*((dfr_refcounted_future_p)refcounted_futures[17])->future));
break;
case 19:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17,
hpx::shared_future<void *> param18)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get(), param13.get(), param14.get(), param15.get(),
param16.get(), param17.get(), param18.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
*((dfr_refcounted_future_p)refcounted_futures[16])->future,
*((dfr_refcounted_future_p)refcounted_futures[17])->future,
*((dfr_refcounted_future_p)refcounted_futures[18])->future));
break;
case 20:
oodf = std::move(hpx::dataflow(
[wfnname, param_sizes, param_types, output_sizes,
output_types](hpx::shared_future<void *> param0,
hpx::shared_future<void *> param1,
hpx::shared_future<void *> param2,
hpx::shared_future<void *> param3,
hpx::shared_future<void *> param4,
hpx::shared_future<void *> param5,
hpx::shared_future<void *> param6,
hpx::shared_future<void *> param7,
hpx::shared_future<void *> param8,
hpx::shared_future<void *> param9,
hpx::shared_future<void *> param10,
hpx::shared_future<void *> param11,
hpx::shared_future<void *> param12,
hpx::shared_future<void *> param13,
hpx::shared_future<void *> param14,
hpx::shared_future<void *> param15,
hpx::shared_future<void *> param16,
hpx::shared_future<void *> param17,
hpx::shared_future<void *> param18,
hpx::shared_future<void *> param19)
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
std::vector<void *> params = {
param0.get(), param1.get(), param2.get(), param3.get(),
param4.get(), param5.get(), param6.get(), param7.get(),
param8.get(), param9.get(), param10.get(), param11.get(),
param12.get(), param13.get(), param14.get(), param15.get(),
param16.get(), param17.get(), param18.get(), param19.get()};
mlir::concretelang::dfr::OpaqueInputData oid(
wfnname, params, param_sizes, param_types, output_sizes,
output_types);
return mlir::concretelang::dfr::gcc
[_dfr_find_next_execution_locality()]
.execute_task(oid);
},
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
*((dfr_refcounted_future_p)refcounted_futures[16])->future,
*((dfr_refcounted_future_p)refcounted_futures[17])->future,
*((dfr_refcounted_future_p)refcounted_futures[18])->future,
*((dfr_refcounted_future_p)refcounted_futures[19])->future));
break;
default:
@@ -675,53 +913,67 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
switch (num_outputs) {
case 1:
*((void **)outputs[0]) = new hpx::shared_future<void *>(hpx::dataflow(
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
-> void * { return oodf_in.get().outputs[0]; },
oodf));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
*((void **)outputs[0]) = (void *)new dfr_refcounted_future_t(
new hpx::shared_future<void *>(hpx::dataflow(
[refcounted_futures](
hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
-> void * {
void *ret = oodf_in.get().outputs[0];
for (auto rcf : refcounted_futures)
_dfr_deallocate_future(rcf);
return ret;
},
oodf)),
1, output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
break;
case 2: {
hpx::future<hpx::tuple<void *, void *>> &&ft = hpx::dataflow(
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
[refcounted_futures](
hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
-> hpx::tuple<void *, void *> {
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
for (auto rcf : refcounted_futures)
_dfr_deallocate_future(rcf);
return hpx::make_tuple<>(outputs[0], outputs[1]);
},
oodf);
hpx::tuple<hpx::future<void *>, hpx::future<void *>> &&tf =
hpx::split_future(std::move(ft));
*((void **)outputs[0]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<0>(tf)));
*((void **)outputs[1]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1]));
*((void **)outputs[0]) = (void *)new dfr_refcounted_future_t(
new hpx::shared_future<void *>(std::move(hpx::get<0>(tf))), 1,
output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
*((void **)outputs[1]) = (void *)new dfr_refcounted_future_t(
new hpx::shared_future<void *>(std::move(hpx::get<1>(tf))), 1,
output_types[1] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
break;
}
case 3: {
hpx::future<hpx::tuple<void *, void *, void *>> &&ft = hpx::dataflow(
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
[refcounted_futures](
hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
-> hpx::tuple<void *, void *, void *> {
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
for (auto rcf : refcounted_futures)
_dfr_deallocate_future(rcf);
return hpx::make_tuple<>(outputs[0], outputs[1], outputs[2]);
},
oodf);
hpx::tuple<hpx::future<void *>, hpx::future<void *>, hpx::future<void *>>
&&tf = hpx::split_future(std::move(ft));
*((void **)outputs[0]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<0>(tf)));
*((void **)outputs[1]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
*((void **)outputs[2]) =
(void *)new hpx::shared_future<void *>(std::move(hpx::get<2>(tf)));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1]));
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[2]));
*((void **)outputs[0]) = (void *)new dfr_refcounted_future_t(
new hpx::shared_future<void *>(std::move(hpx::get<0>(tf))), 1,
output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
*((void **)outputs[1]) = (void *)new dfr_refcounted_future_t(
new hpx::shared_future<void *>(std::move(hpx::get<1>(tf))), 1,
output_types[1] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
*((void **)outputs[2]) = (void *)new dfr_refcounted_future_t(
new hpx::shared_future<void *>(std::move(hpx::get<2>(tf))), 1,
output_types[2] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
break;
}
default:
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_create_async_task",
"Error: number of task outputs not supported.");
@@ -896,6 +1148,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
new mlir::concretelang::dfr::KeyManager<LweBootstrapKey_u64>();
new mlir::concretelang::dfr::KeyManager<LweKeyswitchKey_u64>();
new mlir::concretelang::dfr::RuntimeContextManager();
new mlir::concretelang::dfr::WorkFunctionRegistry();
mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier =
new hpx::lcos::barrier("wait_register_remote_work_functions",
@@ -985,21 +1238,8 @@ void _dfr_stop() {
// safer to drop them in-between phases.
mlir::concretelang::dfr::_dfr_node_level_bsk_manager->clear_keys();
mlir::concretelang::dfr::_dfr_node_level_ksk_manager->clear_keys();
while (!mlir::concretelang::dfr::new_allocated.empty()) {
delete[] static_cast<char *>(
mlir::concretelang::dfr::new_allocated.front());
mlir::concretelang::dfr::new_allocated.pop_front();
}
while (!mlir::concretelang::dfr::fut_allocated.empty()) {
delete static_cast<hpx::shared_future<void *> *>(
mlir::concretelang::dfr::fut_allocated.front());
mlir::concretelang::dfr::fut_allocated.pop_front();
}
while (!mlir::concretelang::dfr::m_allocated.empty()) {
free(mlir::concretelang::dfr::m_allocated.front());
mlir::concretelang::dfr::m_allocated.pop_front();
}
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
->clearContext();
}
void _dfr_try_initialize() {

View File

@@ -278,6 +278,11 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
pm, mlir::concretelang::createFixupDataflowTaskOpsPass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass);
// Use the buffer deallocation interface to insert future deallocation calls
addPotentiallyNestedPass(
pm, mlir::bufferization::createBufferDeallocationPass(), enablePass);
addPotentiallyNestedPass(
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
// Convert to MLIR LLVM Dialect
addPotentiallyNestedPass(