diff --git a/compiler/include/concretelang/ClientLib/KeySet.h b/compiler/include/concretelang/ClientLib/KeySet.h index 3924d90f5..b767f19c3 100644 --- a/compiler/include/concretelang/ClientLib/KeySet.h +++ b/compiler/include/concretelang/ClientLib/KeySet.h @@ -19,7 +19,6 @@ extern "C" { #include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/ClientLib/KeySetCache.h" #include "concretelang/Common/Error.h" -#include namespace concretelang { namespace clientlib { @@ -86,9 +85,15 @@ public: } EvaluationKeys evaluationKeys() { - auto sharedKsk = std::get<1>(this->keyswitchKeys.at("ksk_v0")); - auto sharedBsk = std::get<1>(this->bootstrapKeys.at("bsk_v0")); - return EvaluationKeys(sharedKsk, sharedBsk); + auto kskIt = this->keyswitchKeys.find("ksk_v0"); + auto bskIt = this->bootstrapKeys.find("bsk_v0"); + if (kskIt != this->keyswitchKeys.end() && + bskIt != this->bootstrapKeys.end()) { + auto sharedKsk = std::get<1>(kskIt->second); + auto sharedBsk = std::get<1>(bskIt->second); + return EvaluationKeys(sharedKsk, sharedBsk); + } + return EvaluationKeys(); } const std::map createLowerDataflowTasksPass(bool debug = false); std::unique_ptr createBufferizeDataflowTaskOpsPass(bool debug = false); std::unique_ptr createFixupDataflowTaskOpsPass(bool debug = false); +std::unique_ptr +createFixupBufferDeallocationPass(bool debug = false); void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns); void populateRTBufferizePatterns(mlir::BufferizeTypeConverter &typeConverter, diff --git a/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td b/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td index d1ac576ed..7901d4d4e 100644 --- a/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td +++ b/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td @@ -82,6 +82,15 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> { }]; } +def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> { + let summary = + "Prevent deallocation of buffers returned as futures by tasks."; + let description = [{ This pass removes buffer deallocation calls on + buffers being used for dataflow communication between + tasks. These buffers cannot be deallocated directly without + synchronization as they can be needed by asynchronous + computation. Instead, these will be deallocated by the runtime + when no longer needed.}]; } #endif diff --git a/compiler/include/concretelang/Dialect/RT/IR/RTOps.h b/compiler/include/concretelang/Dialect/RT/IR/RTOps.h index a35c4d3ef..fa6ac2032 100644 --- a/compiler/include/concretelang/Dialect/RT/IR/RTOps.h +++ b/compiler/include/concretelang/Dialect/RT/IR/RTOps.h @@ -6,10 +6,12 @@ #ifndef CONCRETELANG_DIALECT_RT_IR_RTOPS_H #define CONCRETELANG_DIALECT_RT_IR_RTOPS_H +#include #include #include #include #include +#include #include #include "concretelang/Dialect/RT/IR/RTTypes.h" diff --git a/compiler/include/concretelang/Dialect/RT/IR/RTOps.td b/compiler/include/concretelang/Dialect/RT/IR/RTOps.td index 3ae30f408..c81cb3acc 100644 --- a/compiler/include/concretelang/Dialect/RT/IR/RTOps.td +++ b/compiler/include/concretelang/Dialect/RT/IR/RTOps.td @@ -1,9 +1,10 @@ #ifndef CONCRETELANG_DIALECT_RT_IR_RT_OPS #define CONCRETELANG_DIALECT_RT_IR_RT_OPS +include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" -include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" @@ -15,9 +16,13 @@ class RT_Op traits = []> : def RT_DataflowTaskOp : RT_Op<"dataflow_task", [ DeclareOpInterfaceMethods, - SingleBlockImplicitTerminator<"DataflowYieldOp">]> { + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"DataflowYieldOp">, + AutomaticAllocationScope] > { let arguments = (ins Variadic: $inputs); - let results = (outs Variadic:$outputs); + let results = (outs Variadic: $outputs); let regions = (region AnyRegion:$body); @@ -85,8 +90,12 @@ Example: }]; } -def RT_MakeReadyFutureOp : RT_Op<"make_ready_future"> { - let arguments = (ins AnyType: $input); +def RT_MakeReadyFutureOp : RT_Op<"make_ready_future", [ + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let arguments = (ins AnyType: $input, + AnyType: $memrefCloned); let results = (outs RT_Future: $output); let summary = "Build a ready future."; let description = [{ @@ -115,14 +124,27 @@ def RT_CreateAsyncTaskOp : RT_Op<"create_async_task"> { let summary = "Create a dataflow task."; } -def RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> { +def RT_RegisterTaskWorkFunctionOp : RT_Op<"register_task_work_function"> { let arguments = (ins Variadic:$list); let results = (outs ); let summary = "Register the task work-function with the runtime system."; } -def DeallocateFutureOp : RT_Op<"deallocate_future"> { +def RT_CloneFutureOp : RT_Op<"clone_future", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods] > { + let builders = [ + OpBuilder<(ins "Value": $input), [{ + return build($_builder, $_state, input.getType(), input); + }]>]; + let arguments = (ins RT_Future: $input); + let results = (outs RT_Future: $output); +} + +def RT_DeallocateFutureOp : RT_Op<"deallocate_future"> { + let arguments = (ins AnyType: $input); let results = (outs ); } diff --git a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp index a56022cd1..394b45858 100644 --- a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp +++ b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -28,6 +29,7 @@ #include +#include "concretelang/ClientLib/EvaluationKeys.h" #include "concretelang/Runtime/DFRuntime.hpp" #include "concretelang/Runtime/context.h" #include "concretelang/Runtime/dfr_debug_interface.h" @@ -49,6 +51,17 @@ static inline size_t _dfr_get_memref_rank(size_t size) { (2 * sizeof(int64_t) /*size&stride/rank*/); } +static inline void _dfr_checked_aligned_alloc(void **out, size_t align, + size_t size) { + int res = posix_memalign(out, align, size); + if (res == ENOMEM) + HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed", + "Error: insufficient memory available."); + if (res == EINVAL) + HPX_THROW_EXCEPTION(hpx::no_success, "DFR: memory allocation failed", + "Error: invalid memory alignment."); +} + struct OpaqueInputData { OpaqueInputData() = default; @@ -62,7 +75,7 @@ struct OpaqueInputData { param_types(std::move(_param_types)), output_sizes(std::move(_output_sizes)), output_types(std::move(_output_types)), alloc_p(_alloc_p), - source_locality(hpx::find_here()) {} + source_locality(hpx::find_here()), ksk_id(0), bsk_id(0) {} OpaqueInputData(const OpaqueInputData &oid) : wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)), @@ -70,16 +83,18 @@ struct OpaqueInputData { param_types(std::move(oid.param_types)), output_sizes(std::move(oid.output_sizes)), output_types(std::move(oid.output_types)), alloc_p(oid.alloc_p), - source_locality(oid.source_locality) {} + source_locality(oid.source_locality), ksk_id(oid.ksk_id), + bsk_id(oid.bsk_id) {} friend class hpx::serialization::access; template void load(Archive &ar, const unsigned int version) { ar >> wfn_name; ar >> param_sizes >> param_types; ar >> output_sizes >> output_types; + ar >> source_locality; for (size_t p = 0; p < param_sizes.size(); ++p) { - char *param = new char[param_sizes[p]]; - new_allocated.push_back((void *)param); + char *param; + _dfr_checked_aligned_alloc((void **)¶m, 64, param_sizes[p]); ar >> hpx::serialization::make_array(param, param_sizes[p]); params.push_back((void *)param); @@ -95,28 +110,23 @@ struct OpaqueInputData { for (size_t r = 0; r < rank; ++r) size *= mref.sizes[r]; size_t alloc_size = (size + mref.offset) * elementSize; - char *data = new char[alloc_size]; - new_allocated.push_back((void *)data); + char *data; + _dfr_checked_aligned_alloc((void **)&data, 512, alloc_size); ar >> hpx::serialization::make_array(data + mref.offset * elementSize, size * elementSize); static_cast *>(params[p])->basePtr = nullptr; static_cast *>(params[p])->data = data; } break; case _DFR_TASK_ARG_CONTEXT: { - uint64_t bsk_id, ksk_id; - ar >> bsk_id >> ksk_id >> source_locality; + ar >> bsk_id >> ksk_id; - mlir::concretelang::RuntimeContext *context = - new mlir::concretelang::RuntimeContext; - new_allocated.push_back((void *)context); - mlir::concretelang::RuntimeContext **_context = - new mlir::concretelang::RuntimeContext *[1]; - new_allocated.push_back((void *)_context); - _context[0] = context; - - context->bsk = (LweBootstrapKey_u64 *)bsk_id; - context->ksk = (LweKeyswitchKey_u64 *)ksk_id; - params[p] = (void *)_context; + delete ((char *)params[p]); + // TODO: this might be relaxed with newer versions of HPX. + // Do not set the context here as remote operations are + // unstable when initiated within a HPX helper thread. + params[p] = + (void *) + _dfr_node_level_runtime_context_manager->getContextAddress(); } break; case _DFR_TASK_ARG_UNRANKED_MEMREF: default: @@ -131,6 +141,7 @@ struct OpaqueInputData { ar << wfn_name; ar << param_sizes << param_types; ar << output_sizes << output_types; + ar << source_locality; for (size_t p = 0; p < params.size(); ++p) { // Save the first level of the data structure - if the parameter // is a tensor/memref, there is a second level. @@ -152,18 +163,16 @@ struct OpaqueInputData { case _DFR_TASK_ARG_CONTEXT: { mlir::concretelang::RuntimeContext *context = *static_cast(params[p]); - LweKeyswitchKey_u64 *ksk = context->ksk; - LweBootstrapKey_u64 *bsk = context->bsk; - - // TODO: find better unique identifiers. This is not a - // correctness issue, but performance. - uint64_t bsk_id = (uint64_t)bsk; - uint64_t ksk_id = (uint64_t)ksk; + LweKeyswitchKey_u64 *ksk = get_keyswitch_key_u64(context); + LweBootstrapKey_u64 *bsk = get_bootstrap_key_u64(context); assert(bsk != nullptr && ksk != nullptr && "Missing context keys"); - _dfr_register_bsk(bsk, bsk_id); - _dfr_register_ksk(ksk, ksk_id); - ar << bsk_id << ksk_id << source_locality; + std::cout << "Registering Key ids " << (uint64_t)ksk << " " + << (uint64_t)bsk << "\n" + << std::flush; + _dfr_register_bsk(bsk, (uint64_t)bsk); + _dfr_register_ksk(ksk, (uint64_t)ksk); + ar << (uint64_t)bsk << (uint64_t)ksk; } break; case _DFR_TASK_ARG_UNRANKED_MEMREF: default: @@ -182,6 +191,8 @@ struct OpaqueInputData { std::vector output_types; bool alloc_p = false; hpx::naming::id_type source_locality; + uint64_t ksk_id; + uint64_t bsk_id; }; struct OpaqueOutputData { @@ -200,8 +211,9 @@ struct OpaqueOutputData { template void load(Archive &ar, const unsigned int version) { ar >> output_sizes >> output_types; for (size_t p = 0; p < output_sizes.size(); ++p) { - char *output = new char[output_sizes[p]]; - new_allocated.push_back((void *)output); + char *output; + _dfr_checked_aligned_alloc((void **)&output, 64, (output_sizes[p])); + ar >> hpx::serialization::make_array(output, output_sizes[p]); outputs.push_back((void *)output); @@ -217,8 +229,8 @@ struct OpaqueOutputData { for (size_t r = 0; r < rank; ++r) size *= mref.sizes[r]; size_t alloc_size = (size + mref.offset) * elementSize; - char *data = new char[alloc_size]; - new_allocated.push_back((void *)data); + char *data; + _dfr_checked_aligned_alloc((void **)&data, 512, alloc_size); ar >> hpx::serialization::make_array(data + mref.offset * elementSize, size * elementSize); static_cast *>(outputs[p])->basePtr = @@ -283,23 +295,20 @@ struct GenericComputeServer : component_base { inputs.wfn_name); std::vector outputs; - if (!_dfr_is_root_node()) { - for (size_t p = 0; p < inputs.params.size(); ++p) { - if (_dfr_get_arg_type(inputs.param_types[p] == _DFR_TASK_ARG_CONTEXT)) { - mlir::concretelang::RuntimeContext *ctx = - (*(mlir::concretelang::RuntimeContext **)inputs.params[p]); - ctx->ksk = _dfr_get_ksk(inputs.source_locality, (uint64_t)ctx->ksk); - ctx->bsk = _dfr_get_bsk(inputs.source_locality, (uint64_t)ctx->bsk); - break; - } - } + if (inputs.source_locality != hpx::find_here() && + (inputs.ksk_id || inputs.bsk_id)) { + _dfr_node_level_runtime_context_manager->getContext( + inputs.ksk_id, inputs.bsk_id, inputs.source_locality); } + _dfr_debug_print_task(inputs.wfn_name.c_str(), inputs.params.size(), inputs.output_sizes.size()); hpx::cout << std::flush; + switch (inputs.output_sizes.size()) { case 1: { - void *output = (void *)(new char[inputs.output_sizes[0]]); + void *output; + _dfr_checked_aligned_alloc(&output, 512, inputs.output_sizes[0]); switch (inputs.params.size()) { case 0: wfn(output); @@ -387,18 +396,52 @@ struct GenericComputeServer : component_base { inputs.params[12], inputs.params[13], inputs.params[14], inputs.params[15], output); break; + case 17: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], output); + break; + case 18: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], output); + break; + case 19: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], + inputs.params[18], output); + break; + case 20: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], + inputs.params[18], inputs.params[19], output); + break; default: HPX_THROW_EXCEPTION(hpx::no_success, "GenericComputeServer::execute_task", "Error: number of task parameters not supported."); } outputs = {output}; - new_allocated.push_back(output); break; } case 2: { - void *output1 = (void *)(new char[inputs.output_sizes[0]]); - void *output2 = (void *)(new char[inputs.output_sizes[1]]); + void *output1, *output2; + _dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]); + _dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]); switch (inputs.params.size()) { case 0: wfn(output1, output2); @@ -491,20 +534,54 @@ struct GenericComputeServer : component_base { inputs.params[12], inputs.params[13], inputs.params[14], inputs.params[15], output1, output2); break; + case 17: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], output1, output2); + break; + case 18: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], output1, + output2); + break; + case 19: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], + inputs.params[18], output1, output2); + break; + case 20: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], + inputs.params[18], inputs.params[19], output1, output2); + break; default: HPX_THROW_EXCEPTION(hpx::no_success, "GenericComputeServer::execute_task", "Error: number of task parameters not supported."); } outputs = {output1, output2}; - new_allocated.push_back(output1); - new_allocated.push_back(output2); break; } case 3: { - void *output1 = (void *)(new char[inputs.output_sizes[0]]); - void *output2 = (void *)(new char[inputs.output_sizes[1]]); - void *output3 = (void *)(new char[inputs.output_sizes[2]]); + void *output1, *output2, *output3; + _dfr_checked_aligned_alloc(&output1, 512, inputs.output_sizes[0]); + _dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[1]); + _dfr_checked_aligned_alloc(&output2, 512, inputs.output_sizes[2]); switch (inputs.params.size()) { case 0: wfn(output1, output2, output3); @@ -597,15 +674,47 @@ struct GenericComputeServer : component_base { inputs.params[12], inputs.params[13], inputs.params[14], inputs.params[15], output1, output2, output3); break; + case 17: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], output1, output2, output3); + break; + case 18: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], output1, + output2, output3); + break; + case 19: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], + inputs.params[18], output1, output2, output3); + break; + case 20: + wfn(inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5], + inputs.params[6], inputs.params[7], inputs.params[8], + inputs.params[9], inputs.params[10], inputs.params[11], + inputs.params[12], inputs.params[13], inputs.params[14], + inputs.params[15], inputs.params[16], inputs.params[17], + inputs.params[18], inputs.params[19], output1, output2, output3); + break; default: HPX_THROW_EXCEPTION(hpx::no_success, "GenericComputeServer::execute_task", "Error: number of task parameters not supported."); } outputs = {output1, output2, output3}; - new_allocated.push_back(output1); - new_allocated.push_back(output2); - new_allocated.push_back(output3); break; } default: @@ -613,6 +722,18 @@ struct GenericComputeServer : component_base { "Error: number of task outputs not supported."); } + // Deallocate input data buffers from OID deserialization (load) + if (!_dfr_is_root_node()) { + for (size_t p = 0; p < inputs.param_sizes.size(); ++p) { + if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) { + if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF) + delete (static_cast *>(inputs.params[p]) + ->data); + delete ((char *)inputs.params[p]); + } + } + } + return OpaqueOutputData(std::move(outputs), std::move(inputs.output_sizes), std::move(inputs.output_types), inputs.alloc_p); } diff --git a/compiler/include/concretelang/Runtime/key_manager.hpp b/compiler/include/concretelang/Runtime/key_manager.hpp index 86d4942fb..660b60fd9 100644 --- a/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compiler/include/concretelang/Runtime/key_manager.hpp @@ -15,6 +15,7 @@ #include #include "concretelang/Runtime/DFRuntime.hpp" +#include "concretelang/Runtime/context.h" extern "C" { #include "concrete-ffi.h" @@ -25,13 +26,12 @@ namespace concretelang { namespace dfr { template struct KeyManager; +struct RuntimeContextManager; namespace { static void *dl_handle; static KeyManager *_dfr_node_level_bsk_manager; static KeyManager *_dfr_node_level_ksk_manager; -static std::list new_allocated; -static std::list fut_allocated; -static std::list m_allocated; +static RuntimeContextManager *_dfr_node_level_runtime_context_manager; } // namespace void _dfr_register_bsk(LweBootstrapKey_u64 *key, uint64_t key_id); @@ -66,7 +66,6 @@ void KeyWrapper::load(Archive &ar, size_t length; ar >> length; uint8_t *pointer = new uint8_t[length]; - new_allocated.push_back((void *)pointer); ar >> hpx::serialization::make_array(pointer, length); BufferView buffer = {(const uint8_t *)pointer, length}; key = deserialize_lwe_bootstrap_key_u64(buffer); @@ -87,7 +86,6 @@ void KeyWrapper::load(Archive &ar, size_t length; ar >> length; uint8_t *pointer = new uint8_t[length]; - new_allocated.push_back((void *)pointer); ar >> hpx::serialization::make_array(pointer, length); BufferView buffer = {(const uint8_t *)pointer, length}; key = deserialize_lwe_keyswitching_key_u64(buffer); @@ -224,6 +222,68 @@ LweKeyswitchKey_u64 *_dfr_get_ksk(hpx::naming::id_type loc, uint64_t key_id) { return _dfr_node_level_ksk_manager->get_key(loc, key_id); } +/************************/ +/* Context management. */ +/************************/ + +struct RuntimeContextManager { + // TODO: this is only ok so long as we don't change keys. Once we + // use multiple keys, should have a map. + RuntimeContext *context; + std::mutex context_guard; + uint64_t ksk_id; + uint64_t bsk_id; + + RuntimeContextManager() { + ksk_id = 0; + bsk_id = 0; + context = nullptr; + _dfr_node_level_runtime_context_manager = this; + } + + RuntimeContext *getContext(uint64_t ksk, uint64_t bsk, + hpx::naming::id_type source_locality) { + std::cout << "GetContext on node " << hpx::get_locality_id() + << " with context " << context << " " << bsk_id << " " << ksk_id + << "\n" + << std::flush; + if (context != nullptr) { + std::cout << "simil " << ksk_id << " " << ksk << " " << bsk_id << " " + << bsk << "\n" + << std::flush; + assert(ksk == ksk_id && bsk == bsk_id && + "Context manager can only used with single keys for now."); + } else { + assert(ksk_id == 0 && bsk_id == 0 && + "Context empty but context manager has key ids."); + LweKeyswitchKey_u64 *keySwitchKey = _dfr_get_ksk(source_locality, ksk); + LweBootstrapKey_u64 *bootstrapKey = _dfr_get_bsk(source_locality, bsk); + std::lock_guard guard(context_guard); + if (context == nullptr) { + auto ctx = new RuntimeContext(); + ctx->evaluationKeys = ::concretelang::clientlib::EvaluationKeys( + std::shared_ptr<::concretelang::clientlib::LweKeyswitchKey>( + new ::concretelang::clientlib::LweKeyswitchKey(keySwitchKey)), + std::shared_ptr<::concretelang::clientlib::LweBootstrapKey>( + new ::concretelang::clientlib::LweBootstrapKey(bootstrapKey))); + ksk_id = ksk; + bsk_id = bsk; + context = ctx; + std::cout << "Fetching Key ids " << ksk_id << " " << bsk_id << "\n" + << std::flush; + } else { + std::cout << " GOT context after LOCK on node " + << hpx::get_locality_id() << " with context " << context + << " " << bsk_id << " " << ksk_id << "\n" + << std::flush; + } + } + return context; + } + + RuntimeContext **getContextAddress() { return &context; } +}; + } // namespace dfr } // namespace concretelang } // namespace mlir diff --git a/compiler/include/concretelang/Runtime/runtime_api.h b/compiler/include/concretelang/Runtime/runtime_api.h index 4025f9aa4..7e19bb2d4 100644 --- a/compiler/include/concretelang/Runtime/runtime_api.h +++ b/compiler/include/concretelang/Runtime/runtime_api.h @@ -14,7 +14,7 @@ extern "C" { typedef void (*wfnptr)(...); -void *_dfr_make_ready_future(void *); +void *_dfr_make_ready_future(void *, size_t); void _dfr_create_async_task(wfnptr, size_t, size_t, ...); void _dfr_register_work_function(wfnptr); void *_dfr_await_future(void *); diff --git a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp index bc0787e9a..caf2b887b 100644 --- a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp +++ b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp @@ -57,7 +57,7 @@ static bool isCandidateForTask(Operation *op) { /// operations must not have side-effects and not be `isCandidateForTask` static bool isSinkingBeneficiary(Operation *op) { return isa(op); + mlir::arith::CmpIOp, mlir::memref::GetGlobalOp>(op); } static bool @@ -90,6 +90,92 @@ extractBeneficiaryOps(Operation *op, SetVector existingDependencies, return true; } +static func::FuncOp getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); + if (!sym) + return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + +static void getAliasedUses(Value val, DenseSet &aliasedUses) { + for (auto &use : val.getUses()) { + aliasedUses.insert(&use); + if (isa(use.getOwner())) + getAliasedUses(use.getOwner()->getResult(0), aliasedUses); + } +} + +static bool extractOutputMemrefAllocations( + Operation *op, SetVector existingDependencies, + SetVector &beneficiaryOps, + llvm::SmallPtrSetImpl &availableValues, RT::DataflowTaskOp taskOp) { + if (beneficiaryOps.count(op)) + return true; + + if (!isa(op)) + return false; + + Value val = op->getResults().front(); + DenseSet aliasedUses; + getAliasedUses(val, aliasedUses); + + // Helper function checking if a memref use writes to memory + auto hasMemoryWriteEffect = [&](OpOperand *use) { + // Call ops targeting concrete-ffi do not have memory effects + // interface, so handle apart. + // TODO: this could be handled better in BConcrete or higher. + if (isa(use->getOwner())) { + if (getCalledFunction(use->getOwner()).getName() == + "memref_expand_lut_in_trivial_glwe_ct_u64" || + getCalledFunction(use->getOwner()).getName() == + "memref_add_lwe_ciphertexts_u64" || + getCalledFunction(use->getOwner()).getName() == + "memref_add_plaintext_lwe_ciphertext_u64" || + getCalledFunction(use->getOwner()).getName() == + "memref_mul_cleartext_lwe_ciphertext_u64" || + getCalledFunction(use->getOwner()).getName() == + "memref_negate_lwe_ciphertext_u64" || + getCalledFunction(use->getOwner()).getName() == + "memref_keyswitch_lwe_u64" || + getCalledFunction(use->getOwner()).getName() == + "memref_bootstrap_lwe_u64") + if (use->getOwner()->getOperand(0) == use->get()) + return true; + + if (getCalledFunction(use->getOwner()).getName() == + "memref_copy_one_rank") + if (use->getOwner()->getOperand(1) == use->get()) + return true; + } + + // Otherwise we rely on the memory effect interface + auto effectInterface = dyn_cast(use->getOwner()); + if (!effectInterface) + return false; + SmallVector effects; + effectInterface.getEffects(effects); + for (auto eff : effects) + if (isa(eff.getEffect()) && + eff.getValue() == use->get()) + return true; + return false; + }; + + // We need to check if this allocated memref is written in this task. + // TODO: for now we'll assume that we don't do partial writes or read/writes. + for (auto use : aliasedUses) + if (hasMemoryWriteEffect(use) && + use->getOwner()->getParentOfType() == taskOp) { + // We will sink the operation, mark its results as now available. + beneficiaryOps.insert(op); + for (Value result : op->getResults()) + availableValues.insert(result); + return true; + } + return false; +} + LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) { Region &taskOpBody = taskOp.body(); @@ -104,6 +190,8 @@ LogicalResult sinkOperationsIntoDFTask(RT::DataflowTaskOp taskOp) { if (!operandOp) continue; extractBeneficiaryOps(operandOp, sinkCandidates, toBeSunk, availableValues); + extractOutputMemrefAllocations(operandOp, sinkCandidates, toBeSunk, + availableValues, taskOp); } // Insert operations so that the defs get cloned before uses. diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index 1886f18a4..e47951cd3 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -117,7 +118,8 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp, static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement, Region ®ion) { for (auto &use : llvm::make_early_inc_range(orig.getUses())) { - if (isa(use.getOwner()) && + if ((isa(use.getOwner()) || + isa(use.getOwner())) && region.isAncestor(use.getOwner()->getParentRegion())) use.set(replacement); } @@ -183,10 +185,7 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) { // bytes until we can get the actual size of the actual types. if (type.isa() || type.isa() || - type.isa() || - type.isa() || - type.isa() || - type.isa()) { + type.isa()) { Value result = builder.create(loc, builder.getI64IntegerAttr(8)); return std::pair(result, arg_type); @@ -204,38 +203,69 @@ getSizeInBytes(Value val, Location loc, OpBuilder builder) { return std::pair(result, arg_type); } +static void getAliasedUses(Value val, DenseSet &aliasedUses) { + for (auto &use : val.getUses()) { + aliasedUses.insert(&use); + if (isa(use.getOwner())) + getAliasedUses(use.getOwner()->getResult(0), aliasedUses); + } +} + static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, func::FuncOp workFunction) { DataLayout dataLayout = DataLayout::closest(DFTOp); Region &opBody = DFTOp->getParentOfType().getBody(); - BlockAndValueMapping map; OpBuilder builder(DFTOp); // First identify DFT operands that are not futures and are not // defined by another DFT. These need to be made into futures and // propagated to all other DFTs. We can allow PRE to eliminate the // previous definitions if there are no non-future type uses. - builder.setInsertionPoint(DFTOp); for (Value val : DFTOp.getOperands()) { if (!val.getType().isa()) { - Value newval; + OpBuilder::InsertionGuard guard(builder); + Type futType = RT::FutureType::get(val.getType()); + Value memrefCloned, newval = val; + + // Find out if this value is needed in any other task + SmallVector taskOps; + for (auto &use : val.getUses()) + if (isa(use.getOwner())) + taskOps.push_back(use.getOwner()); + Operation *first = DFTOp; + for (auto op : taskOps) + if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first)) + first = op; + builder.setInsertionPoint(first); + // If we are building a future on a MemRef, then we need to clone // the memref in order to allow the deallocation pass which does // not synchronize with task execution. - if (val.getType().isa() && !val.isa()) { - newval = builder - .create( - DFTOp.getLoc(), val.getType().cast()) + if (val.getType().isa()) { + // Get the type of memref that we will clone. In case this is + // a subview, we discard the mapping so we copy to a contiguous + // layout which pre-serializes this. + MemRefType mrType = val.getType().dyn_cast(); + if (!mrType.getLayout().isIdentity()) { + unsigned rank = mrType.getRank(); + mrType = MemRefType::Builder(mrType) + .setShape(mrType.getShape()) + .setLayout(AffineMapAttr::get( + builder.getMultiDimIdentityMap(rank))); + } + newval = builder.create(val.getLoc(), mrType) .getResult(); - builder.create(DFTOp.getLoc(), val, newval); + builder.create(val.getLoc(), val, newval); + memrefCloned = builder.create( + val.getLoc(), builder.getI64IntegerAttr(1)); } else { - newval = val; + memrefCloned = builder.create( + val.getLoc(), builder.getI64IntegerAttr(0)); } - Type futType = RT::FutureType::get(newval.getType()); - auto mrf = builder.create(DFTOp.getLoc(), futType, - newval); - map.map(mrf->getResult(0), val); - replaceAllUsesInDFTsInRegionWith(val, mrf->getResult(0), opBody); + + auto mrf = builder.create(val.getLoc(), futType, + newval, memrefCloned); + replaceAllUsesInDFTsInRegionWith(val, mrf, opBody); } } @@ -268,6 +298,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, // unsupported even in the LLVMIR Dialect - this needs to use two // placeholders for each output, before and after the // CreateAsyncTaskOp. + BlockAndValueMapping map; for (auto result : DFTOp.getResults()) { Type futType = RT::PointerType::get(RT::FutureType::get(result.getType())); auto brpp = builder.create(DFTOp.getLoc(), @@ -297,6 +328,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, for (auto &use : llvm::make_early_inc_range(result.getUses())) { if (!isa(use.getOwner()) && + !isa(use.getOwner()) && use.getOwner()->getParentOfType() == nullptr) { // Wait for this future before its uses OpBuilder::InsertionGuard guard(builder); @@ -315,24 +347,35 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, DFTOp.erase(); } -static void registerWorkFunction(FuncOp parentFunc, FuncOp workFunction) { - OpBuilder builder(parentFunc.body()); - builder.setInsertionPointToStart(&parentFunc.body().front()); +static void registerWorkFunction(mlir::func::FuncOp parentFunc, + mlir::func::FuncOp workFunction) { + OpBuilder builder(parentFunc.getBody()); + builder.setInsertionPointToStart(&parentFunc.getBody().front()); - auto fnptr = builder.create( - parentFunc.getLoc(), workFunction.getType(), + auto fnptr = builder.create( + parentFunc.getLoc(), workFunction.getFunctionType(), SymbolRefAttr::get(builder.getContext(), workFunction.getName())); builder.create(parentFunc.getLoc(), fnptr.getResult()); } +static func::FuncOp getCalledFunction(CallOpInterface callOp) { + SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); + if (!sym) + return nullptr; + return dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(callOp, sym)); +} + /// For documentation see Autopar.td struct LowerDataflowTasksPass : public LowerDataflowTasksBase { void runOnOperation() override { auto module = getOperation(); + SmallVector workFunctions; + SmallVector entryPoints; module.walk([&](mlir::func::FuncOp func) { static int wfn_id = 0; @@ -342,7 +385,7 @@ struct LowerDataflowTasksPass return; SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func); - std::vector> outliningMap; + SmallVector, 4> outliningMap; func.walk([&](RT::DataflowTaskOp op) { auto workFunctionName = @@ -353,6 +396,7 @@ struct LowerDataflowTasksPass outlineWorkFunction(op, workFunctionName.str()); outliningMap.push_back( std::pair(op, outlinedFunc)); + workFunctions.push_back(outlinedFunc); symbolTable.insert(outlinedFunc); return WalkResult::advance(); }); @@ -361,73 +405,72 @@ struct LowerDataflowTasksPass for (auto mapping : outliningMap) lowerDataflowTaskOp(mapping.first, mapping.second); + // Main is always an entry-point - otherwise check if this + // function is called within the module. TODO: we assume no + // recursion. + if (func.getName() == "main") + entryPoints.push_back(func); + else { + bool found = false; + module.walk([&](mlir::func::CallOp op) { + if (getCalledFunction(op) == func) + found = true; + }); + if (!found) + entryPoints.push_back(func); + } + }); + + for (auto entryPoint : entryPoints) { // If this is a JIT invocation and we're not on the root node, // we do not need to do any computation, only register all work // functions with the runtime system - if (!outliningMap.empty()) { + if (!workFunctions.empty()) { if (!dfr::_dfr_is_root_node()) { - // auto regFunc = builder.create(func.getLoc(), - // func.getName(), func.getType()); - - func.eraseBody(); + entryPoint.eraseBody(); Block *b = new Block; - b->addArguments(func.getType().getInputs()); - func.body().push_front(b); - for (int i = func.getType().getNumInputs() - 1; i >= 0; --i) - func.eraseArgument(i); - for (int i = func.getType().getNumResults() - 1; i >= 0; --i) - func.eraseResult(i); - OpBuilder builder(func.body()); - builder.setInsertionPointToEnd(&func.body().front()); - builder.create(func.getLoc()); + SmallVector locations; + for (auto input : entryPoint.getFunctionType().getInputs()) + locations.push_back(entryPoint.getLoc()); + b->addArguments(entryPoint.getFunctionType().getInputs(), locations); + entryPoint.getBody().push_front(b); + for (int i = entryPoint.getFunctionType().getNumInputs() - 1; i >= 0; + --i) + entryPoint.eraseArgument(i); + for (int i = entryPoint.getFunctionType().getNumResults() - 1; i >= 0; + --i) + entryPoint.eraseResult(i); + OpBuilder builder(entryPoint.getBody()); + builder.setInsertionPointToEnd(&entryPoint.getBody().front()); + builder.create(entryPoint.getLoc()); } } // Generate code to register all work-functions with the // runtime. - for (auto mapping : outliningMap) - registerWorkFunction(func, mapping.second); + for (auto wf : workFunctions) + registerWorkFunction(entryPoint, wf); // Issue _dfr_start/stop calls for this function - if (!outliningMap.empty()) { - OpBuilder builder(func.getBody()); - builder.setInsertionPointToStart(&func.getBody().front()); + if (!workFunctions.empty()) { + OpBuilder builder(entryPoint.getBody()); + builder.setInsertionPointToStart(&entryPoint.getBody().front()); auto dfrStartFunOp = mlir::LLVM::lookupOrCreateFn( - func->getParentOfType(), "_dfr_start", {}, - LLVM::LLVMVoidType::get(func->getContext())); - builder.create(func.getLoc(), dfrStartFunOp, + module, "_dfr_start", {}, + LLVM::LLVMVoidType::get(entryPoint->getContext())); + builder.create(entryPoint.getLoc(), dfrStartFunOp, mlir::ValueRange(), ArrayRef()); - builder.setInsertionPoint(func.getBody().back().getTerminator()); + builder.setInsertionPoint(entryPoint.getBody().back().getTerminator()); auto dfrStopFunOp = mlir::LLVM::lookupOrCreateFn( - func->getParentOfType(), "_dfr_stop", {}, - LLVM::LLVMVoidType::get(func->getContext())); - builder.create(func.getLoc(), dfrStopFunOp, + module, "_dfr_stop", {}, + LLVM::LLVMVoidType::get(entryPoint->getContext())); + builder.create(entryPoint.getLoc(), dfrStopFunOp, mlir::ValueRange(), ArrayRef()); } - }); - - // Delay memref deallocations when memrefs are made into futures - module.walk([&](Operation *op) { - if (isa(*op) && - op->getOperand(0).getType().isa()) { - for (auto &use : - llvm::make_early_inc_range(op->getOperand(0).getUses())) { - if (isa(use.getOwner())) { - OpBuilder builder(use.getOwner() - ->getParentOfType() - .getBody() - .back() - .getTerminator()); - builder.clone(*use.getOwner()); - use.getOwner()->erase(); - } - } - } - return WalkResult::advance(); - }); + } } LowerDataflowTasksPass(bool debug) : debug(debug){}; @@ -440,5 +483,44 @@ std::unique_ptr createLowerDataflowTasksPass(bool debug) { return std::make_unique(debug); } +namespace { + +// For documentation see Autopar.td +struct FixupBufferDeallocationPass + : public FixupBufferDeallocationBase { + + void runOnOperation() override { + auto module = getOperation(); + std::vector ops; + + // All buffers allocated and either made into a future, directly + // or as a result of being returned by a task, are managed by the + // DFR runtime system's reference counting. + module.walk([&](RT::WorkFunctionReturnOp retOp) { + for (auto &use : + llvm::make_early_inc_range(retOp.getOperands().front().getUses())) + if (isa(use.getOwner())) + ops.push_back(use.getOwner()); + }); + module.walk([&](RT::MakeReadyFutureOp mrfOp) { + for (auto &use : + llvm::make_early_inc_range(mrfOp.getOperands().front().getUses())) + if (isa(use.getOwner())) + ops.push_back(use.getOwner()); + }); + for (auto op : ops) + op->erase(); + } + FixupBufferDeallocationPass(bool debug) : debug(debug){}; + +protected: + bool debug; +}; +} // end anonymous namespace + +std::unique_ptr createFixupBufferDeallocationPass(bool debug) { + return std::make_unique(debug); +} + } // end namespace concretelang } // end namespace mlir diff --git a/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp b/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp index 816678de6..5f235ef9c 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerRTToLLVMDFRCallsConversionPatterns.cpp @@ -125,8 +125,9 @@ struct MakeReadyFutureOpInterfaceLowering results[0]); rewriter.create(mrfOp.getLoc(), adaptor.getOperands().front(), allocatedPtr); - rewriter.replaceOpWithNewOp(mrfOp, mrfFuncOp, allocatedPtr); - + SmallVector mrfOperands = {adaptor.getOperands()}; + mrfOperands[0] = allocatedPtr; + rewriter.replaceOpWithNewOp(mrfOp, mrfFuncOp, mrfOperands); return mlir::success(); } }; @@ -178,16 +179,14 @@ struct RegisterTaskWorkFunctionOpInterfaceLowering mlir::LogicalResult matchAndRewrite(RT::RegisterTaskWorkFunctionOp rtwfOp, - ArrayRef operands, + RT::RegisterTaskWorkFunctionOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RT::RegisterTaskWorkFunctionOp::Adaptor transformed(operands); - auto rtwfFuncType = LLVM::LLVMFunctionType::get(getVoidType(), {}, /*isVariadic=*/true); auto rtwfFuncOp = getOrInsertFuncOpDecl( rtwfOp, "_dfr_register_work_function", rtwfFuncType, rewriter); rewriter.replaceOpWithNewOp(rtwfOp, rtwfFuncOp, - transformed.getOperands()); + adaptor.getOperands()); return success(); } }; diff --git a/compiler/lib/Dialect/RT/IR/RTOps.cpp b/compiler/lib/Dialect/RT/IR/RTOps.cpp index 215a635b4..d4ef30634 100644 --- a/compiler/lib/Dialect/RT/IR/RTOps.cpp +++ b/compiler/lib/Dialect/RT/IR/RTOps.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Region.h" #include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" #include "concretelang/Dialect/RT/IR/RTOps.h" #include "concretelang/Dialect/RT/IR/RTTypes.h" @@ -33,3 +34,66 @@ void DataflowTaskOp::build( void DataflowTaskOp::getSuccessorRegions( Optional index, ArrayRef operands, SmallVectorImpl ®ions) {} + +llvm::Optional +DataflowTaskOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc) + .getOperation(); +} +llvm::Optional DataflowTaskOp::buildClone(OpBuilder &builder, + Value alloc) { + return builder.create(alloc.getLoc(), alloc).getResult(); +} +void DataflowTaskOp::getEffects( + SmallVectorImpl> + &effects) { + for (auto input : inputs()) + effects.emplace_back(MemoryEffects::Read::get(), input, + SideEffects::DefaultResource::get()); + for (auto output : outputs()) + effects.emplace_back(MemoryEffects::Write::get(), output, + SideEffects::DefaultResource::get()); + for (auto output : outputs()) + effects.emplace_back(MemoryEffects::Allocate::get(), output, + SideEffects::DefaultResource::get()); +} + +llvm::Optional +CloneFutureOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc) + .getOperation(); +} +llvm::Optional CloneFutureOp::buildClone(OpBuilder &builder, + Value alloc) { + return builder.create(alloc.getLoc(), alloc).getResult(); +} +void CloneFutureOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Allocate::get(), output(), + SideEffects::DefaultResource::get()); +} + +llvm::Optional +MakeReadyFutureOp::buildDealloc(OpBuilder &builder, Value alloc) { + return builder.create(alloc.getLoc(), alloc) + .getOperation(); +} +llvm::Optional MakeReadyFutureOp::buildClone(OpBuilder &builder, + Value alloc) { + return builder.create(alloc.getLoc(), alloc).getResult(); +} +void MakeReadyFutureOp::getEffects( + SmallVectorImpl> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), input(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), output(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Allocate::get(), output(), + SideEffects::DefaultResource::get()); +} diff --git a/compiler/lib/Runtime/DFRuntime.cpp b/compiler/lib/Runtime/DFRuntime.cpp index 56c6035c3..0a22d6696 100644 --- a/compiler/lib/Runtime/DFRuntime.cpp +++ b/compiler/lib/Runtime/DFRuntime.cpp @@ -38,33 +38,50 @@ static size_t num_nodes = 0; using namespace hpx; -void *_dfr_make_ready_future(void *in) { - void *future = static_cast( - new hpx::shared_future(hpx::make_ready_future(in))); - mlir::concretelang::dfr::m_allocated.push_back(in); - mlir::concretelang::dfr::fut_allocated.push_back(future); - return future; +typedef struct dfr_refcounted_future { + hpx::shared_future *future; + std::atomic count; + bool cloned_memref_p; + dfr_refcounted_future(hpx::shared_future *f, size_t c, bool clone_p) + : future(f), count(c), cloned_memref_p(clone_p) {} +} dfr_refcounted_future_t, *dfr_refcounted_future_p; + +// Ready futures are only used as inputs to tasks (never passed to +// await_future), so we only need to track the references in task +// creation. +void *_dfr_make_ready_future(void *in, size_t memref_clone_p) { + return (void *)new dfr_refcounted_future_t( + new hpx::shared_future(hpx::make_ready_future(in)), 1, + memref_clone_p); } void *_dfr_await_future(void *in) { - return static_cast *>(in)->get(); -} - -void _dfr_deallocate_future_data(void *in) { - delete[] static_cast( - static_cast *>(in)->get()); + return static_cast(in)->future->get(); } void _dfr_deallocate_future(void *in) { - delete (static_cast *>(in)); + auto drf = static_cast(in); + size_t prev_count = drf->count.fetch_sub(1); + if (prev_count == 1) { + // If this was a memref for which a clone was needed, deallocate first. + if (drf->cloned_memref_p) + free( + (void *)(static_cast *>(drf->future->get()) + ->data)); + free(drf->future->get()); + delete (drf->future); + delete drf; + } } +void _dfr_deallocate_future_data(void *in) {} + // Determine where new task should run. For now just round-robin // distribution - TODO: optimise. static inline size_t _dfr_find_next_execution_locality() { - static std::atomic next_locality{0}; + static std::atomic next_locality{1}; - size_t next_loc = ++next_locality; + size_t next_loc = next_locality.fetch_add(1); return next_loc % mlir::concretelang::dfr::num_nodes; } @@ -76,7 +93,8 @@ static inline size_t _dfr_find_next_execution_locality() { /// the returns. void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, ...) { - std::vector params; + // std::vector params; + std::vector refcounted_futures; std::vector param_sizes; std::vector param_types; std::vector outputs; @@ -86,7 +104,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, va_list args; va_start(args, num_outputs); for (size_t i = 0; i < num_params; ++i) { - params.push_back(va_arg(args, void *)); + refcounted_futures.push_back(va_arg(args, void *)); param_sizes.push_back(va_arg(args, uint64_t)); param_types.push_back(va_arg(args, uint64_t)); } @@ -97,13 +115,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, } va_end(args); - for (size_t i = 0; i < num_params; ++i) { - if (mlir::concretelang::dfr::_dfr_get_arg_type( - param_types[i] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF)) { - mlir::concretelang::dfr::m_allocated.push_back( - (void *)static_cast *>(params[i])->data); - } - } + // Take a reference on each future argument + for (auto rcf : refcounted_futures) + ((dfr_refcounted_future_p)rcf)->count.fetch_add(1); // We pass functions by name - which is not strictly necessary in // shared memory as pointers suffice, but is needed in the @@ -147,7 +161,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future)); break; case 2: @@ -164,8 +178,8 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future)); break; case 3: @@ -184,9 +198,9 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future)); break; case 4: @@ -206,10 +220,10 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future)); break; case 5: @@ -231,11 +245,11 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future)); break; case 6: @@ -258,12 +272,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future)); break; case 7: @@ -287,13 +301,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future)); break; case 8: @@ -318,14 +332,14 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future)); break; case 9: @@ -352,15 +366,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future)); break; case 10: @@ -388,16 +402,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8], - *(hpx::shared_future *)params[9])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future)); break; case 11: @@ -426,17 +440,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8], - *(hpx::shared_future *)params[9], - *(hpx::shared_future *)params[10])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future)); break; case 12: @@ -466,18 +480,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8], - *(hpx::shared_future *)params[9], - *(hpx::shared_future *)params[10], - *(hpx::shared_future *)params[11])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future)); break; case 13: @@ -509,19 +523,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8], - *(hpx::shared_future *)params[9], - *(hpx::shared_future *)params[10], - *(hpx::shared_future *)params[11], - *(hpx::shared_future *)params[12])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future)); break; case 14: @@ -554,20 +568,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8], - *(hpx::shared_future *)params[9], - *(hpx::shared_future *)params[10], - *(hpx::shared_future *)params[11], - *(hpx::shared_future *)params[12], - *(hpx::shared_future *)params[13])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future, + *((dfr_refcounted_future_p)refcounted_futures[13])->future)); break; case 15: @@ -601,21 +615,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8], - *(hpx::shared_future *)params[9], - *(hpx::shared_future *)params[10], - *(hpx::shared_future *)params[11], - *(hpx::shared_future *)params[12], - *(hpx::shared_future *)params[13], - *(hpx::shared_future *)params[14])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future, + *((dfr_refcounted_future_p)refcounted_futures[13])->future, + *((dfr_refcounted_future_p)refcounted_futures[14])->future)); break; case 16: @@ -650,22 +664,246 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, [_dfr_find_next_execution_locality()] .execute_task(oid); }, - *(hpx::shared_future *)params[0], - *(hpx::shared_future *)params[1], - *(hpx::shared_future *)params[2], - *(hpx::shared_future *)params[3], - *(hpx::shared_future *)params[4], - *(hpx::shared_future *)params[5], - *(hpx::shared_future *)params[6], - *(hpx::shared_future *)params[7], - *(hpx::shared_future *)params[8], - *(hpx::shared_future *)params[9], - *(hpx::shared_future *)params[10], - *(hpx::shared_future *)params[11], - *(hpx::shared_future *)params[12], - *(hpx::shared_future *)params[13], - *(hpx::shared_future *)params[14], - *(hpx::shared_future *)params[15])); + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future, + *((dfr_refcounted_future_p)refcounted_futures[13])->future, + *((dfr_refcounted_future_p)refcounted_futures[14])->future, + *((dfr_refcounted_future_p)refcounted_futures[15])->future)); + break; + + case 17: + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, param_types, output_sizes, + output_types](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16) + -> hpx::future { + std::vector params = { + param0.get(), param1.get(), param2.get(), param3.get(), + param4.get(), param5.get(), param6.get(), param7.get(), + param8.get(), param9.get(), param10.get(), param11.get(), + param12.get(), param13.get(), param14.get(), param15.get(), + param16.get()}; + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); + }, + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future, + *((dfr_refcounted_future_p)refcounted_futures[13])->future, + *((dfr_refcounted_future_p)refcounted_futures[14])->future, + *((dfr_refcounted_future_p)refcounted_futures[15])->future, + *((dfr_refcounted_future_p)refcounted_futures[16])->future)); + break; + + case 18: + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, param_types, output_sizes, + output_types](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16, + hpx::shared_future param17) + -> hpx::future { + std::vector params = { + param0.get(), param1.get(), param2.get(), param3.get(), + param4.get(), param5.get(), param6.get(), param7.get(), + param8.get(), param9.get(), param10.get(), param11.get(), + param12.get(), param13.get(), param14.get(), param15.get(), + param16.get(), param17.get()}; + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); + }, + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future, + *((dfr_refcounted_future_p)refcounted_futures[13])->future, + *((dfr_refcounted_future_p)refcounted_futures[14])->future, + *((dfr_refcounted_future_p)refcounted_futures[15])->future, + *((dfr_refcounted_future_p)refcounted_futures[16])->future, + *((dfr_refcounted_future_p)refcounted_futures[17])->future)); + break; + + case 19: + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, param_types, output_sizes, + output_types](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16, + hpx::shared_future param17, + hpx::shared_future param18) + -> hpx::future { + std::vector params = { + param0.get(), param1.get(), param2.get(), param3.get(), + param4.get(), param5.get(), param6.get(), param7.get(), + param8.get(), param9.get(), param10.get(), param11.get(), + param12.get(), param13.get(), param14.get(), param15.get(), + param16.get(), param17.get(), param18.get()}; + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); + }, + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future, + *((dfr_refcounted_future_p)refcounted_futures[13])->future, + *((dfr_refcounted_future_p)refcounted_futures[14])->future, + *((dfr_refcounted_future_p)refcounted_futures[15])->future, + *((dfr_refcounted_future_p)refcounted_futures[16])->future, + *((dfr_refcounted_future_p)refcounted_futures[17])->future, + *((dfr_refcounted_future_p)refcounted_futures[18])->future)); + break; + + case 20: + oodf = std::move(hpx::dataflow( + [wfnname, param_sizes, param_types, output_sizes, + output_types](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16, + hpx::shared_future param17, + hpx::shared_future param18, + hpx::shared_future param19) + -> hpx::future { + std::vector params = { + param0.get(), param1.get(), param2.get(), param3.get(), + param4.get(), param5.get(), param6.get(), param7.get(), + param8.get(), param9.get(), param10.get(), param11.get(), + param12.get(), param13.get(), param14.get(), param15.get(), + param16.get(), param17.get(), param18.get(), param19.get()}; + mlir::concretelang::dfr::OpaqueInputData oid( + wfnname, params, param_sizes, param_types, output_sizes, + output_types); + return mlir::concretelang::dfr::gcc + [_dfr_find_next_execution_locality()] + .execute_task(oid); + }, + *((dfr_refcounted_future_p)refcounted_futures[0])->future, + *((dfr_refcounted_future_p)refcounted_futures[1])->future, + *((dfr_refcounted_future_p)refcounted_futures[2])->future, + *((dfr_refcounted_future_p)refcounted_futures[3])->future, + *((dfr_refcounted_future_p)refcounted_futures[4])->future, + *((dfr_refcounted_future_p)refcounted_futures[5])->future, + *((dfr_refcounted_future_p)refcounted_futures[6])->future, + *((dfr_refcounted_future_p)refcounted_futures[7])->future, + *((dfr_refcounted_future_p)refcounted_futures[8])->future, + *((dfr_refcounted_future_p)refcounted_futures[9])->future, + *((dfr_refcounted_future_p)refcounted_futures[10])->future, + *((dfr_refcounted_future_p)refcounted_futures[11])->future, + *((dfr_refcounted_future_p)refcounted_futures[12])->future, + *((dfr_refcounted_future_p)refcounted_futures[13])->future, + *((dfr_refcounted_future_p)refcounted_futures[14])->future, + *((dfr_refcounted_future_p)refcounted_futures[15])->future, + *((dfr_refcounted_future_p)refcounted_futures[16])->future, + *((dfr_refcounted_future_p)refcounted_futures[17])->future, + *((dfr_refcounted_future_p)refcounted_futures[18])->future, + *((dfr_refcounted_future_p)refcounted_futures[19])->future)); break; default: @@ -675,53 +913,67 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, switch (num_outputs) { case 1: - *((void **)outputs[0]) = new hpx::shared_future(hpx::dataflow( - [](hpx::future oodf_in) - -> void * { return oodf_in.get().outputs[0]; }, - oodf)); - mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0])); + *((void **)outputs[0]) = (void *)new dfr_refcounted_future_t( + new hpx::shared_future(hpx::dataflow( + [refcounted_futures]( + hpx::future oodf_in) + -> void * { + void *ret = oodf_in.get().outputs[0]; + for (auto rcf : refcounted_futures) + _dfr_deallocate_future(rcf); + return ret; + }, + oodf)), + 1, output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF); break; case 2: { hpx::future> &&ft = hpx::dataflow( - [](hpx::future oodf_in) + [refcounted_futures]( + hpx::future oodf_in) -> hpx::tuple { std::vector outputs = std::move(oodf_in.get().outputs); + for (auto rcf : refcounted_futures) + _dfr_deallocate_future(rcf); return hpx::make_tuple<>(outputs[0], outputs[1]); }, oodf); hpx::tuple, hpx::future> &&tf = hpx::split_future(std::move(ft)); - *((void **)outputs[0]) = - (void *)new hpx::shared_future(std::move(hpx::get<0>(tf))); - *((void **)outputs[1]) = - (void *)new hpx::shared_future(std::move(hpx::get<1>(tf))); - mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0])); - mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1])); + *((void **)outputs[0]) = (void *)new dfr_refcounted_future_t( + new hpx::shared_future(std::move(hpx::get<0>(tf))), 1, + output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF); + *((void **)outputs[1]) = (void *)new dfr_refcounted_future_t( + new hpx::shared_future(std::move(hpx::get<1>(tf))), 1, + output_types[1] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF); break; } case 3: { hpx::future> &&ft = hpx::dataflow( - [](hpx::future oodf_in) + [refcounted_futures]( + hpx::future oodf_in) -> hpx::tuple { std::vector outputs = std::move(oodf_in.get().outputs); + for (auto rcf : refcounted_futures) + _dfr_deallocate_future(rcf); return hpx::make_tuple<>(outputs[0], outputs[1], outputs[2]); }, oodf); hpx::tuple, hpx::future, hpx::future> &&tf = hpx::split_future(std::move(ft)); - *((void **)outputs[0]) = - (void *)new hpx::shared_future(std::move(hpx::get<0>(tf))); - *((void **)outputs[1]) = - (void *)new hpx::shared_future(std::move(hpx::get<1>(tf))); - *((void **)outputs[2]) = - (void *)new hpx::shared_future(std::move(hpx::get<2>(tf))); - mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[0])); - mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[1])); - mlir::concretelang::dfr::fut_allocated.push_back(*((void **)outputs[2])); + *((void **)outputs[0]) = (void *)new dfr_refcounted_future_t( + new hpx::shared_future(std::move(hpx::get<0>(tf))), 1, + output_types[0] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF); + *((void **)outputs[1]) = (void *)new dfr_refcounted_future_t( + new hpx::shared_future(std::move(hpx::get<1>(tf))), 1, + output_types[1] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF); + *((void **)outputs[2]) = (void *)new dfr_refcounted_future_t( + new hpx::shared_future(std::move(hpx::get<2>(tf))), 1, + output_types[2] == mlir::concretelang::dfr::_DFR_TASK_ARG_MEMREF); break; } + default: HPX_THROW_EXCEPTION(hpx::no_success, "_dfr_create_async_task", "Error: number of task outputs not supported."); @@ -896,6 +1148,7 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { new mlir::concretelang::dfr::KeyManager(); new mlir::concretelang::dfr::KeyManager(); + new mlir::concretelang::dfr::RuntimeContextManager(); new mlir::concretelang::dfr::WorkFunctionRegistry(); mlir::concretelang::dfr::_dfr_jit_workfunction_registration_barrier = new hpx::lcos::barrier("wait_register_remote_work_functions", @@ -985,21 +1238,8 @@ void _dfr_stop() { // safer to drop them in-between phases. mlir::concretelang::dfr::_dfr_node_level_bsk_manager->clear_keys(); mlir::concretelang::dfr::_dfr_node_level_ksk_manager->clear_keys(); - - while (!mlir::concretelang::dfr::new_allocated.empty()) { - delete[] static_cast( - mlir::concretelang::dfr::new_allocated.front()); - mlir::concretelang::dfr::new_allocated.pop_front(); - } - while (!mlir::concretelang::dfr::fut_allocated.empty()) { - delete static_cast *>( - mlir::concretelang::dfr::fut_allocated.front()); - mlir::concretelang::dfr::fut_allocated.pop_front(); - } - while (!mlir::concretelang::dfr::m_allocated.empty()) { - free(mlir::concretelang::dfr::m_allocated.front()); - mlir::concretelang::dfr::m_allocated.pop_front(); - } + mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager + ->clearContext(); } void _dfr_try_initialize() { diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 0797991f2..775cab7d2 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -278,6 +278,11 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, pm, mlir::concretelang::createFixupDataflowTaskOpsPass(), enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass); + // Use the buffer deallocation interface to insert future deallocation calls + addPotentiallyNestedPass( + pm, mlir::bufferization::createBufferDeallocationPass(), enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass); // Convert to MLIR LLVM Dialect addPotentiallyNestedPass(