mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
feat(dfr): add memory management for futures and associated data.
This commit is contained in:
@@ -19,7 +19,6 @@ extern "C" {
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/ClientLib/KeySetCache.h"
|
||||
#include "concretelang/Common/Error.h"
|
||||
#include <concretelang/Runtime/DFRuntime.hpp>
|
||||
|
||||
namespace concretelang {
|
||||
namespace clientlib {
|
||||
@@ -86,9 +85,15 @@ public:
|
||||
}
|
||||
|
||||
EvaluationKeys evaluationKeys() {
|
||||
auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0"));
|
||||
auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0"));
|
||||
return EvaluationKeys(sharedKsk, sharedBsk);
|
||||
auto kskIt = this->keyswitchKeys.find("ksk_v0");
|
||||
auto bskIt = this->bootstrapKeys.find("bsk_v0");
|
||||
if (kskIt != this->keyswitchKeys.end() &&
|
||||
bskIt != this->bootstrapKeys.end()) {
|
||||
auto sharedKsk = std::get<1>(kskIt->second);
|
||||
auto sharedBsk = std::get<1>(bskIt->second);
|
||||
return EvaluationKeys(sharedKsk, sharedBsk);
|
||||
}
|
||||
return EvaluationKeys();
|
||||
}
|
||||
|
||||
const std::map<LweSecretKeyID,
|
||||
|
||||
@@ -23,6 +23,8 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createBufferizeDataflowTaskOpsPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass> createFixupDataflowTaskOpsPass(bool debug = false);
|
||||
std::unique_ptr<mlir::Pass>
|
||||
createFixupBufferDeallocationPass(bool debug = false);
|
||||
void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter,
|
||||
mlir::RewritePatternSet &patterns);
|
||||
void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter,
|
||||
|
||||
@@ -82,6 +82,15 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> {
|
||||
}];
|
||||
}
|
||||
|
||||
def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> {
|
||||
let summary =
|
||||
"Prevent deallocation of buffers returned as futures by tasks.";
|
||||
|
||||
let description = [{ This pass removes buffer deallocation calls on
|
||||
buffers being used for dataflow communication between
|
||||
tasks. These buffers cannot be deallocated directly without
|
||||
synchronization as they can be needed by asynchronous
|
||||
computation. Instead, these will be deallocated by the runtime
|
||||
when no longer needed.}]; }
|
||||
|
||||
#endif
|
||||
|
||||
@@ -6,10 +6,12 @@
|
||||
#ifndef CONCRETELANG_DIALECT_RT_IR_RTOPS_H
|
||||
#define CONCRETELANG_DIALECT_RT_IR_RTOPS_H
|
||||
|
||||
#include <mlir/Dialect/Bufferization/IR/AllocationOpInterface.h>
|
||||
#include <mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h>
|
||||
#include <mlir/IR/BuiltinOps.h>
|
||||
#include <mlir/IR/BuiltinTypes.h>
|
||||
#include <mlir/Interfaces/ControlFlowInterfaces.h>
|
||||
#include <mlir/Interfaces/DataLayoutInterfaces.h>
|
||||
#include <mlir/Interfaces/SideEffectInterfaces.h>
|
||||
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
#ifndef CONCRETELANG_DIALECT_RT_IR_RT_OPS
|
||||
#define CONCRETELANG_DIALECT_RT_IR_RT_OPS
|
||||
|
||||
include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td"
|
||||
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
include "mlir/Interfaces/DataLayoutInterfaces.td"
|
||||
|
||||
@@ -15,9 +16,13 @@ class RT_Op<string mnemonic, list<Trait> traits = []> :
|
||||
|
||||
def RT_DataflowTaskOp : RT_Op<"dataflow_task", [
|
||||
DeclareOpInterfaceMethods<RegionBranchOpInterface>,
|
||||
SingleBlockImplicitTerminator<"DataflowYieldOp">]> {
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<AllocationOpInterface,
|
||||
["buildDealloc", "buildClone"]>,
|
||||
SingleBlockImplicitTerminator<"DataflowYieldOp">,
|
||||
AutomaticAllocationScope] > {
|
||||
let arguments = (ins Variadic<AnyType>: $inputs);
|
||||
let results = (outs Variadic<AnyType>:$outputs);
|
||||
let results = (outs Variadic<AnyType>: $outputs);
|
||||
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
@@ -85,8 +90,12 @@ Example:
|
||||
}];
|
||||
}
|
||||
|
||||
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future"> {
|
||||
let arguments = (ins AnyType: $input);
|
||||
def RT_MakeReadyFutureOp : RT_Op<"make_ready_future", [
|
||||
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<AllocationOpInterface,
|
||||
["buildDealloc", "buildClone"]>]> {
|
||||
let arguments = (ins AnyType: $input,
|
||||
AnyType: $memrefCloned);
|
||||
let results = (outs RT_Future: $output);
|
||||
let summary = "Build a ready future.";
|
||||
let description = [{
|
||||
@@ -115,14 +124,27 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> {
|
||||
let summary = "Create a dataflow task.";
|
||||
}
|
||||
|
||||
def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
|
||||
def RT_RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> {
|
||||
let arguments = (ins Variadic<AnyType>:$list);
|
||||
let results = (outs );
|
||||
let summary = "Register the task work-function with the runtime system.";
|
||||
}
|
||||
|
||||
def DeallocateFutureOp : RT_Op<"deallocate_future"> {
|
||||
def RT_CloneFutureOp : RT_Op<"clone_future",
|
||||
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
||||
DeclareOpInterfaceMethods<AllocationOpInterface,
|
||||
["buildDealloc", "buildClone"]>] > {
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value": $input), [{
|
||||
return build($_builder, $_state, input.getType(), input);
|
||||
}]>];
|
||||
|
||||
let arguments = (ins RT_Future: $input);
|
||||
let results = (outs RT_Future: $output);
|
||||
}
|
||||
|
||||
def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> {
|
||||
let arguments = (ins AnyType: $input);
|
||||
let results = (outs );
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <cstdarg>
|
||||
#include <cstdlib>
|
||||
#include <malloc.h>
|
||||
#include <string>
|
||||
|
||||
#include <hpx/async_colocated/get_colocation_id.hpp>
|
||||
@@ -28,6 +29,7 @@
|
||||
|
||||
#include <mlir/ExecutionEngine/CRunnerUtils.h>
|
||||
|
||||
#include "concretelang/ClientLib/EvaluationKeys.h"
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
#include "concretelang/Runtime/dfr_debug_interface.h"
|
||||
@@ -49,6 +51,17 @@ static inline size_t _dfr_get_memref_rank(size_t size) {
|
||||
(2 * sizeof(int64_t) /*size&stride/rank*/);
|
||||
}
|
||||
|
||||
static inline void _dfr_checked_aligned_alloc(void **out, size_t align,
|
||||
size_t size) {
|
||||
int res = posix_memalign(out, align, size);
|
||||
if (res == ENOMEM)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
|
||||
"Error: insufficient memory available.");
|
||||
if (res == EINVAL)
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed",
|
||||
"Error: invalid memory alignment.");
|
||||
}
|
||||
|
||||
struct OpaqueInputData {
|
||||
OpaqueInputData() = default;
|
||||
|
||||
@@ -62,7 +75,7 @@ struct OpaqueInputData {
|
||||
param_types(std::move(_param_types)),
|
||||
output_sizes(std::move(_output_sizes)),
|
||||
output_types(std::move(_output_types)), alloc_p(_alloc_p),
|
||||
source_locality(hpx::find_here()) {}
|
||||
source_locality(hpx::find_here()), ksk_id(0), bsk_id(0) {}
|
||||
|
||||
OpaqueInputData(const OpaqueInputData &oid)
|
||||
: wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)),
|
||||
@@ -70,16 +83,18 @@ struct OpaqueInputData {
|
||||
param_types(std::move(oid.param_types)),
|
||||
output_sizes(std::move(oid.output_sizes)),
|
||||
output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p),
|
||||
source_locality(oid.source_locality) {}
|
||||
source_locality(oid.source_locality), ksk_id(oid.ksk_id),
|
||||
bsk_id(oid.bsk_id) {}
|
||||
|
||||
friend class hpx::serialization::access;
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar >> wfn_name;
|
||||
ar >> param_sizes >> param_types;
|
||||
ar >> output_sizes >> output_types;
|
||||
ar >> source_locality;
|
||||
for (size_t p = 0; p < param_sizes.size(); ++p) {
|
||||
char *param = new char[param_sizes[p]];
|
||||
new_allocated.push_back((void *)param);
|
||||
char *param;
|
||||
_dfr_checked_aligned_alloc((void **)¶m, 64, param_sizes[p]);
|
||||
ar >> hpx::serialization::make_array(param, param_sizes[p]);
|
||||
params.push_back((void *)param);
|
||||
|
||||
@@ -95,28 +110,23 @@ struct OpaqueInputData {
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
size_t alloc_size = (size + mref.offset) * elementSize;
|
||||
char *data = new char[alloc_size];
|
||||
new_allocated.push_back((void *)data);
|
||||
char *data;
|
||||
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
|
||||
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
|
||||
size * elementSize);
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->basePtr = nullptr;
|
||||
static_cast<StridedMemRefType<char, 1> *>(params[p])->data = data;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
uint64_t bsk_id, ksk_id;
|
||||
ar >> bsk_id >> ksk_id >> source_locality;
|
||||
ar >> bsk_id >> ksk_id;
|
||||
|
||||
mlir::concretelang::RuntimeContext *context =
|
||||
new mlir::concretelang::RuntimeContext;
|
||||
new_allocated.push_back((void *)context);
|
||||
mlir::concretelang::RuntimeContext **_context =
|
||||
new mlir::concretelang::RuntimeContext *[1];
|
||||
new_allocated.push_back((void *)_context);
|
||||
_context[0] = context;
|
||||
|
||||
context->bsk = (LweBootstrapKey_u64 *)bsk_id;
|
||||
context->ksk = (LweKeyswitchKey_u64 *)ksk_id;
|
||||
params[p] = (void *)_context;
|
||||
delete ((char *)params[p]);
|
||||
// TODO: this might be relaxed with newer versions of HPX.
|
||||
// Do not set the context here as remote operations are
|
||||
// unstable when initiated within a HPX helper thread.
|
||||
params[p] =
|
||||
(void *)
|
||||
_dfr_node_level_runtime_context_manager->getContextAddress();
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
@@ -131,6 +141,7 @@ struct OpaqueInputData {
|
||||
ar << wfn_name;
|
||||
ar << param_sizes << param_types;
|
||||
ar << output_sizes << output_types;
|
||||
ar << source_locality;
|
||||
for (size_t p = 0; p < params.size(); ++p) {
|
||||
// Save the first level of the data structure - if the parameter
|
||||
// is a tensor/memref, there is a second level.
|
||||
@@ -152,18 +163,16 @@ struct OpaqueInputData {
|
||||
case _DFR_TASK_ARG_CONTEXT: {
|
||||
mlir::concretelang::RuntimeContext *context =
|
||||
*static_cast<mlir::concretelang::RuntimeContext **>(params[p]);
|
||||
LweKeyswitchKey_u64 *ksk = context->ksk;
|
||||
LweBootstrapKey_u64 *bsk = context->bsk;
|
||||
|
||||
// TODO: find better unique identifiers. This is not a
|
||||
// correctness issue, but performance.
|
||||
uint64_t bsk_id = (uint64_t)bsk;
|
||||
uint64_t ksk_id = (uint64_t)ksk;
|
||||
LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context);
|
||||
LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context);
|
||||
|
||||
assert(bsk != nullptr && ksk != nullptr && "Missing context keys");
|
||||
_dfr_register_bsk(bsk, bsk_id);
|
||||
_dfr_register_ksk(ksk, ksk_id);
|
||||
ar << bsk_id << ksk_id << source_locality;
|
||||
std::cout << "Registering Key ids " << (uint64_t)ksk << " "
|
||||
<< (uint64_t)bsk << "\n"
|
||||
<< std::flush;
|
||||
_dfr_register_bsk(bsk, (uint64_t)bsk);
|
||||
_dfr_register_ksk(ksk, (uint64_t)ksk);
|
||||
ar << (uint64_t)bsk << (uint64_t)ksk;
|
||||
} break;
|
||||
case _DFR_TASK_ARG_UNRANKED_MEMREF:
|
||||
default:
|
||||
@@ -182,6 +191,8 @@ struct OpaqueInputData {
|
||||
std::vector<uint64_t> output_types;
|
||||
bool alloc_p = false;
|
||||
hpx::naming::id_type source_locality;
|
||||
uint64_t ksk_id;
|
||||
uint64_t bsk_id;
|
||||
};
|
||||
|
||||
struct OpaqueOutputData {
|
||||
@@ -200,8 +211,9 @@ struct OpaqueOutputData {
|
||||
template <class Archive> void load(Archive &ar, const unsigned int version) {
|
||||
ar >> output_sizes >> output_types;
|
||||
for (size_t p = 0; p < output_sizes.size(); ++p) {
|
||||
char *output = new char[output_sizes[p]];
|
||||
new_allocated.push_back((void *)output);
|
||||
char *output;
|
||||
_dfr_checked_aligned_alloc((void **)&output, 64, (output_sizes[p]));
|
||||
|
||||
ar >> hpx::serialization::make_array(output, output_sizes[p]);
|
||||
outputs.push_back((void *)output);
|
||||
|
||||
@@ -217,8 +229,8 @@ struct OpaqueOutputData {
|
||||
for (size_t r = 0; r < rank; ++r)
|
||||
size *= mref.sizes[r];
|
||||
size_t alloc_size = (size + mref.offset) * elementSize;
|
||||
char *data = new char[alloc_size];
|
||||
new_allocated.push_back((void *)data);
|
||||
char *data;
|
||||
_dfr_checked_aligned_alloc((void **)&data, 512, alloc_size);
|
||||
ar >> hpx::serialization::make_array(data + mref.offset * elementSize,
|
||||
size * elementSize);
|
||||
static_cast<StridedMemRefType<char, 1> *>(outputs[p])->basePtr =
|
||||
@@ -283,23 +295,20 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.wfn_name);
|
||||
std::vector<void *> outputs;
|
||||
|
||||
if (!_dfr_is_root_node()) {
|
||||
for (size_t p = 0; p < inputs.params.size(); ++p) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) {
|
||||
mlir::concretelang::RuntimeContext *ctx =
|
||||
(*(mlir::concretelang::RuntimeContext **)inputs.params[p]);
|
||||
ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk);
|
||||
ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (inputs.source_locality != hpx::find_here() &&
|
||||
(inputs.ksk_id || inputs.bsk_id)) {
|
||||
_dfr_node_level_runtime_context_manager->getContext(
|
||||
inputs.ksk_id, inputs.bsk_id, inputs.source_locality);
|
||||
}
|
||||
|
||||
_dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(),
|
||||
inputs.output_sizes.size());
|
||||
hpx::cout << std::flush;
|
||||
|
||||
switch (inputs.output_sizes.size()) {
|
||||
case 1: {
|
||||
void *output = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output;
|
||||
_dfr_checked_aligned_alloc(&output, 512, inputs.output_sizes[0]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output);
|
||||
@@ -387,18 +396,52 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output};
|
||||
new_allocated.push_back(output);
|
||||
break;
|
||||
}
|
||||
case 2: {
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
void *output1, *output2;
|
||||
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
|
||||
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2);
|
||||
@@ -491,20 +534,54 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output1, output2);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output1, output2);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output1,
|
||||
output2);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output1, output2);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output1, output2);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output1, output2};
|
||||
new_allocated.push_back(output1);
|
||||
new_allocated.push_back(output2);
|
||||
break;
|
||||
}
|
||||
case 3: {
|
||||
void *output1 = (void *)(new char[inputs.output_sizes[0]]);
|
||||
void *output2 = (void *)(new char[inputs.output_sizes[1]]);
|
||||
void *output3 = (void *)(new char[inputs.output_sizes[2]]);
|
||||
void *output1, *output2, *output3;
|
||||
_dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]);
|
||||
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]);
|
||||
_dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[2]);
|
||||
switch (inputs.params.size()) {
|
||||
case 0:
|
||||
wfn(output1, output2, output3);
|
||||
@@ -597,15 +674,47 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], output1, output2, output3);
|
||||
break;
|
||||
case 17:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], output1, output2, output3);
|
||||
break;
|
||||
case 18:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17], output1,
|
||||
output2, output3);
|
||||
break;
|
||||
case 19:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], output1, output2, output3);
|
||||
break;
|
||||
case 20:
|
||||
wfn(inputs.params[0], inputs.params[1], inputs.params[2],
|
||||
inputs.params[3], inputs.params[4], inputs.params[5],
|
||||
inputs.params[6], inputs.params[7], inputs.params[8],
|
||||
inputs.params[9], inputs.params[10], inputs.params[11],
|
||||
inputs.params[12], inputs.params[13], inputs.params[14],
|
||||
inputs.params[15], inputs.params[16], inputs.params[17],
|
||||
inputs.params[18], inputs.params[19], output1, output2, output3);
|
||||
break;
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success,
|
||||
"GenericComputeServer::execute_task",
|
||||
"Error: number of task parameters not supported.");
|
||||
}
|
||||
outputs = {output1, output2, output3};
|
||||
new_allocated.push_back(output1);
|
||||
new_allocated.push_back(output2);
|
||||
new_allocated.push_back(output3);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@@ -613,6 +722,18 @@ struct GenericComputeServer : component_base<GenericComputeServer> {
|
||||
"Error: number of task outputs not supported.");
|
||||
}
|
||||
|
||||
// Deallocate input data buffers from OID deserialization (load)
|
||||
if (!_dfr_is_root_node()) {
|
||||
for (size_t p = 0; p < inputs.param_sizes.size(); ++p) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) {
|
||||
if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF)
|
||||
delete (static_cast<StridedMemRefType<char, 1> *>(inputs.params[p])
|
||||
->data);
|
||||
delete ((char *)inputs.params[p]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes),
|
||||
std::move(inputs.output_types), inputs.alloc_p);
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include <hpx/modules/serialization.hpp>
|
||||
|
||||
#include "concretelang/Runtime/DFRuntime.hpp"
|
||||
#include "concretelang/Runtime/context.h"
|
||||
|
||||
extern "C" {
|
||||
#include "concrete-ffi.h"
|
||||
@@ -25,13 +26,12 @@ namespace concretelang {
|
||||
namespace dfr {
|
||||
|
||||
template <typename T> struct KeyManager;
|
||||
struct RuntimeContextManager;
|
||||
namespace {
|
||||
static void *dl_handle;
|
||||
static KeyManager<LweBootstrapKey_u64> *_dfr_node_level_bsk_manager;
|
||||
static KeyManager<LweKeyswitchKey_u64> *_dfr_node_level_ksk_manager;
|
||||
static std::list<void *> new_allocated;
|
||||
static std::list<void *> fut_allocated;
|
||||
static std::list<void *> m_allocated;
|
||||
static RuntimeContextManager *_dfr_node_level_runtime_context_manager;
|
||||
} // namespace
|
||||
|
||||
void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id);
|
||||
@@ -66,7 +66,6 @@ void KeyWrapper<LweBootstrapKey_u64>::load(Archive &ar,
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
new_allocated.push_back((void *)pointer);
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_bootstrap_key_u64(buffer);
|
||||
@@ -87,7 +86,6 @@ void KeyWrapper<LweKeyswitchKey_u64>::load(Archive &ar,
|
||||
size_t length;
|
||||
ar >> length;
|
||||
uint8_t *pointer = new uint8_t[length];
|
||||
new_allocated.push_back((void *)pointer);
|
||||
ar >> hpx::serialization::make_array(pointer, length);
|
||||
BufferView buffer = {(const uint8_t *)pointer, length};
|
||||
key = deserialize_lwe_keyswitching_key_u64(buffer);
|
||||
@@ -224,6 +222,68 @@ LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) {
|
||||
return _dfr_node_level_ksk_manager->get_key(loc, key_id);
|
||||
}
|
||||
|
||||
/************************/
|
||||
/* Context management. */
|
||||
/************************/
|
||||
|
||||
struct RuntimeContextManager {
|
||||
// TODO: this is only ok so long as we don't change keys. Once we
|
||||
// use multiple keys, should have a map.
|
||||
RuntimeContext *context;
|
||||
std::mutex context_guard;
|
||||
uint64_t ksk_id;
|
||||
uint64_t bsk_id;
|
||||
|
||||
RuntimeContextManager() {
|
||||
ksk_id = 0;
|
||||
bsk_id = 0;
|
||||
context = nullptr;
|
||||
_dfr_node_level_runtime_context_manager = this;
|
||||
}
|
||||
|
||||
RuntimeContext *getContext(uint64_t ksk, uint64_t bsk,
|
||||
hpx::naming::id_type source_locality) {
|
||||
std::cout << "GetContext on node " << hpx::get_locality_id()
|
||||
<< " with context " << context << " " << bsk_id << " " << ksk_id
|
||||
<< "\n"
|
||||
<< std::flush;
|
||||
if (context != nullptr) {
|
||||
std::cout << "simil " << ksk_id << " " << ksk << " " << bsk_id << " "
|
||||
<< bsk << "\n"
|
||||
<< std::flush;
|
||||
assert(ksk == ksk_id && bsk == bsk_id &&
|
||||
"Context manager can only used with single keys for now.");
|
||||
} else {
|
||||
assert(ksk_id == 0 && bsk_id == 0 &&
|
||||
"Context empty but context manager has key ids.");
|
||||
LweKeyswitchKey_u64 *keySwitchKey = _dfr_get_ksk(source_locality, ksk);
|
||||
LweBootstrapKey_u64 *bootstrapKey = _dfr_get_bsk(source_locality, bsk);
|
||||
std::lock_guard<std::mutex> guard(context_guard);
|
||||
if (context == nullptr) {
|
||||
auto ctx = new RuntimeContext();
|
||||
ctx->evaluationKeys = ::concretelang::clientlib::EvaluationKeys(
|
||||
std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>(
|
||||
new ::concretelang::clientlib::LweKeyswitchKey(keySwitchKey)),
|
||||
std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>(
|
||||
new ::concretelang::clientlib::LweBootstrapKey(bootstrapKey)));
|
||||
ksk_id = ksk;
|
||||
bsk_id = bsk;
|
||||
context = ctx;
|
||||
std::cout << "Fetching Key ids " << ksk_id << " " << bsk_id << "\n"
|
||||
<< std::flush;
|
||||
} else {
|
||||
std::cout << " GOT context after LOCK on node "
|
||||
<< hpx::get_locality_id() << " with context " << context
|
||||
<< " " << bsk_id << " " << ksk_id << "\n"
|
||||
<< std::flush;
|
||||
}
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
RuntimeContext **getContextAddress() { return &context; }
|
||||
};
|
||||
|
||||
} // namespace dfr
|
||||
} // namespace concretelang
|
||||
} // namespace mlir
|
||||
|
||||
@@ -14,7 +14,7 @@ extern "C" {
|
||||
|
||||
typedef void (*wfnptr)(...);
|
||||
|
||||
void *_dfr_make_ready_future(void *);
|
||||
void *_dfr_make_ready_future(void *, size_t);
|
||||
void _dfr_create_async_task(wfnptr, size_t, size_t, ...);
|
||||
void _dfr_register_work_function(wfnptr);
|
||||
void *_dfr_await_future(void *);
|
||||
|
||||
@@ -57,7 +57,7 @@ static bool isCandidateForTask(Operation *op) {
|
||||
/// operations must not have side-effects and not be `isCandidateForTask`
|
||||
static bool isSinkingBeneficiary(Operation *op) {
|
||||
return isa<FHE::ZeroEintOp, arith::ConstantOp, memref::DimOp, arith::SelectOp,
|
||||
mlir::arith::CmpIOp>(op);
|
||||
mlir::arith::CmpIOp, mlir::memref::GetGlobalOp>(op);
|
||||
}
|
||||
|
||||
static bool
|
||||
@@ -90,6 +90,92 @@ extractBeneficiaryOps(Operation *op, SetVector<Value> existingDependencies,
|
||||
return true;
|
||||
}
|
||||
|
||||
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
if (!sym)
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
||||
}
|
||||
|
||||
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
for (auto &use : val.getUses()) {
|
||||
aliasedUses.insert(&use);
|
||||
if (isa<memref::CastOp, memref::ViewOp, memref::SubViewOp>(use.getOwner()))
|
||||
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
|
||||
}
|
||||
}
|
||||
|
||||
static bool extractOutputMemrefAllocations(
|
||||
Operation *op, SetVector<Value> existingDependencies,
|
||||
SetVector<Operation *> &beneficiaryOps,
|
||||
llvm::SmallPtrSetImpl<Value> &availableValues, RT::DataflowTaskOp taskOp) {
|
||||
if (beneficiaryOps.count(op))
|
||||
return true;
|
||||
|
||||
if (!isa<mlir::memref::AllocOp>(op))
|
||||
return false;
|
||||
|
||||
Value val = op->getResults().front();
|
||||
DenseSet<OpOperand *> aliasedUses;
|
||||
getAliasedUses(val, aliasedUses);
|
||||
|
||||
// Helper function checking if a memref use writes to memory
|
||||
auto hasMemoryWriteEffect = [&](OpOperand *use) {
|
||||
// Call ops targeting concrete-ffi do not have memory effects
|
||||
// interface, so handle apart.
|
||||
// TODO: this could be handled better in BConcrete or higher.
|
||||
if (isa<mlir::func::CallOp>(use->getOwner())) {
|
||||
if (getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_expand_lut_in_trivial_glwe_ct_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_add_lwe_ciphertexts_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_add_plaintext_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_mul_cleartext_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_negate_lwe_ciphertext_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_keyswitch_lwe_u64" ||
|
||||
getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_bootstrap_lwe_u64")
|
||||
if (use->getOwner()->getOperand(0) == use->get())
|
||||
return true;
|
||||
|
||||
if (getCalledFunction(use->getOwner()).getName() ==
|
||||
"memref_copy_one_rank")
|
||||
if (use->getOwner()->getOperand(1) == use->get())
|
||||
return true;
|
||||
}
|
||||
|
||||
// Otherwise we rely on the memory effect interface
|
||||
auto effectInterface = dyn_cast<MemoryEffectOpInterface>(use->getOwner());
|
||||
if (!effectInterface)
|
||||
return false;
|
||||
SmallVector<MemoryEffects::EffectInstance, 2> effects;
|
||||
effectInterface.getEffects(effects);
|
||||
for (auto eff : effects)
|
||||
if (isa<MemoryEffects::Write>(eff.getEffect()) &&
|
||||
eff.getValue() == use->get())
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// We need to check if this allocated memref is written in this task.
|
||||
// TODO: for now we'll assume that we don't do partial writes or read/writes.
|
||||
for (auto use : aliasedUses)
|
||||
if (hasMemoryWriteEffect(use) &&
|
||||
use->getOwner()->getParentOfType<RT::DataflowTaskOp>() == taskOp) {
|
||||
// We will sink the operation, mark its results as now available.
|
||||
beneficiaryOps.insert(op);
|
||||
for (Value result : op->getResults())
|
||||
availableValues.insert(result);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
Region &taskOpBody = taskOp.body();
|
||||
|
||||
@@ -104,6 +190,8 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) {
|
||||
if (!operandOp)
|
||||
continue;
|
||||
extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues);
|
||||
extractOutputMemrefAllocations(operandOp, sinkCandidates, toBeSunk,
|
||||
availableValues, taskOp);
|
||||
}
|
||||
|
||||
// Insert operations so that the defs get cloned before uses.
|
||||
|
||||
@@ -25,6 +25,7 @@
|
||||
#include <mlir/Conversion/LLVMCommon/ConversionTarget.h>
|
||||
#include <mlir/Conversion/LLVMCommon/Pattern.h>
|
||||
#include <mlir/Conversion/LLVMCommon/VectorPattern.h>
|
||||
#include <mlir/Dialect/Affine/Utils.h>
|
||||
#include <mlir/Dialect/Arithmetic/IR/Arithmetic.h>
|
||||
#include <mlir/Dialect/Bufferization/Transforms/Passes.h>
|
||||
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
|
||||
@@ -117,7 +118,8 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp,
|
||||
static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement,
|
||||
Region ®ion) {
|
||||
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
|
||||
if (isa<RT::DataflowTaskOp>(use.getOwner()) &&
|
||||
if ((isa<RT::DataflowTaskOp>(use.getOwner()) ||
|
||||
isa<RT::DeallocateFutureOp>(use.getOwner())) &&
|
||||
region.isAncestor(use.getOwner()->getParentRegion()))
|
||||
use.set(replacement);
|
||||
}
|
||||
@@ -183,10 +185,7 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) {
|
||||
// bytes until we can get the actual size of the actual types.
|
||||
if (type.isa<mlir::concretelang::Concrete::LweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::GlweCiphertextType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweKeySwitchKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::LweBootstrapKeyType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::ForeignPlaintextListType>() ||
|
||||
type.isa<mlir::concretelang::Concrete::PlaintextListType>()) {
|
||||
type.isa<mlir::concretelang::Concrete::PlaintextType>()) {
|
||||
Value result =
|
||||
builder.create<arith::ConstantOp>(loc, builder.getI64IntegerAttr(8));
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
@@ -204,38 +203,69 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) {
|
||||
return std::pair<mlir::Value, mlir::Value>(result, arg_type);
|
||||
}
|
||||
|
||||
static void getAliasedUses(Value val, DenseSet<OpOperand *> &aliasedUses) {
|
||||
for (auto &use : val.getUses()) {
|
||||
aliasedUses.insert(&use);
|
||||
if (isa<memref::CastOp, memref::ViewOp, memref::SubViewOp>(use.getOwner()))
|
||||
getAliasedUses(use.getOwner()->getResult(0), aliasedUses);
|
||||
}
|
||||
}
|
||||
|
||||
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
func::FuncOp workFunction) {
|
||||
DataLayout dataLayout = DataLayout::closest(DFTOp);
|
||||
Region &opBody = DFTOp->getParentOfType<func::FuncOp>().getBody();
|
||||
BlockAndValueMapping map;
|
||||
OpBuilder builder(DFTOp);
|
||||
|
||||
// First identify DFT operands that are not futures and are not
|
||||
// defined by another DFT. These need to be made into futures and
|
||||
// propagated to all other DFTs. We can allow PRE to eliminate the
|
||||
// previous definitions if there are no non-future type uses.
|
||||
builder.setInsertionPoint(DFTOp);
|
||||
for (Value val : DFTOp.getOperands()) {
|
||||
if (!val.getType().isa<RT::FutureType>()) {
|
||||
Value newval;
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Type futType = RT::FutureType::get(val.getType());
|
||||
Value memrefCloned, newval = val;
|
||||
|
||||
// Find out if this value is needed in any other task
|
||||
SmallVector<Operation *, 2> taskOps;
|
||||
for (auto &use : val.getUses())
|
||||
if (isa<RT::DataflowTaskOp>(use.getOwner()))
|
||||
taskOps.push_back(use.getOwner());
|
||||
Operation *first = DFTOp;
|
||||
for (auto op : taskOps)
|
||||
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
|
||||
first = op;
|
||||
builder.setInsertionPoint(first);
|
||||
|
||||
// If we are building a future on a MemRef, then we need to clone
|
||||
// the memref in order to allow the deallocation pass which does
|
||||
// not synchronize with task execution.
|
||||
if (val.getType().isa<mlir::MemRefType>() && !val.isa<BlockArgument>()) {
|
||||
newval = builder
|
||||
.create<memref::AllocOp>(
|
||||
DFTOp.getLoc(), val.getType().cast<mlir::MemRefType>())
|
||||
if (val.getType().isa<mlir::MemRefType>()) {
|
||||
// Get the type of memref that we will clone. In case this is
|
||||
// a subview, we discard the mapping so we copy to a contiguous
|
||||
// layout which pre-serializes this.
|
||||
MemRefType mrType = val.getType().dyn_cast<mlir::MemRefType>();
|
||||
if (!mrType.getLayout().isIdentity()) {
|
||||
unsigned rank = mrType.getRank();
|
||||
mrType = MemRefType::Builder(mrType)
|
||||
.setShape(mrType.getShape())
|
||||
.setLayout(AffineMapAttr::get(
|
||||
builder.getMultiDimIdentityMap(rank)));
|
||||
}
|
||||
newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
|
||||
.getResult();
|
||||
builder.create<memref::CopyOp>(DFTOp.getLoc(), val, newval);
|
||||
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
|
||||
memrefCloned = builder.create<arith::ConstantOp>(
|
||||
val.getLoc(), builder.getI64IntegerAttr(1));
|
||||
} else {
|
||||
newval = val;
|
||||
memrefCloned = builder.create<arith::ConstantOp>(
|
||||
val.getLoc(), builder.getI64IntegerAttr(0));
|
||||
}
|
||||
Type futType = RT::FutureType::get(newval.getType());
|
||||
auto mrf = builder.create<RT::MakeReadyFutureOp>(DFTOp.getLoc(), futType,
|
||||
newval);
|
||||
map.map(mrf->getResult(0), val);
|
||||
replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody);
|
||||
|
||||
auto mrf = builder.create<RT::MakeReadyFutureOp>(val.getLoc(), futType,
|
||||
newval, memrefCloned);
|
||||
replaceAllUsesInDFTsInRegionWith(val, mrf, opBody);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -268,6 +298,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
// unsupported even in the LLVMIR Dialect - this needs to use two
|
||||
// placeholders for each output, before and after the
|
||||
// CreateAsyncTaskOp.
|
||||
BlockAndValueMapping map;
|
||||
for (auto result : DFTOp.getResults()) {
|
||||
Type futType = RT::PointerType::get(RT::FutureType::get(result.getType()));
|
||||
auto brpp = builder.create<RT::BuildReturnPtrPlaceholderOp>(DFTOp.getLoc(),
|
||||
@@ -297,6 +328,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
|
||||
for (auto &use : llvm::make_early_inc_range(result.getUses())) {
|
||||
if (!isa<RT::DataflowTaskOp>(use.getOwner()) &&
|
||||
!isa<RT::DeallocateFutureOp>(use.getOwner()) &&
|
||||
use.getOwner()->getParentOfType<RT::DataflowTaskOp>() == nullptr) {
|
||||
// Wait for this future before its uses
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
@@ -315,24 +347,35 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
|
||||
DFTOp.erase();
|
||||
}
|
||||
|
||||
static void registerWorkFunction(FuncOp parentFunc, FuncOp workFunction) {
|
||||
OpBuilder builder(parentFunc.body());
|
||||
builder.setInsertionPointToStart(&parentFunc.body().front());
|
||||
static void registerWorkFunction(mlir::func::FuncOp parentFunc,
|
||||
mlir::func::FuncOp workFunction) {
|
||||
OpBuilder builder(parentFunc.getBody());
|
||||
builder.setInsertionPointToStart(&parentFunc.getBody().front());
|
||||
|
||||
auto fnptr = builder.create<mlir::ConstantOp>(
|
||||
parentFunc.getLoc(), workFunction.getType(),
|
||||
auto fnptr = builder.create<mlir::func::ConstantOp>(
|
||||
parentFunc.getLoc(), workFunction.getFunctionType(),
|
||||
SymbolRefAttr::get(builder.getContext(), workFunction.getName()));
|
||||
|
||||
builder.create<RT::RegisterTaskWorkFunctionOp>(parentFunc.getLoc(),
|
||||
fnptr.getResult());
|
||||
}
|
||||
|
||||
static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
||||
SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast<SymbolRefAttr>();
|
||||
if (!sym)
|
||||
return nullptr;
|
||||
return dyn_cast_or_null<func::FuncOp>(
|
||||
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
||||
}
|
||||
|
||||
/// For documentation see Autopar.td
|
||||
struct LowerDataflowTasksPass
|
||||
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
SmallVector<func::FuncOp, 4> workFunctions;
|
||||
SmallVector<func::FuncOp, 1> entryPoints;
|
||||
|
||||
module.walk([&](mlir::func::FuncOp func) {
|
||||
static int wfn_id = 0;
|
||||
@@ -342,7 +385,7 @@ struct LowerDataflowTasksPass
|
||||
return;
|
||||
|
||||
SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func);
|
||||
std::vector<std::pair<RT::DataflowTaskOp, func::FuncOp>> outliningMap;
|
||||
SmallVector<std::pair<RT::DataflowTaskOp, func::FuncOp>, 4> outliningMap;
|
||||
|
||||
func.walk([&](RT::DataflowTaskOp op) {
|
||||
auto workFunctionName =
|
||||
@@ -353,6 +396,7 @@ struct LowerDataflowTasksPass
|
||||
outlineWorkFunction(op, workFunctionName.str());
|
||||
outliningMap.push_back(
|
||||
std::pair<RT::DataflowTaskOp, func::FuncOp>(op, outlinedFunc));
|
||||
workFunctions.push_back(outlinedFunc);
|
||||
symbolTable.insert(outlinedFunc);
|
||||
return WalkResult::advance();
|
||||
});
|
||||
@@ -361,73 +405,72 @@ struct LowerDataflowTasksPass
|
||||
for (auto mapping : outliningMap)
|
||||
lowerDataflowTaskOp(mapping.first, mapping.second);
|
||||
|
||||
// Main is always an entry-point - otherwise check if this
|
||||
// function is called within the module. TODO: we assume no
|
||||
// recursion.
|
||||
if (func.getName() == "main")
|
||||
entryPoints.push_back(func);
|
||||
else {
|
||||
bool found = false;
|
||||
module.walk([&](mlir::func::CallOp op) {
|
||||
if (getCalledFunction(op) == func)
|
||||
found = true;
|
||||
});
|
||||
if (!found)
|
||||
entryPoints.push_back(func);
|
||||
}
|
||||
});
|
||||
|
||||
for (auto entryPoint : entryPoints) {
|
||||
// If this is a JIT invocation and we're not on the root node,
|
||||
// we do not need to do any computation, only register all work
|
||||
// functions with the runtime system
|
||||
if (!outliningMap.empty()) {
|
||||
if (!workFunctions.empty()) {
|
||||
if (!dfr::_dfr_is_root_node()) {
|
||||
// auto regFunc = builder.create<FuncOp>(func.getLoc(),
|
||||
// func.getName(), func.getType());
|
||||
|
||||
func.eraseBody();
|
||||
entryPoint.eraseBody();
|
||||
Block *b = new Block;
|
||||
b->addArguments(func.getType().getInputs());
|
||||
func.body().push_front(b);
|
||||
for (int i = func.getType().getNumInputs() - 1; i >= 0; --i)
|
||||
func.eraseArgument(i);
|
||||
for (int i = func.getType().getNumResults() - 1; i >= 0; --i)
|
||||
func.eraseResult(i);
|
||||
OpBuilder builder(func.body());
|
||||
builder.setInsertionPointToEnd(&func.body().front());
|
||||
builder.create<ReturnOp>(func.getLoc());
|
||||
SmallVector<Location, 4> locations;
|
||||
for (auto input : entryPoint.getFunctionType().getInputs())
|
||||
locations.push_back(entryPoint.getLoc());
|
||||
b->addArguments(entryPoint.getFunctionType().getInputs(), locations);
|
||||
entryPoint.getBody().push_front(b);
|
||||
for (int i = entryPoint.getFunctionType().getNumInputs() - 1; i >= 0;
|
||||
--i)
|
||||
entryPoint.eraseArgument(i);
|
||||
for (int i = entryPoint.getFunctionType().getNumResults() - 1; i >= 0;
|
||||
--i)
|
||||
entryPoint.eraseResult(i);
|
||||
OpBuilder builder(entryPoint.getBody());
|
||||
builder.setInsertionPointToEnd(&entryPoint.getBody().front());
|
||||
builder.create<mlir::func::ReturnOp>(entryPoint.getLoc());
|
||||
}
|
||||
}
|
||||
|
||||
// Generate code to register all work-functions with the
|
||||
// runtime.
|
||||
for (auto mapping : outliningMap)
|
||||
registerWorkFunction(func, mapping.second);
|
||||
for (auto wf : workFunctions)
|
||||
registerWorkFunction(entryPoint, wf);
|
||||
|
||||
// Issue _dfr_start/stop calls for this function
|
||||
if (!outliningMap.empty()) {
|
||||
OpBuilder builder(func.getBody());
|
||||
builder.setInsertionPointToStart(&func.getBody().front());
|
||||
if (!workFunctions.empty()) {
|
||||
OpBuilder builder(entryPoint.getBody());
|
||||
builder.setInsertionPointToStart(&entryPoint.getBody().front());
|
||||
auto dfrStartFunOp = mlir::LLVM::lookupOrCreateFn(
|
||||
func->getParentOfType<ModuleOp>(), "_dfr_start", {},
|
||||
LLVM::LLVMVoidType::get(func->getContext()));
|
||||
builder.create<LLVM::CallOp>(func.getLoc(), dfrStartFunOp,
|
||||
module, "_dfr_start", {},
|
||||
LLVM::LLVMVoidType::get(entryPoint->getContext()));
|
||||
builder.create<LLVM::CallOp>(entryPoint.getLoc(), dfrStartFunOp,
|
||||
mlir::ValueRange(),
|
||||
ArrayRef<NamedAttribute>());
|
||||
|
||||
builder.setInsertionPoint(func.getBody().back().getTerminator());
|
||||
builder.setInsertionPoint(entryPoint.getBody().back().getTerminator());
|
||||
auto dfrStopFunOp = mlir::LLVM::lookupOrCreateFn(
|
||||
func->getParentOfType<ModuleOp>(), "_dfr_stop", {},
|
||||
LLVM::LLVMVoidType::get(func->getContext()));
|
||||
builder.create<LLVM::CallOp>(func.getLoc(), dfrStopFunOp,
|
||||
module, "_dfr_stop", {},
|
||||
LLVM::LLVMVoidType::get(entryPoint->getContext()));
|
||||
builder.create<LLVM::CallOp>(entryPoint.getLoc(), dfrStopFunOp,
|
||||
mlir::ValueRange(),
|
||||
ArrayRef<NamedAttribute>());
|
||||
}
|
||||
});
|
||||
|
||||
// Delay memref deallocations when memrefs are made into futures
|
||||
module.walk([&](Operation *op) {
|
||||
if (isa<RT::MakeReadyFutureOp>(*op) &&
|
||||
op->getOperand(0).getType().isa<mlir::MemRefType>()) {
|
||||
for (auto &use :
|
||||
llvm::make_early_inc_range(op->getOperand(0).getUses())) {
|
||||
if (isa<mlir::memref::DeallocOp>(use.getOwner())) {
|
||||
OpBuilder builder(use.getOwner()
|
||||
->getParentOfType<mlir::func::FuncOp>()
|
||||
.getBody()
|
||||
.back()
|
||||
.getTerminator());
|
||||
builder.clone(*use.getOwner());
|
||||
use.getOwner()->erase();
|
||||
}
|
||||
}
|
||||
}
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
}
|
||||
LowerDataflowTasksPass(bool debug) : debug(debug){};
|
||||
|
||||
@@ -440,5 +483,44 @@ std::unique_ptr<mlir::Pass> createLowerDataflowTasksPass(bool debug) {
|
||||
return std::make_unique<LowerDataflowTasksPass>(debug);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// For documentation see Autopar.td
|
||||
struct FixupBufferDeallocationPass
|
||||
: public FixupBufferDeallocationBase<FixupBufferDeallocationPass> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
std::vector<Operation *> ops;
|
||||
|
||||
// All buffers allocated and either made into a future, directly
|
||||
// or as a result of being returned by a task, are managed by the
|
||||
// DFR runtime system's reference counting.
|
||||
module.walk([&](RT::WorkFunctionReturnOp retOp) {
|
||||
for (auto &use :
|
||||
llvm::make_early_inc_range(retOp.getOperands().front().getUses()))
|
||||
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
|
||||
ops.push_back(use.getOwner());
|
||||
});
|
||||
module.walk([&](RT::MakeReadyFutureOp mrfOp) {
|
||||
for (auto &use :
|
||||
llvm::make_early_inc_range(mrfOp.getOperands().front().getUses()))
|
||||
if (isa<mlir::memref::DeallocOp>(use.getOwner()))
|
||||
ops.push_back(use.getOwner());
|
||||
});
|
||||
for (auto op : ops)
|
||||
op->erase();
|
||||
}
|
||||
FixupBufferDeallocationPass(bool debug) : debug(debug){};
|
||||
|
||||
protected:
|
||||
bool debug;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
std::unique_ptr<mlir::Pass> createFixupBufferDeallocationPass(bool debug) {
|
||||
return std::make_unique<FixupBufferDeallocationPass>(debug);
|
||||
}
|
||||
|
||||
} // end namespace concretelang
|
||||
} // end namespace mlir
|
||||
|
||||
@@ -125,8 +125,9 @@ struct MakeReadyFutureOpInterfaceLowering
|
||||
results[0]);
|
||||
rewriter.create<LLVM::StoreOp>(mrfOp.getLoc(),
|
||||
adaptor.getOperands().front(), allocatedPtr);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(mrfOp, mrfFuncOp, allocatedPtr);
|
||||
|
||||
SmallVector<Value, 4> mrfOperands = {adaptor.getOperands()};
|
||||
mrfOperands[0] = allocatedPtr;
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(mrfOp, mrfFuncOp, mrfOperands);
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
@@ -178,16 +179,14 @@ struct RegisterTaskWorkFunctionOpInterfaceLowering
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(RT::RegisterTaskWorkFunctionOp rtwfOp,
|
||||
ArrayRef<Value> operands,
|
||||
RT::RegisterTaskWorkFunctionOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
RT::RegisterTaskWorkFunctionOp::Adaptor transformed(operands);
|
||||
|
||||
auto rtwfFuncType =
|
||||
LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true);
|
||||
auto rtwfFuncOp = getOrInsertFuncOpDecl(
|
||||
rtwfOp, "_dfr_register_work_function", rtwfFuncType, rewriter);
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(rtwfOp, rtwfFuncOp,
|
||||
transformed.getOperands());
|
||||
adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Region.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
#include "concretelang/Dialect/RT/IR/RTOps.h"
|
||||
#include "concretelang/Dialect/RT/IR/RTTypes.h"
|
||||
@@ -33,3 +34,66 @@ void DataflowTaskOp::build(
|
||||
void DataflowTaskOp::getSuccessorRegions(
|
||||
Optional<unsigned> index, ArrayRef<Attribute> operands,
|
||||
SmallVectorImpl<RegionSuccessor> ®ions) {}
|
||||
|
||||
llvm::Optional<mlir::Operation *>
|
||||
DataflowTaskOp::buildDealloc(OpBuilder &builder, Value alloc) {
|
||||
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
|
||||
.getOperation();
|
||||
}
|
||||
llvm::Optional<mlir::Value> DataflowTaskOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
|
||||
}
|
||||
void DataflowTaskOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
for (auto input : inputs())
|
||||
effects.emplace_back(MemoryEffects::Read::get(), input,
|
||||
SideEffects::DefaultResource::get());
|
||||
for (auto output : outputs())
|
||||
effects.emplace_back(MemoryEffects::Write::get(), output,
|
||||
SideEffects::DefaultResource::get());
|
||||
for (auto output : outputs())
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), output,
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::Operation *>
|
||||
CloneFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
|
||||
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
|
||||
.getOperation();
|
||||
}
|
||||
llvm::Optional<mlir::Value> CloneFutureOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
|
||||
}
|
||||
void CloneFutureOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), input(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
llvm::Optional<mlir::Operation *>
|
||||
MakeReadyFutureOp::buildDealloc(OpBuilder &builder, Value alloc) {
|
||||
return builder.create<DeallocateFutureOp>(alloc.getLoc(), alloc)
|
||||
.getOperation();
|
||||
}
|
||||
llvm::Optional<mlir::Value> MakeReadyFutureOp::buildClone(OpBuilder &builder,
|
||||
Value alloc) {
|
||||
return builder.create<CloneFutureOp>(alloc.getLoc(), alloc).getResult();
|
||||
}
|
||||
void MakeReadyFutureOp::getEffects(
|
||||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
|
||||
&effects) {
|
||||
effects.emplace_back(MemoryEffects::Read::get(), input(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Write::get(), output(),
|
||||
SideEffects::DefaultResource::get());
|
||||
effects.emplace_back(MemoryEffects::Allocate::get(), output(),
|
||||
SideEffects::DefaultResource::get());
|
||||
}
|
||||
|
||||
@@ -38,33 +38,50 @@ static size_t num_nodes = 0;
|
||||
|
||||
using namespace hpx;
|
||||
|
||||
void *_dfr_make_ready_future(void *in) {
|
||||
void *future = static_cast<void *>(
|
||||
new hpx::shared_future<void *>(hpx::make_ready_future(in)));
|
||||
mlir::concretelang::dfr::m_allocated.push_back(in);
|
||||
mlir::concretelang::dfr::fut_allocated.push_back(future);
|
||||
return future;
|
||||
typedef struct dfr_refcounted_future {
|
||||
hpx::shared_future<void *> *future;
|
||||
std::atomic<std::size_t> count;
|
||||
bool cloned_memref_p;
|
||||
dfr_refcounted_future(hpx::shared_future<void *> *f, size_t c, bool clone_p)
|
||||
: future(f), count(c), cloned_memref_p(clone_p) {}
|
||||
} dfr_refcounted_future_t, *dfr_refcounted_future_p;
|
||||
|
||||
// Ready futures are only used as inputs to tasks (never passed to
|
||||
// await_future), so we only need to track the references in task
|
||||
// creation.
|
||||
void *_dfr_make_ready_future(void *in, size_t memref_clone_p) {
|
||||
return (void *)new dfr_refcounted_future_t(
|
||||
new hpx::shared_future<void *>(hpx::make_ready_future(in)), 1,
|
||||
memref_clone_p);
|
||||
}
|
||||
|
||||
void *_dfr_await_future(void *in) {
|
||||
return static_cast<hpx::shared_future<void *> *>(in)->get();
|
||||
}
|
||||
|
||||
void _dfr_deallocate_future_data(void *in) {
|
||||
delete[] static_cast<char *>(
|
||||
static_cast<hpx::shared_future<void *> *>(in)->get());
|
||||
return static_cast<dfr_refcounted_future_p>(in)->future->get();
|
||||
}
|
||||
|
||||
void _dfr_deallocate_future(void *in) {
|
||||
delete (static_cast<hpx::shared_future<void *> *>(in));
|
||||
auto drf = static_cast<dfr_refcounted_future_p>(in);
|
||||
size_t prev_count = drf->count.fetch_sub(1);
|
||||
if (prev_count == 1) {
|
||||
// If this was a memref for which a clone was needed, deallocate first.
|
||||
if (drf->cloned_memref_p)
|
||||
free(
|
||||
(void *)(static_cast<StridedMemRefType<char, 1> *>(drf->future->get())
|
||||
->data));
|
||||
free(drf->future->get());
|
||||
delete (drf->future);
|
||||
delete drf;
|
||||
}
|
||||
}
|
||||
|
||||
void _dfr_deallocate_future_data(void *in) {}
|
||||
|
||||
// Determine where new task should run. For now just round-robin
|
||||
// distribution - TODO: optimise.
|
||||
static inline size_t _dfr_find_next_execution_locality() {
|
||||
static std::atomic<std::size_t> next_locality{0};
|
||||
static std::atomic<std::size_t> next_locality{1};
|
||||
|
||||
size_t next_loc = ++next_locality;
|
||||
size_t next_loc = next_locality.fetch_add(1);
|
||||
|
||||
return next_loc % mlir::concretelang::dfr::num_nodes;
|
||||
}
|
||||
@@ -76,7 +93,8 @@ static inline size_t _dfr_find_next_execution_locality() {
|
||||
/// the returns.
|
||||
void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
...) {
|
||||
std::vector<void *> params;
|
||||
// std::vector<void *> params;
|
||||
std::vector<void *> refcounted_futures;
|
||||
std::vector<size_t> param_sizes;
|
||||
std::vector<uint64_t> param_types;
|
||||
std::vector<void *> outputs;
|
||||
@@ -86,7 +104,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
va_list args;
|
||||
va_start(args, num_outputs);
|
||||
for (size_t i = 0; i < num_params; ++i) {
|
||||
params.push_back(va_arg(args, void *));
|
||||
refcounted_futures.push_back(va_arg(args, void *));
|
||||
param_sizes.push_back(va_arg(args, uint64_t));
|
||||
param_types.push_back(va_arg(args, uint64_t));
|
||||
}
|
||||
@@ -97,13 +115,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
}
|
||||
va_end(args);
|
||||
|
||||
for (size_t i = 0; i < num_params; ++i) {
|
||||
if (mlir::concretelang::dfr::_dfr_get_arg_type(
|
||||
param_types[i] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF)) {
|
||||
mlir::concretelang::dfr::m_allocated.push_back(
|
||||
(void *)static_cast<StridedMemRefType<char, 1> *>(params[i])->data);
|
||||
}
|
||||
}
|
||||
// Take a reference on each future argument
|
||||
for (auto rcf : refcounted_futures)
|
||||
((dfr_refcounted_future_p)rcf)->count.fetch_add(1);
|
||||
|
||||
// We pass functions by name - which is not strictly necessary in
|
||||
// shared memory as pointers suffice, but is needed in the
|
||||
@@ -147,7 +161,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future));
|
||||
break;
|
||||
|
||||
case 2:
|
||||
@@ -164,8 +178,8 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future));
|
||||
break;
|
||||
|
||||
case 3:
|
||||
@@ -184,9 +198,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future));
|
||||
break;
|
||||
|
||||
case 4:
|
||||
@@ -206,10 +220,10 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future));
|
||||
break;
|
||||
|
||||
case 5:
|
||||
@@ -231,11 +245,11 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future));
|
||||
break;
|
||||
|
||||
case 6:
|
||||
@@ -258,12 +272,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future));
|
||||
break;
|
||||
|
||||
case 7:
|
||||
@@ -287,13 +301,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future));
|
||||
break;
|
||||
|
||||
case 8:
|
||||
@@ -318,14 +332,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future));
|
||||
break;
|
||||
|
||||
case 9:
|
||||
@@ -352,15 +366,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future));
|
||||
break;
|
||||
|
||||
case 10:
|
||||
@@ -388,16 +402,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8],
|
||||
*(hpx::shared_future<void *> *)params[9]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future));
|
||||
break;
|
||||
|
||||
case 11:
|
||||
@@ -426,17 +440,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8],
|
||||
*(hpx::shared_future<void *> *)params[9],
|
||||
*(hpx::shared_future<void *> *)params[10]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future));
|
||||
break;
|
||||
|
||||
case 12:
|
||||
@@ -466,18 +480,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8],
|
||||
*(hpx::shared_future<void *> *)params[9],
|
||||
*(hpx::shared_future<void *> *)params[10],
|
||||
*(hpx::shared_future<void *> *)params[11]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future));
|
||||
break;
|
||||
|
||||
case 13:
|
||||
@@ -509,19 +523,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8],
|
||||
*(hpx::shared_future<void *> *)params[9],
|
||||
*(hpx::shared_future<void *> *)params[10],
|
||||
*(hpx::shared_future<void *> *)params[11],
|
||||
*(hpx::shared_future<void *> *)params[12]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future));
|
||||
break;
|
||||
|
||||
case 14:
|
||||
@@ -554,20 +568,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8],
|
||||
*(hpx::shared_future<void *> *)params[9],
|
||||
*(hpx::shared_future<void *> *)params[10],
|
||||
*(hpx::shared_future<void *> *)params[11],
|
||||
*(hpx::shared_future<void *> *)params[12],
|
||||
*(hpx::shared_future<void *> *)params[13]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[13])->future));
|
||||
break;
|
||||
|
||||
case 15:
|
||||
@@ -601,21 +615,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8],
|
||||
*(hpx::shared_future<void *> *)params[9],
|
||||
*(hpx::shared_future<void *> *)params[10],
|
||||
*(hpx::shared_future<void *> *)params[11],
|
||||
*(hpx::shared_future<void *> *)params[12],
|
||||
*(hpx::shared_future<void *> *)params[13],
|
||||
*(hpx::shared_future<void *> *)params[14]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[14])->future));
|
||||
break;
|
||||
|
||||
case 16:
|
||||
@@ -650,22 +664,246 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*(hpx::shared_future<void *> *)params[0],
|
||||
*(hpx::shared_future<void *> *)params[1],
|
||||
*(hpx::shared_future<void *> *)params[2],
|
||||
*(hpx::shared_future<void *> *)params[3],
|
||||
*(hpx::shared_future<void *> *)params[4],
|
||||
*(hpx::shared_future<void *> *)params[5],
|
||||
*(hpx::shared_future<void *> *)params[6],
|
||||
*(hpx::shared_future<void *> *)params[7],
|
||||
*(hpx::shared_future<void *> *)params[8],
|
||||
*(hpx::shared_future<void *> *)params[9],
|
||||
*(hpx::shared_future<void *> *)params[10],
|
||||
*(hpx::shared_future<void *> *)params[11],
|
||||
*(hpx::shared_future<void *> *)params[12],
|
||||
*(hpx::shared_future<void *> *)params[13],
|
||||
*(hpx::shared_future<void *> *)params[14],
|
||||
*(hpx::shared_future<void *> *)params[15]));
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[15])->future));
|
||||
break;
|
||||
|
||||
case 17:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
param4.get(), param5.get(), param6.get(), param7.get(),
|
||||
param8.get(), param9.get(), param10.get(), param11.get(),
|
||||
param12.get(), param13.get(), param14.get(), param15.get(),
|
||||
param16.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
return mlir::concretelang::dfr::gcc
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[16])->future));
|
||||
break;
|
||||
|
||||
case 18:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
param4.get(), param5.get(), param6.get(), param7.get(),
|
||||
param8.get(), param9.get(), param10.get(), param11.get(),
|
||||
param12.get(), param13.get(), param14.get(), param15.get(),
|
||||
param16.get(), param17.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
return mlir::concretelang::dfr::gcc
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[16])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[17])->future));
|
||||
break;
|
||||
|
||||
case 19:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17,
|
||||
hpx::shared_future<void *> param18)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
param4.get(), param5.get(), param6.get(), param7.get(),
|
||||
param8.get(), param9.get(), param10.get(), param11.get(),
|
||||
param12.get(), param13.get(), param14.get(), param15.get(),
|
||||
param16.get(), param17.get(), param18.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
return mlir::concretelang::dfr::gcc
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[16])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[17])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[18])->future));
|
||||
break;
|
||||
|
||||
case 20:
|
||||
oodf = std::move(hpx::dataflow(
|
||||
[wfnname, param_sizes, param_types, output_sizes,
|
||||
output_types](hpx::shared_future<void *> param0,
|
||||
hpx::shared_future<void *> param1,
|
||||
hpx::shared_future<void *> param2,
|
||||
hpx::shared_future<void *> param3,
|
||||
hpx::shared_future<void *> param4,
|
||||
hpx::shared_future<void *> param5,
|
||||
hpx::shared_future<void *> param6,
|
||||
hpx::shared_future<void *> param7,
|
||||
hpx::shared_future<void *> param8,
|
||||
hpx::shared_future<void *> param9,
|
||||
hpx::shared_future<void *> param10,
|
||||
hpx::shared_future<void *> param11,
|
||||
hpx::shared_future<void *> param12,
|
||||
hpx::shared_future<void *> param13,
|
||||
hpx::shared_future<void *> param14,
|
||||
hpx::shared_future<void *> param15,
|
||||
hpx::shared_future<void *> param16,
|
||||
hpx::shared_future<void *> param17,
|
||||
hpx::shared_future<void *> param18,
|
||||
hpx::shared_future<void *> param19)
|
||||
-> hpx::future<mlir::concretelang::dfr::OpaqueOutputData> {
|
||||
std::vector<void *> params = {
|
||||
param0.get(), param1.get(), param2.get(), param3.get(),
|
||||
param4.get(), param5.get(), param6.get(), param7.get(),
|
||||
param8.get(), param9.get(), param10.get(), param11.get(),
|
||||
param12.get(), param13.get(), param14.get(), param15.get(),
|
||||
param16.get(), param17.get(), param18.get(), param19.get()};
|
||||
mlir::concretelang::dfr::OpaqueInputData oid(
|
||||
wfnname, params, param_sizes, param_types, output_sizes,
|
||||
output_types);
|
||||
return mlir::concretelang::dfr::gcc
|
||||
[_dfr_find_next_execution_locality()]
|
||||
.execute_task(oid);
|
||||
},
|
||||
*((dfr_refcounted_future_p)refcounted_futures[0])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[1])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[2])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[3])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[4])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[5])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[6])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[7])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[8])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[9])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[10])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[11])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[12])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[13])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[14])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[15])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[16])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[17])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[18])->future,
|
||||
*((dfr_refcounted_future_p)refcounted_futures[19])->future));
|
||||
break;
|
||||
|
||||
default:
|
||||
@@ -675,53 +913,67 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs,
|
||||
|
||||
switch (num_outputs) {
|
||||
case 1:
|
||||
*((void **)outputs[0]) = new hpx::shared_future<void *>(hpx::dataflow(
|
||||
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
|
||||
-> void * { return oodf_in.get().outputs[0]; },
|
||||
oodf));
|
||||
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
|
||||
*((void **)outputs[0]) = (void *)new dfr_refcounted_future_t(
|
||||
new hpx::shared_future<void *>(hpx::dataflow(
|
||||
[refcounted_futures](
|
||||
hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
|
||||
-> void * {
|
||||
void *ret = oodf_in.get().outputs[0];
|
||||
for (auto rcf : refcounted_futures)
|
||||
_dfr_deallocate_future(rcf);
|
||||
return ret;
|
||||
},
|
||||
oodf)),
|
||||
1, output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
|
||||
break;
|
||||
|
||||
case 2: {
|
||||
hpx::future<hpx::tuple<void *, void *>> &&ft = hpx::dataflow(
|
||||
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
|
||||
[refcounted_futures](
|
||||
hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
|
||||
-> hpx::tuple<void *, void *> {
|
||||
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
|
||||
for (auto rcf : refcounted_futures)
|
||||
_dfr_deallocate_future(rcf);
|
||||
return hpx::make_tuple<>(outputs[0], outputs[1]);
|
||||
},
|
||||
oodf);
|
||||
hpx::tuple<hpx::future<void *>, hpx::future<void *>> &&tf =
|
||||
hpx::split_future(std::move(ft));
|
||||
*((void **)outputs[0]) =
|
||||
(void *)new hpx::shared_future<void *>(std::move(hpx::get<0>(tf)));
|
||||
*((void **)outputs[1]) =
|
||||
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
|
||||
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
|
||||
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1]));
|
||||
*((void **)outputs[0]) = (void *)new dfr_refcounted_future_t(
|
||||
new hpx::shared_future<void *>(std::move(hpx::get<0>(tf))), 1,
|
||||
output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
|
||||
*((void **)outputs[1]) = (void *)new dfr_refcounted_future_t(
|
||||
new hpx::shared_future<void *>(std::move(hpx::get<1>(tf))), 1,
|
||||
output_types[1] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
|
||||
break;
|
||||
}
|
||||
|
||||
case 3: {
|
||||
hpx::future<hpx::tuple<void *, void *, void *>> &&ft = hpx::dataflow(
|
||||
[](hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
|
||||
[refcounted_futures](
|
||||
hpx::future<mlir::concretelang::dfr::OpaqueOutputData> oodf_in)
|
||||
-> hpx::tuple<void *, void *, void *> {
|
||||
std::vector<void *> outputs = std::move(oodf_in.get().outputs);
|
||||
for (auto rcf : refcounted_futures)
|
||||
_dfr_deallocate_future(rcf);
|
||||
return hpx::make_tuple<>(outputs[0], outputs[1], outputs[2]);
|
||||
},
|
||||
oodf);
|
||||
hpx::tuple<hpx::future<void *>, hpx::future<void *>, hpx::future<void *>>
|
||||
&&tf = hpx::split_future(std::move(ft));
|
||||
*((void **)outputs[0]) =
|
||||
(void *)new hpx::shared_future<void *>(std::move(hpx::get<0>(tf)));
|
||||
*((void **)outputs[1]) =
|
||||
(void *)new hpx::shared_future<void *>(std::move(hpx::get<1>(tf)));
|
||||
*((void **)outputs[2]) =
|
||||
(void *)new hpx::shared_future<void *>(std::move(hpx::get<2>(tf)));
|
||||
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0]));
|
||||
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1]));
|
||||
mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[2]));
|
||||
*((void **)outputs[0]) = (void *)new dfr_refcounted_future_t(
|
||||
new hpx::shared_future<void *>(std::move(hpx::get<0>(tf))), 1,
|
||||
output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
|
||||
*((void **)outputs[1]) = (void *)new dfr_refcounted_future_t(
|
||||
new hpx::shared_future<void *>(std::move(hpx::get<1>(tf))), 1,
|
||||
output_types[1] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
|
||||
*((void **)outputs[2]) = (void *)new dfr_refcounted_future_t(
|
||||
new hpx::shared_future<void *>(std::move(hpx::get<2>(tf))), 1,
|
||||
output_types[2] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF);
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_create_async_task",
|
||||
"Error: number of task outputs not supported.");
|
||||
@@ -896,6 +1148,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) {
|
||||
|
||||
new mlir::concretelang::dfr::KeyManager<LweBootstrapKey_u64>();
|
||||
new mlir::concretelang::dfr::KeyManager<LweKeyswitchKey_u64>();
|
||||
new mlir::concretelang::dfr::RuntimeContextManager();
|
||||
new mlir::concretelang::dfr::WorkFunctionRegistry();
|
||||
mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier =
|
||||
new hpx::lcos::barrier("wait_register_remote_work_functions",
|
||||
@@ -985,21 +1238,8 @@ void _dfr_stop() {
|
||||
// safer to drop them in-between phases.
|
||||
mlir::concretelang::dfr::_dfr_node_level_bsk_manager->clear_keys();
|
||||
mlir::concretelang::dfr::_dfr_node_level_ksk_manager->clear_keys();
|
||||
|
||||
while (!mlir::concretelang::dfr::new_allocated.empty()) {
|
||||
delete[] static_cast<char *>(
|
||||
mlir::concretelang::dfr::new_allocated.front());
|
||||
mlir::concretelang::dfr::new_allocated.pop_front();
|
||||
}
|
||||
while (!mlir::concretelang::dfr::fut_allocated.empty()) {
|
||||
delete static_cast<hpx::shared_future<void *> *>(
|
||||
mlir::concretelang::dfr::fut_allocated.front());
|
||||
mlir::concretelang::dfr::fut_allocated.pop_front();
|
||||
}
|
||||
while (!mlir::concretelang::dfr::m_allocated.empty()) {
|
||||
free(mlir::concretelang::dfr::m_allocated.front());
|
||||
mlir::concretelang::dfr::m_allocated.pop_front();
|
||||
}
|
||||
mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager
|
||||
->clearContext();
|
||||
}
|
||||
|
||||
void _dfr_try_initialize() {
|
||||
|
||||
@@ -278,6 +278,11 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module,
|
||||
pm, mlir::concretelang::createFixupDataflowTaskOpsPass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass);
|
||||
// Use the buffer deallocation interface to insert future deallocation calls
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::bufferization::createBufferDeallocationPass(), enablePass);
|
||||
addPotentiallyNestedPass(
|
||||
pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass);
|
||||
|
||||
// Convert to MLIR LLVM Dialect
|
||||
addPotentiallyNestedPass(
|
||||
|
||||
Reference in New Issue
Block a user