diff --git a/compiler/include/concretelang/Conversion/Utils/FuncConstOpConversion.h b/compiler/include/concretelang/Conversion/Utils/FuncConstOpConversion.h new file mode 100644 index 000000000..6538ebdab --- /dev/null +++ b/compiler/include/concretelang/Conversion/Utils/FuncConstOpConversion.h @@ -0,0 +1,66 @@ +// Part of the Concrete Compiler Project, under the BSD3 License with Zama +// Exceptions. See +// https://github.com/zama-ai/concrete-compiler-internal/blob/main/LICENSE.txt +// for license information. + +#include +#include +#include + +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +template +struct FunctionConstantOpConversion + : public mlir::OpRewritePattern { + FunctionConstantOpConversion(mlir::MLIRContext *ctx, + TypeConverterType &converter, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(ctx, benefit), + converter(converter) {} + + mlir::LogicalResult + matchAndRewrite(mlir::func::ConstantOp op, + mlir::PatternRewriter &rewriter) const override { + auto symTab = mlir::SymbolTable::getNearestSymbolTable(op); + auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue()); + assert(funcOp && + "Function symbol missing in symbol table for function constant op."); + mlir::FunctionType funType = mlir::cast(funcOp) + .getFunctionType() + .cast(); + typename TypeConverterType::SignatureConversion result( + funType.getNumInputs()); + mlir::SmallVector newResults; + if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) || + failed(converter.convertTypes(funType.getResults(), newResults))) + return mlir::failure(); + auto newType = mlir::FunctionType::get( + rewriter.getContext(), result.getConvertedTypes(), newResults); + rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newType); }); + return mlir::success(); + } + + static bool isLegal(mlir::func::ConstantOp fun, + TypeConverterType &converter) { + auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun); + auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue()); + assert(funcOp && + "Function symbol missing in symbol table for function constant op."); + mlir::FunctionType funType = mlir::cast(funcOp) + .getFunctionType() + .cast(); + typename TypeConverterType::SignatureConversion result( + funType.getNumInputs()); + mlir::SmallVector newResults; + if (failed(converter.convertSignatureArgs(funType.getInputs(), result)) || + failed(converter.convertTypes(funType.getResults(), newResults))) + return false; + auto newType = mlir::FunctionType::get( + fun.getContext(), result.getConvertedTypes(), newResults); + return newType == fun.getType(); + } + +private: + TypeConverterType &converter; +}; diff --git a/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.h b/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.h index c6b10aeeb..e13540cb3 100644 --- a/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.h +++ b/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.h @@ -22,7 +22,8 @@ createBuildDataflowTaskGraphPass(bool debug = false); std::unique_ptr createLowerDataflowTasksPass(bool debug = false); std::unique_ptr createBufferizeDataflowTaskOpsPass(bool debug = false); -std::unique_ptr createFixupDataflowTaskOpsPass(bool debug = false); +std::unique_ptr createFinalizeTaskCreationPass(bool debug = false); +std::unique_ptr createStartStopPass(bool debug = false); std::unique_ptr createFixupBufferDeallocationPass(bool debug = false); void populateRTToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, diff --git a/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td b/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td index 7901d4d4e..adae3a2f3 100644 --- a/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td +++ b/compiler/include/concretelang/Dialect/RT/Analysis/Autopar.td @@ -53,23 +53,6 @@ def BufferizeDataflowTaskOps : Pass<"BufferizeDataflowTaskOps", "mlir::ModuleOp" }]; } -def FixupDataflowTaskOps : Pass<"FixupDataflowTaskOps", "mlir::ModuleOp"> { - let summary = - "Fix DataflowTaskOp(s) before lowering."; - - let description = [{ - This pass fixes up code changes that intervene between the - BuildDataflowTaskGraph pass and the lowering of the taskgraph to - LLVMIR and calls to the DFR runtime system. - - In particular, some operations (e.g., constants, dimension - operations, etc.) can be used within the task while only defined - outside. In most cases cloning and sinking these operations in the - task is the simplest to avoid adding dependences. - - }]; -} - def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> { let summary = "Outline the body of a DataflowTaskOp into a separate function which will serve as a task work function and lower the task graph to RT."; @@ -82,6 +65,30 @@ def LowerDataflowTasks : Pass<"LowerDataflowTasks", "mlir::ModuleOp"> { }]; } +def FinalizeTaskCreation : Pass<"FinalizeTaskCreation", "mlir::ModuleOp"> { + let summary = + "Finalize the CreateAsyncTaskOp ops."; + + let description = [{ + This pass adds the lower level information missing in + CreateAsyncTaskOp, in particular the type sizes and if required + passing the runtime context. + }]; +} + +def StartStop : Pass<"StartStop", "mlir::ModuleOp"> { + let summary = + "Issue calls to start/stop the runtime system."; + + let description = [{ + This pass adds calls to _dfr_start and _dfr_stop which + respectively initialize/start and pause the runtime system. The + start function further distributes the evaluation keys to + compute nodes when required and the stop function clears the + execution context. + }]; +} + def FixupBufferDeallocation : Pass<"FixupBufferDeallocation", "mlir::ModuleOp"> { let summary = "Prevent deallocation of buffers returned as futures by tasks."; diff --git a/compiler/include/concretelang/Runtime/DFRuntime.hpp b/compiler/include/concretelang/Runtime/DFRuntime.hpp index d6b9103cc..7a94596df 100644 --- a/compiler/include/concretelang/Runtime/DFRuntime.hpp +++ b/compiler/include/concretelang/Runtime/DFRuntime.hpp @@ -24,6 +24,7 @@ void _dfr_set_use_omp(bool); bool _dfr_is_jit(); bool _dfr_is_root_node(); bool _dfr_use_omp(); +bool _dfr_is_distributed(); typedef enum _dfr_task_arg_type { _DFR_TASK_ARG_BASE = 0, diff --git a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp index 8f02782ce..2d860adf8 100644 --- a/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp +++ b/compiler/include/concretelang/Runtime/distributed_generic_task_server.hpp @@ -69,23 +69,27 @@ struct OpaqueInputData { std::vector _param_sizes, std::vector _param_types, std::vector _output_sizes, - std::vector _output_types) + std::vector _output_types, void *_context = nullptr) : wfn_name(_wfn_name), params(std::move(_params)), param_sizes(std::move(_param_sizes)), param_types(std::move(_param_types)), output_sizes(std::move(_output_sizes)), - output_types(std::move(_output_types)) {} + output_types(std::move(_output_types)), context(_context) { + if (_context) + params.push_back(_context); + } OpaqueInputData(const OpaqueInputData &oid) : wfn_name(std::move(oid.wfn_name)), params(std::move(oid.params)), param_sizes(std::move(oid.param_sizes)), param_types(std::move(oid.param_types)), output_sizes(std::move(oid.output_sizes)), - output_types(std::move(oid.output_types)) {} + output_types(std::move(oid.output_types)), context(oid.context) {} friend class hpx::serialization::access; template void load(Archive &ar, const unsigned int version) { - ar >> wfn_name; + bool has_context; + ar >> wfn_name >> has_context; ar >> param_sizes >> param_types; ar >> output_sizes >> output_types; for (size_t p = 0; p < param_sizes.size(); ++p) { @@ -114,27 +118,22 @@ struct OpaqueInputData { static_cast *>(params[p])->basePtr = nullptr; static_cast *>(params[p])->data = data; } break; - case _DFR_TASK_ARG_CONTEXT: { - // The copied pointer is meaningless - TODO: if the context - // can change dynamically (e.g., different evaluation keys) - // then this needs updating by passing key ids and retrieving - // adequate keys for the context. - delete ((char *)params[p]); - params[p] = - (void *)_dfr_node_level_runtime_context_manager->getContext(); - } break; default: HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", "Error: invalid task argument type."); } } + if (has_context) + params.push_back( + (void *)_dfr_node_level_runtime_context_manager->getContext()); } template void save(Archive &ar, const unsigned int version) const { - ar << wfn_name; + bool has_context = (bool)(context != nullptr); + ar << wfn_name << has_context; ar << param_sizes << param_types; ar << output_sizes << output_types; - for (size_t p = 0; p < params.size(); ++p) { + for (size_t p = 0; p < param_sizes.size(); ++p) { // Save the first level of the data structure - if the parameter // is a tensor/memref, there is a second level. ar << hpx::serialization::make_array((char *)params[p], param_sizes[p]); @@ -152,10 +151,6 @@ struct OpaqueInputData { ar << hpx::serialization::make_array( mref.data + mref.offset * elementSize, size * elementSize); } break; - case _DFR_TASK_ARG_CONTEXT: { - // Nothing to do now - TODO: pass key ids if these are not - // unique for a computation. - } break; default: HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", "Error: invalid task argument type."); @@ -170,6 +165,7 @@ struct OpaqueInputData { std::vector param_types; std::vector output_sizes; std::vector output_types; + void *context; }; struct OpaqueOutputData { @@ -214,9 +210,6 @@ struct OpaqueOutputData { static_cast *>(outputs[p])->basePtr = nullptr; static_cast *>(outputs[p])->data = data; - } break; - case _DFR_TASK_ARG_CONTEXT: { - } break; default: HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", @@ -243,9 +236,6 @@ struct OpaqueOutputData { size *= mref.sizes[r]; ar << hpx::serialization::make_array( mref.data + mref.offset * elementSize, size * elementSize); - } break; - case _DFR_TASK_ARG_CONTEXT: { - } break; default: HPX_THROW_EXCEPTION(hpx::no_success, "DFR: OpaqueInputData save", @@ -282,121 +272,121 @@ struct GenericComputeServer : component_base { wfn(output); break; case 1: - wfn(inputs.params[0], output); + wfn(output, inputs.params[0]); break; case 2: - wfn(inputs.params[0], inputs.params[1], output); + wfn(output, inputs.params[0], inputs.params[1]); break; case 3: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], output); + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2]); break; case 4: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], output); + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3]); break; case 5: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], output); + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4]); break; case 6: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], output); + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], + inputs.params[3], inputs.params[4], inputs.params[5]); break; case 7: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], output); + inputs.params[6]); break; case 8: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], output); + inputs.params[6], inputs.params[7]); break; case 9: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], output); + inputs.params[6], inputs.params[7], inputs.params[8]); break; case 10: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], output); + inputs.params[9]); break; case 11: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], output); + inputs.params[9], inputs.params[10]); break; case 12: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], output); + inputs.params[9], inputs.params[10], inputs.params[11]); break; case 13: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], output); + inputs.params[12]); break; case 14: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], output); + inputs.params[12], inputs.params[13]); break; case 15: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], output); + inputs.params[12], inputs.params[13], inputs.params[14]); break; case 16: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], output); + inputs.params[15]); break; case 17: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], output); + inputs.params[15], inputs.params[16]); break; case 18: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], inputs.params[17], output); + inputs.params[15], inputs.params[16], inputs.params[17]); break; case 19: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], inputs.params[12], inputs.params[13], inputs.params[14], inputs.params[15], inputs.params[16], inputs.params[17], - inputs.params[18], output); + inputs.params[18]); break; case 20: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], + wfn(output, inputs.params[0], inputs.params[1], inputs.params[2], inputs.params[3], inputs.params[4], inputs.params[5], inputs.params[6], inputs.params[7], inputs.params[8], inputs.params[9], inputs.params[10], inputs.params[11], inputs.params[12], inputs.params[13], inputs.params[14], inputs.params[15], inputs.params[16], inputs.params[17], - inputs.params[18], inputs.params[19], output); + inputs.params[18], inputs.params[19]); break; default: HPX_THROW_EXCEPTION(hpx::no_success, @@ -415,127 +405,127 @@ struct GenericComputeServer : component_base { wfn(output1, output2); break; case 1: - wfn(inputs.params[0], output1, output2); + wfn(output1, output2, inputs.params[0]); break; case 2: - wfn(inputs.params[0], inputs.params[1], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1]); break; case 3: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, - output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], output1, output2); break; case 4: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3]); break; case 5: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4]); break; case 6: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], output1, - output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], output1, output2); break; case 7: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6]); break; case 8: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7]); break; case 9: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], output1, - output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], output1, output2); break; case 10: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9]); break; case 11: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10]); break; case 12: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], output1, - output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], output1, output2); break; case 13: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12]); break; case 14: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13]); break; case 15: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], output1, - output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], output1, output2); break; case 16: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15]); break; case 17: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16]); break; case 18: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], inputs.params[17], output1, - output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16], + inputs.params[17], output1, output2); break; case 19: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], inputs.params[17], - inputs.params[18], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16], + inputs.params[17], inputs.params[18]); break; case 20: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], inputs.params[17], - inputs.params[18], inputs.params[19], output1, output2); + wfn(output1, output2, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16], + inputs.params[17], inputs.params[18], inputs.params[19]); break; default: HPX_THROW_EXCEPTION(hpx::no_success, @@ -555,127 +545,127 @@ struct GenericComputeServer : component_base { wfn(output1, output2, output3); break; case 1: - wfn(inputs.params[0], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0]); break; case 2: - wfn(inputs.params[0], inputs.params[1], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1]); break; case 3: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], output1, - output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], output1, output2, output3); break; case 4: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3]); break; case 5: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4]); break; case 6: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], output1, - output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], output1, output2, output3); break; case 7: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6]); break; case 8: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7]); break; case 9: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], output1, - output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], output1, output2, output3); break; case 10: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9]); break; case 11: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10]); break; case 12: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], output1, - output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], output1, output2, output3); break; case 13: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12]); break; case 14: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13]); break; case 15: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], output1, - output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], output1, output2, output3); break; case 16: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15]); break; case 17: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16]); break; case 18: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], inputs.params[17], output1, - output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16], + inputs.params[17], output1, output2, output3); break; case 19: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], inputs.params[17], - inputs.params[18], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16], + inputs.params[17], inputs.params[18]); break; case 20: - wfn(inputs.params[0], inputs.params[1], inputs.params[2], - inputs.params[3], inputs.params[4], inputs.params[5], - inputs.params[6], inputs.params[7], inputs.params[8], - inputs.params[9], inputs.params[10], inputs.params[11], - inputs.params[12], inputs.params[13], inputs.params[14], - inputs.params[15], inputs.params[16], inputs.params[17], - inputs.params[18], inputs.params[19], output1, output2, output3); + wfn(output1, output2, output3, inputs.params[0], inputs.params[1], + inputs.params[2], inputs.params[3], inputs.params[4], + inputs.params[5], inputs.params[6], inputs.params[7], + inputs.params[8], inputs.params[9], inputs.params[10], + inputs.params[11], inputs.params[12], inputs.params[13], + inputs.params[14], inputs.params[15], inputs.params[16], + inputs.params[17], inputs.params[18], inputs.params[19]); break; default: HPX_THROW_EXCEPTION(hpx::no_success, @@ -693,12 +683,10 @@ struct GenericComputeServer : component_base { // Deallocate input data buffers from OID deserialization (load) if (!_dfr_is_root_node()) { for (size_t p = 0; p < inputs.param_sizes.size(); ++p) { - if (_dfr_get_arg_type(inputs.param_types[p]) != _DFR_TASK_ARG_CONTEXT) { - if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF) - delete (static_cast *>(inputs.params[p]) - ->data); - delete ((char *)inputs.params[p]); - } + if (_dfr_get_arg_type(inputs.param_types[p]) == _DFR_TASK_ARG_MEMREF) + delete (static_cast *>(inputs.params[p]) + ->data); + delete ((char *)inputs.params[p]); } } diff --git a/compiler/include/concretelang/Runtime/key_manager.hpp b/compiler/include/concretelang/Runtime/key_manager.hpp index e80dc1773..01859a4d2 100644 --- a/compiler/include/concretelang/Runtime/key_manager.hpp +++ b/compiler/include/concretelang/Runtime/key_manager.hpp @@ -180,7 +180,7 @@ struct RuntimeContextManager { } } - RuntimeContext **getContext() { return &context; } + RuntimeContext *getContext() { return context; } void clearContext() { if (context != nullptr) diff --git a/compiler/include/concretelang/Runtime/runtime_api.h b/compiler/include/concretelang/Runtime/runtime_api.h index 4c1d2d6fd..32ba131c4 100644 --- a/compiler/include/concretelang/Runtime/runtime_api.h +++ b/compiler/include/concretelang/Runtime/runtime_api.h @@ -15,7 +15,7 @@ extern "C" { typedef void (*wfnptr)(...); void *_dfr_make_ready_future(void *, size_t); -void _dfr_create_async_task(wfnptr, size_t, size_t, ...); +void _dfr_create_async_task(wfnptr, void *, size_t, size_t, ...); void _dfr_register_work_function(wfnptr); void *_dfr_await_future(void *); @@ -26,8 +26,7 @@ void _dfr_deallocate_future(void *); void _dfr_deallocate_future_data(void *); /* Initialisation & termination. */ -void _dfr_start_c(int64_t, void *); -void _dfr_start(int64_t); +void _dfr_start(int64_t, void *); void _dfr_stop(int64_t); void _dfr_terminate(); diff --git a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp index 86865b07b..361196184 100644 --- a/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp +++ b/compiler/lib/Conversion/ConcreteToBConcrete/ConcreteToBConcrete.cpp @@ -25,6 +25,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/BConcrete/IR/BConcreteDialect.h" @@ -108,6 +109,16 @@ public: newShape, mlir::IntegerType::get(type.getContext(), 64)); return r; }); + addConversion([&](mlir::concretelang::RT::FutureType type) { + return mlir::concretelang::RT::FutureType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); + addConversion([&](mlir::concretelang::RT::PointerType type) { + return mlir::concretelang::RT::PointerType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); } }; @@ -956,6 +967,14 @@ void ConcreteToBConcretePass::runOnOperation() { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); + target.addDynamicallyLegalOp( + [&](mlir::func::ConstantOp op) { + return FunctionConstantOpConversion< + ConcreteToBConcreteTypeConverter>::isLegal(op, converter); + }); + patterns + .insert>( + &getContext(), converter); target.addDynamicallyLegalOp([&](mlir::scf::ForOp forOp) { return converter.isLegal(forOp.getInitArgs().getTypes()) && @@ -969,19 +988,44 @@ void ConcreteToBConcretePass::runOnOperation() { converter.isLegal(op->getOperandTypes()); }); + // Conversion of RT Dialect Ops patterns.add< mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern, mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::DataflowTaskOp>, + mlir::concretelang::RT::MakeReadyFutureOp>, mlir::concretelang::GenericTypeConverterPattern< - mlir::concretelang::RT::DataflowYieldOp>>(&getContext(), converter); - - // Conversion of RT Dialect Ops + mlir::concretelang::RT::AwaitFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::CreateAsyncTaskOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::WorkFunctionReturnOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), + converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowTaskOp>(target, converter); + mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowYieldOp>(target, converter); + mlir::concretelang::RT::AwaitFutureOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( + target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) diff --git a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp index bb1f3ae8d..43b390a36 100644 --- a/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp +++ b/compiler/lib/Conversion/FHEToTFHE/FHEToTFHE.cpp @@ -14,11 +14,14 @@ #include "concretelang/Conversion/FHEToTFHE/Patterns.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/FHE/IR/FHEDialect.h" #include "concretelang/Dialect/FHE/IR/FHETypes.h" +#include "concretelang/Dialect/RT/IR/RTDialect.h" #include "concretelang/Dialect/RT/IR/RTOps.h" +#include "concretelang/Dialect/RT/IR/RTTypes.h" #include "concretelang/Dialect/TFHE/IR/TFHEDialect.h" #include "concretelang/Dialect/TFHE/IR/TFHETypes.h" @@ -53,6 +56,16 @@ public: eint.getContext(), eint)); return r; }); + addConversion([&](mlir::concretelang::RT::FutureType type) { + return mlir::concretelang::RT::FutureType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); + addConversion([&](mlir::concretelang::RT::PointerType type) { + return mlir::concretelang::RT::PointerType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); } }; @@ -269,6 +282,11 @@ struct FHEToTFHEPass : public FHEToTFHEBase { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); + target.addDynamicallyLegalOp( + [&](mlir::func::ConstantOp op) { + return FunctionConstantOpConversion::isLegal( + op, converter); + }); // Add all patterns required to lower all ops from `FHE` to // `TFHE` @@ -292,6 +310,8 @@ struct FHEToTFHEPass : public FHEToTFHEBase { patterns.add(&getContext()); patterns.add(&getContext()); + patterns.add>( + &getContext(), converter); patterns.add>( @@ -319,16 +339,43 @@ struct FHEToTFHEPass : public FHEToTFHEBase { patterns, converter); // Conversion of RT Dialect Ops - patterns.add>(patterns.getContext(), - converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::MakeReadyFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::AwaitFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::CreateAsyncTaskOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::WorkFunctionReturnOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), + converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowTaskOp>(target, converter); - patterns.add>(patterns.getContext(), - converter); + mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowYieldOp>(target, converter); + mlir::concretelang::RT::AwaitFutureOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( + target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) diff --git a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp index 129800157..e0fde7bbc 100644 --- a/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp +++ b/compiler/lib/Conversion/TFHEGlobalParametrization/TFHEGlobalParametrization.cpp @@ -7,6 +7,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "concretelang/Conversion/Passes.h" +#include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/GenericOpTypeConversionPattern.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" @@ -52,6 +53,16 @@ public: this->glweInterPBSType(glwe)); return r; }); + addConversion([&](mlir::concretelang::RT::FutureType type) { + return mlir::concretelang::RT::FutureType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); + addConversion([&](mlir::concretelang::RT::PointerType type) { + return mlir::concretelang::RT::PointerType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); } TFHE::GLWECipherTextType glweInterPBSType(GLWECipherTextType &type) { @@ -293,6 +304,14 @@ void TFHEGlobalParametrizationPass::runOnOperation() { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); + target.addDynamicallyLegalOp( + [&](mlir::func::ConstantOp op) { + return FunctionConstantOpConversion< + TFHEGlobalParametrizationTypeConverter>::isLegal(op, converter); + }); + patterns.add< + FunctionConstantOpConversion>( + &getContext(), converter); mlir::populateFunctionOpInterfaceTypeConversionPattern( patterns, converter); @@ -354,16 +373,43 @@ void TFHEGlobalParametrizationPass::runOnOperation() { patterns, target, converter); // Conversion of RT Dialect Ops - patterns.add>(patterns.getContext(), - converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::MakeReadyFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::AwaitFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::CreateAsyncTaskOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::WorkFunctionReturnOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), + converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowTaskOp>(target, converter); - patterns.add>(patterns.getContext(), - converter); + mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowYieldOp>(target, converter); + mlir::concretelang::RT::AwaitFutureOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( + target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); // Apply conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) diff --git a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp index 53fc4f7a1..ccce14a7c 100644 --- a/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp +++ b/compiler/lib/Conversion/TFHEToConcrete/TFHEToConcrete.cpp @@ -10,6 +10,7 @@ #include "concretelang/Conversion/Passes.h" #include "concretelang/Conversion/TFHEToConcrete/Patterns.h" +#include "concretelang/Conversion/Utils/FuncConstOpConversion.h" #include "concretelang/Conversion/Utils/RegionOpTypeConverterPattern.h" #include "concretelang/Conversion/Utils/TensorOpTypeConversion.h" #include "concretelang/Dialect/Concrete/IR/ConcreteDialect.h" @@ -50,6 +51,16 @@ public: mlir::concretelang::convertTypeToLWE(glwe.getContext(), glwe)); return r; }); + addConversion([&](mlir::concretelang::RT::FutureType type) { + return mlir::concretelang::RT::FutureType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); + addConversion([&](mlir::concretelang::RT::PointerType type) { + return mlir::concretelang::RT::PointerType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); } }; @@ -174,10 +185,17 @@ void TFHEToConcretePass::runOnOperation() { return converter.isSignatureLegal(funcOp.getFunctionType()) && converter.isLegal(&funcOp.getBody()); }); + target.addDynamicallyLegalOp( + [&](mlir::func::ConstantOp op) { + return FunctionConstantOpConversion< + TFHEToConcreteTypeConverter>::isLegal(op, converter); + }); // Add all patterns required to lower all ops from `TFHE` to // `Concrete` mlir::RewritePatternSet patterns(&getContext()); + patterns.add>( + &getContext(), converter); populateWithGeneratedTFHEToConcrete(patterns); patterns.add>(patterns.getContext(), - converter); + patterns.add< + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::MakeReadyFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::AwaitFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::CreateAsyncTaskOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::WorkFunctionReturnOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), + converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowTaskOp>(target, converter); - patterns.add>(patterns.getContext(), - converter); + mlir::concretelang::RT::MakeReadyFutureOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp< - mlir::concretelang::RT::DataflowYieldOp>(target, converter); + mlir::concretelang::RT::AwaitFutureOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::CreateAsyncTaskOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( + target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::WorkFunctionReturnOp>(target, converter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, converter); mlir::concretelang::addDynamicallyLegalTypeOp( target, converter); diff --git a/compiler/lib/Dialect/BConcrete/Transforms/AddRuntimeContext.cpp b/compiler/lib/Dialect/BConcrete/Transforms/AddRuntimeContext.cpp index a96d01078..2b7bbb440 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/AddRuntimeContext.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/AddRuntimeContext.cpp @@ -32,28 +32,13 @@ struct AddRuntimeContextToFuncOpPattern rewriter.getType()); mlir::FunctionType newFuncTy = rewriter.getType( newInputs, oldFuncType.getResults()); - // Create the new func - mlir::func::FuncOp newFuncOp = rewriter.create( - oldFuncOp.getLoc(), oldFuncOp.getName(), newFuncTy); - // Create the arguments of the new func - mlir::Region &newFuncBody = newFuncOp.getBody(); - mlir::Block *newFuncEntryBlock = new mlir::Block(); - llvm::SmallVector locations(newFuncTy.getInputs().size(), - oldFuncOp.getLoc()); + rewriter.updateRootInPlace(oldFuncOp, + [&] { oldFuncOp.setType(newFuncTy); }); + oldFuncOp.getBody().front().addArgument( + rewriter.getType(), + oldFuncOp.getLoc()); - newFuncEntryBlock->addArguments(newFuncTy.getInputs(), locations); - newFuncBody.push_back(newFuncEntryBlock); - - // Clone the old body to the new one - mlir::BlockAndValueMapping map; - for (auto arg : llvm::enumerate(oldFuncOp.getArguments())) { - map.map(arg.value(), newFuncEntryBlock->getArgument(arg.index())); - } - for (auto &op : oldFuncOp.getBody().front()) { - newFuncEntryBlock->push_back(op.clone(map)); - } - rewriter.eraseOp(oldFuncOp); return mlir::success(); } @@ -72,6 +57,50 @@ struct AddRuntimeContextToFuncOpPattern } }; +namespace { +struct FunctionConstantOpConversion + : public mlir::OpRewritePattern { + FunctionConstantOpConversion(mlir::MLIRContext *ctx, + mlir::PatternBenefit benefit = 1) + : ::mlir::OpRewritePattern(ctx, benefit) {} + ::mlir::LogicalResult + matchAndRewrite(mlir::func::ConstantOp op, + mlir::PatternRewriter &rewriter) const override { + auto symTab = mlir::SymbolTable::getNearestSymbolTable(op); + auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, op.getValue()); + assert(funcOp && + "Function symbol missing in symbol table for function constant op."); + mlir::FunctionType funType = mlir::cast(funcOp) + .getFunctionType() + .cast(); + mlir::SmallVector newInputs(funType.getInputs().begin(), + funType.getInputs().end()); + newInputs.push_back( + rewriter.getType()); + mlir::FunctionType newFuncTy = + rewriter.getType(newInputs, funType.getResults()); + + rewriter.updateRootInPlace(op, [&] { op.getResult().setType(newFuncTy); }); + return mlir::success(); + } + static bool isLegal(mlir::func::ConstantOp fun) { + auto symTab = mlir::SymbolTable::getNearestSymbolTable(fun); + auto funcOp = mlir::SymbolTable::lookupSymbolIn(symTab, fun.getValue()); + assert(funcOp && + "Function symbol missing in symbol table for function constant op."); + mlir::FunctionType funType = mlir::cast(funcOp) + .getFunctionType() + .cast(); + if ((AddRuntimeContextToFuncOpPattern::isLegal( + mlir::cast(funcOp)) && + fun.getType() == funType) || + fun.getType() != funType) + return true; + return false; + } +}; +} // namespace + struct AddRuntimeContextPass : public AddRuntimeContextBase { void runOnOperation() final; @@ -90,8 +119,13 @@ void AddRuntimeContextPass::runOnOperation() { [&](mlir::func::FuncOp funcOp) { return AddRuntimeContextToFuncOpPattern::isLegal(funcOp); }); + target.addDynamicallyLegalOp( + [&](mlir::func::ConstantOp op) { + return FunctionConstantOpConversion::isLegal(op); + }); patterns.add(patterns.getContext()); + patterns.add(patterns.getContext()); // Apply the conversion if (mlir::applyPartialConversion(op, target, std::move(patterns)) diff --git a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp index f5680051b..b95fb3fac 100644 --- a/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/BConcrete/Transforms/BufferizableOpInterfaceImpl.cpp @@ -153,6 +153,9 @@ mlir::Value getContextArgument(mlir::Operation *op) { mlir::Block *block = op->getBlock(); while (block != nullptr) { if (llvm::isa(block->getParentOp())) { + block = &mlir::cast(block->getParentOp()) + .getBody() + .front(); auto context = std::find_if(block->getArguments().rbegin(), @@ -160,7 +163,6 @@ mlir::Value getContextArgument(mlir::Operation *op) { return arg.getType() .isa(); }); - assert(context != block->getArguments().rend() && "Cannot find the Concrete.context"); diff --git a/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp b/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp index 47ec79760..dfab45cf4 100644 --- a/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp +++ b/compiler/lib/Dialect/RT/Analysis/BufferizeDataflowTaskOps.cpp @@ -10,11 +10,16 @@ #include #include +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Transforms/DialectConversion.h" +#include #include #include #include #include #include +#include #include #include #include @@ -27,68 +32,35 @@ namespace mlir { namespace concretelang { namespace { -class BufferizeDataflowYieldOp - : public OpConversionPattern { + +class BufferizeRTTypesConverter + : public mlir::bufferization::BufferizeTypeConverter { public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(RT::DataflowYieldOp op, RT::DataflowYieldOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, mlir::TypeRange(), - adaptor.getOperands()); - return success(); + BufferizeRTTypesConverter() { + addConversion([&](mlir::concretelang::RT::FutureType type) { + return mlir::concretelang::RT::FutureType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); + addConversion([&](mlir::concretelang::RT::PointerType type) { + return mlir::concretelang::RT::PointerType::get( + this->convertType(type.dyn_cast() + .getElementType())); + }); + addConversion([&](mlir::FunctionType type) { + SignatureConversion result(type.getNumInputs()); + mlir::SmallVector newResults; + if (failed(this->convertSignatureArgs(type.getInputs(), result)) || + failed(this->convertTypes(type.getResults(), newResults))) + return type; + return mlir::FunctionType::get(type.getContext(), + result.getConvertedTypes(), newResults); + }); } }; + } // namespace -namespace { -class BufferizeDataflowTaskOp : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(RT::DataflowTaskOp op, RT::DataflowTaskOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - mlir::OpBuilder::InsertionGuard guard(rewriter); - - SmallVector newResults; - (void)getTypeConverter()->convertTypes(op.getResultTypes(), newResults); - auto newop = rewriter.create(op.getLoc(), newResults, - adaptor.getOperands()); - // We cannot clone here as cloned ops must be legalized (so this - // would break on the YieldOp). Instead use mergeBlocks which - // moves the ops instead of cloning. - rewriter.mergeBlocks(op.getBody(), newop.getBody(), - newop.getBody()->getArguments()); - // Because of previous bufferization there are buffer cast ops - // that have been generated for the previously tensor results of - // some tasks. These cannot just be replaced directly as the - // task's results would still be live. - for (auto res : llvm::enumerate(op.getResults())) { - // If this result is getting bufferized ... - if (res.value().getType() != - getTypeConverter()->convertType(res.value().getType())) { - for (auto &use : llvm::make_early_inc_range(res.value().getUses())) { - // ... and its uses are in `ToMemrefOp`s, then we - // replace further uses of the buffer cast. - if (isa(use.getOwner())) { - rewriter.replaceOp(use.getOwner(), {newop.getResult(res.index())}); - } - } - } - } - rewriter.replaceOp(op, {newop.getResults()}); - return success(); - } -}; -} // namespace - -void populateRTBufferizePatterns( - mlir::bufferization::BufferizeTypeConverter &typeConverter, - RewritePatternSet &patterns) { - patterns.add( - typeConverter, patterns.getContext()); -} - namespace { /// For documentation see Autopar.td struct BufferizeDataflowTaskOpsPass @@ -97,15 +69,75 @@ struct BufferizeDataflowTaskOpsPass void runOnOperation() override { auto module = getOperation(); auto *context = &getContext(); - mlir::bufferization::BufferizeTypeConverter typeConverter; + BufferizeRTTypesConverter typeConverter; + RewritePatternSet patterns(context); ConversionTarget target(*context); - populateRTBufferizePatterns(typeConverter, patterns); + populateFunctionOpInterfaceTypeConversionPattern( + patterns, typeConverter); + patterns.add>( + context, typeConverter); - // Forbid all RT ops that still use/return tensors - target.addDynamicallyLegalDialect( - [&](Operation *op) { return typeConverter.isLegal(op); }); + target.addDynamicallyLegalDialect([&](Operation + *op) { + if (auto fun = dyn_cast_or_null(op)) + return typeConverter.isSignatureLegal(fun.getFunctionType()) && + typeConverter.isLegal(&fun.getBody()); + if (auto fun = dyn_cast_or_null(op)) + return FunctionConstantOpConversion::isLegal( + fun, typeConverter); + return typeConverter.isLegal(op); + }); + + patterns.add< + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DataflowTaskOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DataflowYieldOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::MakeReadyFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::AwaitFutureOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::CreateAsyncTaskOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::WorkFunctionReturnOp>, + mlir::concretelang::GenericTypeConverterPattern< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>>(&getContext(), + typeConverter); + + // Conversion of RT Dialect Ops + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DataflowTaskOp>(target, typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DataflowYieldOp>(target, typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::MakeReadyFutureOp>(target, typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::AwaitFutureOp>(target, typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::CreateAsyncTaskOp>(target, typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::BuildReturnPtrPlaceholderOp>(target, + typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefWorkFunctionArgumentPtrPlaceholderOp>( + target, typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::DerefReturnPtrPlaceholderOp>(target, + typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::WorkFunctionReturnOp>(target, typeConverter); + mlir::concretelang::addDynamicallyLegalTypeOp< + mlir::concretelang::RT::RegisterTaskWorkFunctionOp>(target, + typeConverter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); diff --git a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp index 6a67426da..26c148aec 100644 --- a/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp +++ b/compiler/lib/Dialect/RT/Analysis/BuildDataflowTaskGraph.cpp @@ -18,7 +18,6 @@ #include #include -#include #include #include #include @@ -61,9 +60,8 @@ static bool isAggregatingBeneficiary(Operation *op) { return isa(op); + FHELinalg::FromElementOp, arith::ConstantOp, arith::SelectOp, + mlir::arith::CmpIOp>(op); } static bool @@ -95,87 +93,6 @@ aggregateBeneficiaryOps(Operation *op, SetVector &beneficiaryOps, return true; } -static bool isFunctionCallName(OpOperand *use, StringRef name) { - func::CallOp call = dyn_cast_or_null(use->getOwner()); - if (!call) - return false; - SymbolRefAttr sym = call.getCallableForCallee().dyn_cast(); - if (!sym) - return false; - func::FuncOp called = dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(call, sym)); - if (!called) - return false; - return called.getName() == name; -} - -static void getAliasedUses(Value val, DenseSet &aliasedUses) { - for (auto &use : val.getUses()) { - aliasedUses.insert(&use); - if (dyn_cast(use.getOwner())) - getAliasedUses(use.getOwner()->getResult(0), aliasedUses); - } -} - -static bool aggregateOutputMemrefAllocations( - Operation *op, SetVector &beneficiaryOps, - llvm::SmallPtrSetImpl &availableValues, RT::DataflowTaskOp taskOp) { - if (beneficiaryOps.count(op)) - return true; - - if (!isa(op)) - return false; - - Value val = op->getResults().front(); - DenseSet aliasedUses; - getAliasedUses(val, aliasedUses); - - // Helper function checking if a memref use writes to memory - auto hasMemoryWriteEffect = [&](OpOperand *use) { - // Call ops targeting concrete-ffi do not have memory effects - // interface, so handle apart. - // TODO: this could be handled better in BConcrete or higher. - if (isFunctionCallName(use, "memref_expand_lut_in_trivial_glwe_ct_u64") || - isFunctionCallName(use, "memref_add_lwe_ciphertexts_u64") || - isFunctionCallName(use, "memref_add_plaintext_lwe_ciphertext_u64") || - isFunctionCallName(use, "memref_mul_cleartext_lwe_ciphertext_u64") || - isFunctionCallName(use, "memref_negate_lwe_ciphertext_u64") || - isFunctionCallName(use, "memref_keyswitch_lwe_u64") || - isFunctionCallName(use, "memref_bootstrap_lwe_u64")) - if (use->getOwner()->getOperand(0) == use->get()) - return true; - - if (isFunctionCallName(use, "memref_copy_one_rank")) - if (use->getOwner()->getOperand(1) == use->get()) - return true; - - // Otherwise we rely on the memory effect interface - auto effectInterface = dyn_cast(use->getOwner()); - if (!effectInterface) - return false; - SmallVector effects; - effectInterface.getEffects(effects); - for (auto eff : effects) - if (isa(eff.getEffect()) && - eff.getValue() == use->get()) - return true; - return false; - }; - - // We need to check if this allocated memref is written in this task. - // TODO: for now we'll assume that we don't do partial writes or read/writes. - for (auto use : aliasedUses) - if (hasMemoryWriteEffect(use) && - use->getOwner()->getParentOfType() == taskOp) { - // We will sink the operation, mark its results as now available. - beneficiaryOps.insert(op); - for (Value result : op->getResults()) - availableValues.insert(result); - return true; - } - return false; -} - LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) { Region &taskOpBody = taskOp.body(); @@ -191,8 +108,6 @@ LogicalResult coarsenDFTask(RT::DataflowTaskOp taskOp) { if (!operandOp) continue; aggregateBeneficiaryOps(operandOp, toBeSunk, availableValues); - aggregateOutputMemrefAllocations(operandOp, toBeSunk, availableValues, - taskOp); } // Insert operations so that the defs get cloned before uses. @@ -283,45 +198,5 @@ std::unique_ptr createBuildDataflowTaskGraphPass(bool debug) { return std::make_unique(debug); } -namespace { -/// For documentation see Autopar.td -struct FixupDataflowTaskOpsPass - : public FixupDataflowTaskOpsBase { - - void runOnOperation() override { - auto module = getOperation(); - - module->walk([](RT::DataflowTaskOp op) { - assert(!failed(coarsenDFTask(op)) && - "Failing to sink operations into DFT"); - }); - - // Finally clear up any remaining alloc/dealloc ops that are - // meaningless - SetVector eraseOps; - module->walk([&](memref::AllocOp op) { - // If this memref.alloc's only use left is the - // dealloc, erase both. - if (op->hasOneUse() && - isa(op->use_begin()->getOwner())) { - eraseOps.insert(op->use_begin()->getOwner()); - eraseOps.insert(op); - } - }); - for (auto op : eraseOps) - op->erase(); - } - - FixupDataflowTaskOpsPass(bool debug) : debug(debug){}; - -protected: - bool debug; -}; -} // end anonymous namespace - -std::unique_ptr createFixupDataflowTaskOpsPass(bool debug) { - return std::make_unique(debug); -} - } // end namespace concretelang } // end namespace mlir diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index 2355905b4..fffcc74b2 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -18,9 +19,7 @@ #include #include #include -#include -#include #include #include #include @@ -39,7 +38,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -67,10 +68,10 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp, // types, which will be changed to use an indirection when lowering. SmallVector operandTypes; operandTypes.reserve(DFTOp.getNumOperands() + DFTOp.getNumResults()); - for (Value operand : DFTOp.getOperands()) - operandTypes.push_back(RT::PointerType::get(operand.getType())); for (Value res : DFTOp.getResults()) operandTypes.push_back(RT::PointerType::get(res.getType())); + for (Value operand : DFTOp.getOperands()) + operandTypes.push_back(RT::PointerType::get(operand.getType())); FunctionType type = FunctionType::get(DFTOp.getContext(), operandTypes, {}); auto outlinedFunc = builder.create(loc, workFunctionName, type); @@ -82,6 +83,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp, outlinedFuncBody.push_back(outlinedEntryBlock); BlockAndValueMapping map; + int input_offset = DFTOp.getNumResults(); Block &entryBlock = outlinedFuncBody.front(); builder.setInsertionPointToStart(&entryBlock); for (auto operand : llvm::enumerate(DFTOp.getOperands())) { @@ -89,7 +91,7 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp, auto derefdop = builder.create( DFTOp.getLoc(), operand.value().getType(), - entryBlock.getArgument(operand.index())); + entryBlock.getArgument(operand.index() + input_offset)); map.map(operand.value(), derefdop->getResult(0)); } DFTOpBody.cloneInto(&outlinedFuncBody, map); @@ -99,17 +101,14 @@ static func::FuncOp outlineWorkFunction(RT::DataflowTaskOp DFTOp, builder.setInsertionPointToEnd(&entryBlock); builder.create(loc, clonedDFTOpEntry); - // TODO: we use a WorkFunctionReturnOp to tie return to the - // corresponding argument. This can be lowered to a copy/deref for - // shared memory and pointers, but needs to be handled for - // distributed memory. + // WorkFunctionReturnOp ties return to the corresponding argument. + // This is lowered to a copy/deref for shared memory and pointers, + // and handled in the serializer for distributed memory. outlinedFunc.walk([&](RT::DataflowYieldOp op) { OpBuilder replacer(op); - int output_offset = DFTOp.getNumOperands(); for (auto ret : llvm::enumerate(op.getOperands())) replacer.create( - op.getLoc(), ret.value(), - outlinedFunc.getArgument(ret.index() + output_offset)); + op.getLoc(), ret.value(), outlinedFunc.getArgument(ret.index())); replacer.create(op.getLoc()); op.erase(); }); @@ -126,33 +125,37 @@ static void replaceAllUsesInDFTsInRegionWith(Value orig, Value replacement, } } +static mlir::Type stripType(mlir::Type type) { + if (type.isa()) + return stripType(type.dyn_cast().getElementType()); + if (type.isa()) + return stripType(type.dyn_cast().getElementType()); + return type; +} + // TODO: Fix type sizes. For now we're using some default values. -static std::pair +static std::pair getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) { DataLayout dataLayout = DataLayout::closest(val.getDefiningOp()); - Type type = (val.getType().isa()) - ? val.getType().dyn_cast().getElementType() - : val.getType(); + Type type = stripType(val.getType()); // In the case of memref, we need to determine how much space // (conservatively) we need to store the memref itself. Overshooting // by a few bytes should not be an issue, so the main thing is to // properly account for the rank. if (type.isa()) { - // Space for the allocated and aligned pointers, and offset - Value ptrs_offset = - builder.create(loc, builder.getI64IntegerAttr(24)); - // For the sizes and shapes arrays, we need 2*8 = 16 times the rank in bytes - Value multiplier = - builder.create(loc, builder.getI64IntegerAttr(16)); - unsigned _rank = type.dyn_cast().getRank(); - Value rank = builder.create( - loc, builder.getI64IntegerAttr(_rank)); - Value sizes_shapes = builder.create(loc, rank, multiplier); - Value typeSize = - builder.create(loc, ptrs_offset, sizes_shapes); - + // Space for the allocated and aligned pointers, and offset plus + // rank * sizes and strides + size_t element_size; + unsigned rank = type.dyn_cast().getRank(); Type elementType = type.dyn_cast().getElementType(); + + element_size = dataLayout.getTypeSize(elementType); + + size_t size = 24 + 16 * rank; + Value typeSize = + builder.create(loc, builder.getI64IntegerAttr(size)); + // Assume here that the base type is a simple scalar-type or at // least its size can be determined. // size_t elementAttr = dataLayout.getTypeSize(elementType); @@ -160,7 +163,6 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) { // elementAttr <<= 8; // elementAttr |= _DFR_TASK_ARG_MEMREF; uint64_t elementAttr = 0; - size_t element_size = dataLayout.getTypeSize(elementType); elementAttr = dfr::_dfr_set_arg_type(elementAttr, dfr::_DFR_TASK_ARG_MEMREF); elementAttr = dfr::_dfr_set_memref_element_size(elementAttr, element_size); @@ -169,34 +171,19 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) { return std::pair(typeSize, arg_type); } - // Unranked memrefs should be lowered to just pointer + size, so we need 16 - // bytes. - assert(!type.isa() && - "UnrankedMemRefType not currently supported"); + if (type.isa()) { + Value arg_type = builder.create( + loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT)); + Value typeSize = + builder.create(loc, builder.getI64IntegerAttr(8)); + return std::pair(typeSize, arg_type); + } Value arg_type = builder.create( loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_BASE)); - - // FHE types are converted to pointers, so we take their size as 8 - // bytes until we can get the actual size of the actual types. - if (type.isa() || - type.isa() || - type.isa()) { - Value result = - builder.create(loc, builder.getI64IntegerAttr(8)); - return std::pair(result, arg_type); - } else if (type.isa()) { - Value arg_type = builder.create( - loc, builder.getI64IntegerAttr(dfr::_DFR_TASK_ARG_CONTEXT)); - Value result = - builder.create(loc, builder.getI64IntegerAttr(8)); - return std::pair(result, arg_type); - } - - // For all other types, get type size. - Value result = builder.create( + Value typeSize = builder.create( loc, builder.getI64IntegerAttr(dataLayout.getTypeSize(type))); - return std::pair(result, arg_type); + return std::pair(typeSize, arg_type); } static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, @@ -212,7 +199,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, if (!val.getType().isa()) { OpBuilder::InsertionGuard guard(builder); Type futType = RT::FutureType::get(val.getType()); - Value memrefCloned; // Find out if this value is needed in any other task SmallVector taskOps; @@ -224,20 +210,10 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first)) first = op; builder.setInsertionPoint(first); - - // If we are building a future on a MemRef, then we need to clone - // the memref in order to allow the deallocation pass which does - // not synchronize with task execution. - if (val.getType().isa()) { - memrefCloned = builder.create( - val.getLoc(), builder.getI64IntegerAttr(1)); - } else { - memrefCloned = builder.create( - val.getLoc(), builder.getI64IntegerAttr(0)); - } - - auto mrf = builder.create(val.getLoc(), futType, - val, memrefCloned); + auto mrf = builder.create( + val.getLoc(), futType, val, + builder.create(val.getLoc(), + builder.getI64IntegerAttr(0))); replaceAllUsesInDFTsInRegionWith(val, mrf, opBody); } } @@ -246,7 +222,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, // DataflowTaskOp. This also includes the necessary handling of // operands and results (conversion to/from futures and propagation). SmallVector catOperands; - int size = 3 + DFTOp.getNumResults() * 3 + DFTOp.getNumOperands() * 3; + int size = 3 + DFTOp.getNumResults() + DFTOp.getNumOperands(); catOperands.reserve(size); auto fnptr = builder.create( DFTOp.getLoc(), workFunction.getFunctionType(), @@ -258,12 +234,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, catOperands.push_back(fnptr.getResult()); catOperands.push_back(numIns.getResult()); catOperands.push_back(numOuts.getResult()); - for (auto operand : DFTOp.getOperands()) { - auto op_size = getTaskArgumentSizeAndType(operand, DFTOp.getLoc(), builder); - catOperands.push_back(operand); - catOperands.push_back(op_size.first); - catOperands.push_back(op_size.second); - } // We need to adjust the results for the CreateAsyncTaskOp which // are the work function's returns through pointers passed as @@ -276,11 +246,11 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, Type futType = RT::PointerType::get(RT::FutureType::get(result.getType())); auto brpp = builder.create(DFTOp.getLoc(), futType); - auto op_size = getTaskArgumentSizeAndType(result, DFTOp.getLoc(), builder); map.map(result, brpp->getResult(0)); catOperands.push_back(brpp->getResult(0)); - catOperands.push_back(op_size.first); - catOperands.push_back(op_size.second); + } + for (auto operand : DFTOp.getOperands()) { + catOperands.push_back(operand); } builder.create( DFTOp.getLoc(), @@ -341,59 +311,6 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) { SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } -static void propagateMemRefLayoutInDFTs(RT::DataflowTaskOp op, Value val, - Value newval) { - for (auto &use : llvm::make_early_inc_range(val.getUses())) - if (use.getOwner()->getParentOfType() != nullptr) { - OpBuilder builder(use.getOwner()); - Value cast_newval = builder.create( - val.getLoc(), val.getType(), newval); - use.set(cast_newval); - } -} - -static void cloneMemRefTaskArgumentsWithIdentityMaps(RT::DataflowTaskOp op) { - OpBuilder builder(op); - for (Value val : op.getOperands()) { - if (val.getType().isa()) { - OpBuilder::InsertionGuard guard(builder); - - // Find out if this memref is needed in any other task to clone - // before all uses - SmallVector taskOps; - for (auto &use : val.getUses()) - if (isa(use.getOwner())) - taskOps.push_back(use.getOwner()); - Operation *first = op; - for (auto op : taskOps) - if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first)) - first = op; - builder.setInsertionPoint(first); - - // Get the type of memref that we will clone. In case this is - // a subview, we discard the mapping so we copy to a contiguous - // layout which pre-serializes this. - MemRefType mrType_base = val.getType().dyn_cast(); - MemRefType mrType = mrType_base; - if (!mrType_base.getLayout().isIdentity()) { - unsigned rank = mrType_base.getRank(); - mrType = MemRefType::Builder(mrType_base) - .setShape(mrType_base.getShape()) - .setLayout(AffineMapAttr::get( - builder.getMultiDimIdentityMap(rank))); - } - Value newval = builder.create(val.getLoc(), mrType) - .getResult(); - builder.create(val.getLoc(), val, newval); - // Value cast_newval = builder.create(val.getLoc(), - // mrType_base, newval); - replaceAllUsesInDFTsInRegionWith( - val, newval, op->getParentOfType().getBody()); - propagateMemRefLayoutInDFTs(op, val, newval); - } - } -} - /// For documentation see Autopar.td struct LowerDataflowTasksPass : public LowerDataflowTasksBase { @@ -413,8 +330,8 @@ struct LowerDataflowTasksPass SymbolTable symbolTable = mlir::SymbolTable::getNearestSymbolTable(func); SmallVector, 4> outliningMap; + // Outline DataflowTaskOp bodies to work functions func.walk([&](RT::DataflowTaskOp op) { - cloneMemRefTaskArgumentsWithIdentityMaps(op); auto workFunctionName = Twine("_dfr_DFT_work_function__") + Twine(op->getParentOfType().getName()) + @@ -423,15 +340,15 @@ struct LowerDataflowTasksPass outlineWorkFunction(op, workFunctionName.str()); outliningMap.push_back( std::pair(op, outlinedFunc)); - workFunctions.push_back(outlinedFunc); symbolTable.insert(outlinedFunc); + workFunctions.push_back(outlinedFunc); return WalkResult::advance(); }); - // Lower the DF task ops to RT dialect ops. for (auto mapping : outliningMap) lowerDataflowTaskOp(mapping.first, mapping.second); + // Gather all entry points (assuming no recursive calls to entry points) // Main is always an entry-point - otherwise check if this // function is called within the module. TODO: we assume no // recursion. @@ -449,17 +366,6 @@ struct LowerDataflowTasksPass }); for (auto entryPoint : entryPoints) { - // Check if this entry point uses a context - do this before we - // remove arguments in remote nodes - int ctxIndex = -1; - for (auto arg : llvm::enumerate(entryPoint.getArguments())) - if (arg.value() - .getType() - .isa()) { - ctxIndex = arg.index(); - break; - } - // If this is a JIT invocation and we're not on the root node, // we do not need to do any computation, only register all work // functions with the runtime system @@ -486,49 +392,6 @@ struct LowerDataflowTasksPass // runtime. for (auto wf : workFunctions) registerWorkFunction(entryPoint, wf); - - // Issue _dfr_start/stop calls for this function - OpBuilder builder(entryPoint.getBody()); - builder.setInsertionPointToStart(&entryPoint.getBody().front()); - int useDFR = (workFunctions.empty()) ? 0 : 1; - Value useDFRVal = builder.create( - entryPoint.getLoc(), builder.getI64IntegerAttr(useDFR)); - - if (ctxIndex >= 0) { - auto startFunTy = - (dfr::_dfr_is_root_node()) - ? mlir::FunctionType::get( - entryPoint->getContext(), - {useDFRVal.getType(), - entryPoint.getArgument(ctxIndex).getType()}, - {}) - : mlir::FunctionType::get(entryPoint->getContext(), - {useDFRVal.getType()}, {}); - (void)insertForwardDeclaration(entryPoint, builder, "_dfr_start_c", - startFunTy); - (dfr::_dfr_is_root_node()) - ? builder.create( - entryPoint.getLoc(), "_dfr_start_c", mlir::TypeRange(), - mlir::ValueRange( - {useDFRVal, entryPoint.getArgument(ctxIndex)})) - : builder.create(entryPoint.getLoc(), - "_dfr_start_c", - mlir::TypeRange(), useDFRVal); - } else { - auto startFunTy = mlir::FunctionType::get(entryPoint->getContext(), - {useDFRVal.getType()}, {}); - (void)insertForwardDeclaration(entryPoint, builder, "_dfr_start", - startFunTy); - builder.create(entryPoint.getLoc(), "_dfr_start", - mlir::TypeRange(), useDFRVal); - } - builder.setInsertionPoint(entryPoint.getBody().back().getTerminator()); - auto stopFunTy = mlir::FunctionType::get(entryPoint->getContext(), - {useDFRVal.getType()}, {}); - (void)insertForwardDeclaration(entryPoint, builder, "_dfr_stop", - stopFunTy); - builder.create(entryPoint.getLoc(), "_dfr_stop", - mlir::TypeRange(), useDFRVal); } } LowerDataflowTasksPass(bool debug) : debug(debug){}; @@ -544,6 +407,188 @@ std::unique_ptr createLowerDataflowTasksPass(bool debug) { namespace { +// For documentation see Autopar.td +struct StartStopPass : public StartStopBase { + + void runOnOperation() override { + auto module = getOperation(); + int useDFR = 0; + SmallVector entryPoints; + + module.walk([&](mlir::func::FuncOp func) { + // Do not add start/stop to work functions - but if any are + // present, then we need to activate the runtime + if (func->getAttr("_dfr_work_function_attribute")) { + useDFR = 1; + } else { + // Main is always an entry-point - otherwise check if this + // function is called within the module. TODO: we assume no + // recursion. + if (func.getName() == "main") + entryPoints.push_back(func); + else { + bool found = false; + module.walk([&](mlir::func::CallOp op) { + if (getCalledFunction(op) == func) + found = true; + }); + if (!found) + entryPoints.push_back(func); + } + } + }); + + for (auto entryPoint : entryPoints) { + // Issue _dfr_start/stop calls for this function + OpBuilder builder(entryPoint.getBody()); + builder.setInsertionPointToStart(&entryPoint.getBody().front()); + Value useDFRVal = builder.create( + entryPoint.getLoc(), builder.getI64IntegerAttr(useDFR)); + + // Check if this entry point uses a context + Value ctx = nullptr; + if (dfr::_dfr_is_root_node()) + for (auto arg : llvm::enumerate(entryPoint.getArguments())) + if (arg.value() + .getType() + .isa()) { + ctx = arg.value(); + break; + } + if (!ctx) + ctx = builder.create(entryPoint.getLoc(), + builder.getI64IntegerAttr(0)); + + auto startFunTy = mlir::FunctionType::get( + entryPoint->getContext(), {useDFRVal.getType(), ctx.getType()}, {}); + (void)insertForwardDeclaration(entryPoint, builder, "_dfr_start", + startFunTy); + builder.create(entryPoint.getLoc(), "_dfr_start", + mlir::TypeRange(), + mlir::ValueRange({useDFRVal, ctx})); + builder.setInsertionPoint(entryPoint.getBody().back().getTerminator()); + auto stopFunTy = mlir::FunctionType::get(entryPoint->getContext(), + {useDFRVal.getType()}, {}); + (void)insertForwardDeclaration(entryPoint, builder, "_dfr_stop", + stopFunTy); + builder.create(entryPoint.getLoc(), "_dfr_stop", + mlir::TypeRange(), useDFRVal); + } + } + StartStopPass(bool debug) : debug(debug){}; + +protected: + bool debug; +}; +} // namespace + +std::unique_ptr createStartStopPass(bool debug) { + return std::make_unique(debug); +} + +namespace { + +// For documentation see Autopar.td +struct FinalizeTaskCreationPass + : public FinalizeTaskCreationBase { + + void runOnOperation() override { + auto module = getOperation(); + std::vector ops; + + module.walk([&](RT::CreateAsyncTaskOp catOp) { + OpBuilder builder(catOp); + SmallVector operands; + + // Determine if this task needs a runtime context + Value ctx = nullptr; + SymbolRefAttr sym = + catOp->getAttr("workfn").dyn_cast_or_null(); + assert(sym && "Work function symbol attribute missing."); + func::FuncOp workfn = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom(catOp, sym)); + assert(workfn && "Task work function missing."); + if (workfn.getNumArguments() > catOp.getNumOperands() - 3) + ctx = *catOp->getParentOfType().getArguments().rbegin(); + else + ctx = builder.create(catOp.getLoc(), + builder.getI64IntegerAttr(0)); + int index = 0; + for (auto op : catOp.getOperands()) { + operands.push_back(op); + // Add index in second position - in all cases to avoid + // checking if needed. It can be null when not relevant. + if (index == 0) + operands.push_back(ctx); + // First three operands are the function pointer, number inputs + // and number outputs - nothing further to do. + if (++index <= 3) + continue; + auto op_size = getTaskArgumentSizeAndType(op, catOp.getLoc(), builder); + operands.push_back(op_size.first); + operands.push_back(op_size.second); + } + + builder.create(catOp.getLoc(), sym, operands); + ops.push_back(catOp); + }); + for (auto op : ops) { + op->erase(); + } + + // If we are building a future on a MemRef, we need to flatten it. + + // TODO: the performance of shared memory can be improved by + // allowing view-like access instead of cloning, but memory + // deallocation needs to be synchronized appropriately + module.walk([&](RT::MakeReadyFutureOp op) { + OpBuilder builder(op); + + Value val = op.getOperand(0); + Value clone = op.getOperand(1); + if (val.getType().isa()) { + MemRefType mrType_base = val.getType().dyn_cast(); + MemRefType mrType = mrType_base; + if (!mrType_base.getLayout().isIdentity()) { + unsigned rank = mrType_base.getRank(); + mrType = MemRefType::Builder(mrType_base) + .setShape(mrType_base.getShape()) + .setLayout(AffineMapAttr::get( + builder.getMultiDimIdentityMap(rank))); + } + // We need to make a copy of this MemRef to allow deallocation + // based on refcounting + Value newval = + builder.create(val.getLoc(), mrType) + .getResult(); + builder.create(val.getLoc(), val, newval); + clone = builder.create(op.getLoc(), + builder.getI64IntegerAttr(1)); + op->setOperand(0, newval); + op->setOperand(1, clone); + } + }); + } + FinalizeTaskCreationPass(bool debug) : debug(debug){}; + +protected: + bool debug; +}; +} // namespace + +std::unique_ptr createFinalizeTaskCreationPass(bool debug) { + return std::make_unique(debug); +} + +namespace { +static void getAliasedUses(Value val, DenseSet &aliasedUses) { + for (auto &use : val.getUses()) { + aliasedUses.insert(&use); + if (dyn_cast(use.getOwner())) + getAliasedUses(use.getOwner()->getResult(0), aliasedUses); + } +} + // For documentation see Autopar.td struct FixupBufferDeallocationPass : public FixupBufferDeallocationBase { @@ -552,23 +597,21 @@ struct FixupBufferDeallocationPass auto module = getOperation(); std::vector ops; - // All buffers allocated and either made into a future, directly - // or as a result of being returned by a task, are managed by the - // DFR runtime system's reference counting. - module.walk([&](RT::WorkFunctionReturnOp retOp) { - for (auto &use : - llvm::make_early_inc_range(retOp.getOperands().front().getUses())) - if (isa(use.getOwner())) - ops.push_back(use.getOwner()); + module.walk([&](mlir::memref::DeallocOp op) { + Value alloc = op.getOperand(); + DenseSet aliasedUses; + getAliasedUses(alloc, aliasedUses); + + for (auto use : aliasedUses) + if (isa( + use->getOwner())) { + ops.push_back(op); + return; + } }); - module.walk([&](RT::MakeReadyFutureOp mrfOp) { - for (auto &use : - llvm::make_early_inc_range(mrfOp.getOperands().front().getUses())) - if (isa(use.getOwner())) - ops.push_back(use.getOwner()); - }); - for (auto op : ops) + for (auto op : ops) { op->erase(); + } } FixupBufferDeallocationPass(bool debug) : debug(debug){}; diff --git a/compiler/lib/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.cpp b/compiler/lib/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.cpp index 737f2b797..4cfcca057 100644 --- a/compiler/lib/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/compiler/lib/Dialect/RT/Transforms/BufferizableOpInterfaceImpl.cpp @@ -9,6 +9,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" #include +#include #include #include @@ -21,9 +22,10 @@ using namespace mlir::concretelang::RT; // using namespace mlir::tensor; namespace { -struct DataflowTaskOpBufferizationInterface +struct DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface : public BufferizableOpInterface::ExternalModel< - DataflowTaskOpBufferizationInterface, DataflowTaskOp> { + DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface, + DerefWorkFunctionArgumentPtrPlaceholderOp> { bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, const AnalysisState &state) const { return false; @@ -44,22 +46,20 @@ struct DataflowTaskOpBufferizationInterface return BufferRelation::None; } - LogicalResult bufferize(Operation *op, RewriterBase &rewriter, + LogicalResult bufferize(Operation *bop, RewriterBase &rewriter, const BufferizationOptions &options) const { - DataflowTaskOp taskOp = cast(op); + DerefWorkFunctionArgumentPtrPlaceholderOp op = + cast(bop); auto isTensorType = [](Type t) { return t.isa(); }; - bool hasTensorResult = llvm::any_of(taskOp.getResultTypes(), isTensorType); - bool hasTensorOperand = - llvm::any_of(taskOp.getOperandTypes(), isTensorType); + bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType); + bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType); if (!hasTensorResult && !hasTensorOperand) return success(); SmallVector newOperands; - rewriter.setInsertionPoint(taskOp.getBody(), taskOp.getBody()->begin()); - for (OpOperand &opOperand : op->getOpOperands()) { Value oldOperandValue = opOperand.get(); @@ -72,43 +72,11 @@ struct DataflowTaskOpBufferizationInterface Value buffer = bufferOrErr.getValue(); newOperands.push_back(buffer); - - Value tensor = - rewriter.create(buffer.getLoc(), buffer); - - replaceAllUsesInRegionWith(oldOperandValue, tensor, - taskOp.getBodyRegion()); + } else { + newOperands.push_back(opOperand.get()); } } - if (hasTensorResult) { - WalkResult wr = taskOp.walk([&](DataflowYieldOp yield) { - SmallVector yieldValues; - - for (OpOperand &yieldOperand : yield.getOperation()->getOpOperands()) - if (yieldOperand.get().getType().isa()) { - FailureOr bufferOrErr = - bufferization::getBuffer(rewriter, yieldOperand.get(), options); - - if (failed(bufferOrErr)) - return WalkResult::interrupt(); - - yieldValues.push_back(bufferOrErr.getValue()); - } else { - yieldValues.push_back(yieldOperand.get()); - } - - rewriter.setInsertionPointAfter(yield); - rewriter.replaceOpWithNewOp(yield.getOperation(), - yieldValues); - - return WalkResult::advance(); - }); - - if (wr.wasInterrupted()) - return failure(); - } - SmallVector newResultTypes; for (OpResult res : op->getResults()) { @@ -120,17 +88,239 @@ struct DataflowTaskOpBufferizationInterface } } - rewriter.setInsertionPoint(taskOp); - DataflowTaskOp newTaskOp = rewriter.create( - taskOp.getLoc(), newResultTypes, newOperands); + rewriter.setInsertionPoint(op); + DerefWorkFunctionArgumentPtrPlaceholderOp newOp = + rewriter.create( + op.getLoc(), newResultTypes, newOperands); - newTaskOp.getRegion().takeBody(taskOp.getRegion()); - - replaceOpWithBufferizedValues(rewriter, op, newTaskOp->getResults()); + replaceOpWithBufferizedValues(rewriter, op, newOp->getResults()); return success(); } }; + +struct MakeReadyFutureOpBufferizationInterface + : public BufferizableOpInterface::ExternalModel< + MakeReadyFutureOpBufferizationInterface, MakeReadyFutureOp> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *bop, RewriterBase &rewriter, + const BufferizationOptions &options) const { + MakeReadyFutureOp op = cast(bop); + + auto isTensorType = [](Type t) { return t.isa(); }; + bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType); + bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType); + + if (!hasTensorResult && !hasTensorOperand) + return success(); + + SmallVector newOperands; + + for (OpOperand &opOperand : op->getOpOperands()) { + Value oldOperandValue = opOperand.get(); + + if (oldOperandValue.getType().isa()) { + FailureOr bufferOrErr = + bufferization::getBuffer(rewriter, opOperand.get(), options); + + if (failed(bufferOrErr)) + return failure(); + + Value buffer = bufferOrErr.getValue(); + newOperands.push_back(buffer); + } else { + newOperands.push_back(opOperand.get()); + } + } + + SmallVector newResultTypes; + + for (OpResult res : op->getResults()) { + if (TensorType t = res.getType().dyn_cast()) { + BaseMemRefType memrefType = getMemRefType(t, options); + newResultTypes.push_back(memrefType); + } else { + newResultTypes.push_back(res.getType()); + } + } + + rewriter.setInsertionPoint(op); + MakeReadyFutureOp newOp = rewriter.create( + op.getLoc(), newResultTypes, newOperands); + + replaceOpWithBufferizedValues(rewriter, op, newOp->getResults()); + + return success(); + } +}; + +struct WorkFunctionReturnOpBufferizationInterface + : public BufferizableOpInterface::ExternalModel< + WorkFunctionReturnOpBufferizationInterface, WorkFunctionReturnOp> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *bop, RewriterBase &rewriter, + const BufferizationOptions &options) const { + WorkFunctionReturnOp op = cast(bop); + + auto isTensorType = [](Type t) { return t.isa(); }; + bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType); + bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType); + + if (!hasTensorResult && !hasTensorOperand) + return success(); + + SmallVector newOperands; + + for (OpOperand &opOperand : op->getOpOperands()) { + Value oldOperandValue = opOperand.get(); + + if (oldOperandValue.getType().isa()) { + FailureOr bufferOrErr = + bufferization::getBuffer(rewriter, opOperand.get(), options); + + if (failed(bufferOrErr)) + return failure(); + + Value buffer = bufferOrErr.getValue(); + newOperands.push_back(buffer); + } else { + newOperands.push_back(opOperand.get()); + } + } + + SmallVector newResultTypes; + + for (OpResult res : op->getResults()) { + if (TensorType t = res.getType().dyn_cast()) { + BaseMemRefType memrefType = getMemRefType(t, options); + newResultTypes.push_back(memrefType); + } else { + newResultTypes.push_back(res.getType()); + } + } + + rewriter.setInsertionPoint(op); + WorkFunctionReturnOp newOp = rewriter.create( + op.getLoc(), newResultTypes, newOperands); + + replaceOpWithBufferizedValues(rewriter, op, newOp->getResults()); + + return success(); + } +}; + +struct AwaitFutureOpBufferizationInterface + : public BufferizableOpInterface::ExternalModel< + AwaitFutureOpBufferizationInterface, AwaitFutureOp> { + bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return false; + } + + SmallVector getAliasingOpResult(Operation *op, OpOperand &opOperand, + const AnalysisState &state) const { + return {}; + } + + BufferRelation bufferRelation(Operation *op, OpResult opResult, + const AnalysisState &state) const { + return BufferRelation::None; + } + + LogicalResult bufferize(Operation *bop, RewriterBase &rewriter, + const BufferizationOptions &options) const { + AwaitFutureOp op = cast(bop); + + auto isTensorType = [](Type t) { return t.isa(); }; + bool hasTensorResult = llvm::any_of(op->getResultTypes(), isTensorType); + bool hasTensorOperand = llvm::any_of(op->getOperandTypes(), isTensorType); + + if (!hasTensorResult && !hasTensorOperand) + return success(); + + SmallVector newOperands; + + for (OpOperand &opOperand : op->getOpOperands()) { + Value oldOperandValue = opOperand.get(); + + if (oldOperandValue.getType().isa()) { + FailureOr bufferOrErr = + bufferization::getBuffer(rewriter, opOperand.get(), options); + + if (failed(bufferOrErr)) + return failure(); + + Value buffer = bufferOrErr.getValue(); + newOperands.push_back(buffer); + } else { + newOperands.push_back(opOperand.get()); + } + } + + SmallVector newResultTypes; + + for (OpResult res : op->getResults()) { + if (TensorType t = res.getType().dyn_cast()) { + BaseMemRefType memrefType = getMemRefType(t, options); + newResultTypes.push_back(memrefType); + } else { + newResultTypes.push_back(res.getType()); + } + } + + rewriter.setInsertionPoint(op); + AwaitFutureOp newOp = rewriter.create( + op.getLoc(), newResultTypes, newOperands); + + replaceOpWithBufferizedValues(rewriter, op, newOp->getResults()); + + return success(); + } +}; + } // namespace namespace mlir { @@ -138,7 +328,13 @@ namespace concretelang { namespace RT { void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, RTDialect *dialect) { - DataflowTaskOp::attachInterface(*ctx); + DerefWorkFunctionArgumentPtrPlaceholderOp::attachInterface< + DerefWorkFunctionArgumentPtrPlaceholderOpBufferizationInterface>(*ctx); + AwaitFutureOp::attachInterface(*ctx); + MakeReadyFutureOp::attachInterface( + *ctx); + WorkFunctionReturnOp::attachInterface< + WorkFunctionReturnOpBufferizationInterface>(*ctx); }); } } // namespace RT diff --git a/compiler/lib/Runtime/DFRuntime.cpp b/compiler/lib/Runtime/DFRuntime.cpp index 321ac561b..332f41692 100644 --- a/compiler/lib/Runtime/DFRuntime.cpp +++ b/compiler/lib/Runtime/DFRuntime.cpp @@ -93,8 +93,8 @@ static inline size_t _dfr_find_next_execution_locality() { /// hpx::future and the size of data within the future. After /// that come NUM_OUTPUTS pairs of hpx::future* and size_t for /// the returns. -void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, - ...) { +void _dfr_create_async_task(wfnptr wfn, void *ctx, size_t num_params, + size_t num_outputs, ...) { std::vector refcounted_futures; std::vector param_sizes; std::vector param_types; @@ -104,16 +104,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, va_list args; va_start(args, num_outputs); - for (size_t i = 0; i < num_params; ++i) { - refcounted_futures.push_back(va_arg(args, void *)); - param_sizes.push_back(va_arg(args, uint64_t)); - param_types.push_back(va_arg(args, uint64_t)); - } for (size_t i = 0; i < num_outputs; ++i) { outputs.push_back(va_arg(args, void *)); output_sizes.push_back(va_arg(args, uint64_t)); output_types.push_back(va_arg(args, uint64_t)); } + for (size_t i = 0; i < num_params; ++i) { + refcounted_futures.push_back(va_arg(args, void *)); + param_sizes.push_back(va_arg(args, uint64_t)); + param_types.push_back(va_arg(args, uint64_t)); + } va_end(args); // Take a reference on each future argument @@ -140,12 +140,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 0: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target]() - -> hpx::future { + gcc_target, + ctx]() -> hpx::future { std::vector params = {}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); })); break; @@ -153,12 +153,12 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 1: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0) + gcc_target, ctx](hpx::shared_future param0) -> hpx::future { std::vector params = {param0.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future)); @@ -167,13 +167,13 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 2: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1) -> hpx::future { std::vector params = {param0.get(), param1.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -183,15 +183,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 3: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2) -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -202,16 +202,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 4: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3) -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get(), param3.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -223,18 +223,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 5: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4) -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get(), param3.get(), param4.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -247,19 +247,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 6: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5) -> hpx::future { std::vector params = {param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -273,20 +273,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 7: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -301,21 +301,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 8: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), param4.get(), param5.get(), param6.get(), param7.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -331,15 +331,15 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 9: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), @@ -347,7 +347,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param6.get(), param7.get(), param8.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -364,16 +364,16 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 10: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -381,7 +381,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param8.get(), param9.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -399,17 +399,17 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 11: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -417,7 +417,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param8.get(), param9.get(), param10.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -436,18 +436,18 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 12: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -455,7 +455,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param8.get(), param9.get(), param10.get(), param11.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -475,19 +475,19 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 13: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -496,7 +496,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param12.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -517,20 +517,20 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 14: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12, - hpx::shared_future param13) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -539,7 +539,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param12.get(), param13.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -561,21 +561,21 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 15: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12, - hpx::shared_future param13, - hpx::shared_future param14) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -584,7 +584,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param12.get(), param13.get(), param14.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -607,22 +607,22 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 16: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12, - hpx::shared_future param13, - hpx::shared_future param14, - hpx::shared_future param15) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -631,7 +631,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param12.get(), param13.get(), param14.get(), param15.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -655,23 +655,23 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 17: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12, - hpx::shared_future param13, - hpx::shared_future param14, - hpx::shared_future param15, - hpx::shared_future param16) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -681,7 +681,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param16.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -706,24 +706,24 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 18: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12, - hpx::shared_future param13, - hpx::shared_future param14, - hpx::shared_future param15, - hpx::shared_future param16, - hpx::shared_future param17) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16, + hpx::shared_future param17) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -733,7 +733,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param16.get(), param17.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -759,25 +759,25 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 19: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12, - hpx::shared_future param13, - hpx::shared_future param14, - hpx::shared_future param15, - hpx::shared_future param16, - hpx::shared_future param17, - hpx::shared_future param18) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16, + hpx::shared_future param17, + hpx::shared_future param18) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -787,7 +787,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param16.get(), param17.get(), param18.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -814,26 +814,26 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, case 20: oodf = std::move(hpx::dataflow( [wfnname, param_sizes, param_types, output_sizes, output_types, - gcc_target](hpx::shared_future param0, - hpx::shared_future param1, - hpx::shared_future param2, - hpx::shared_future param3, - hpx::shared_future param4, - hpx::shared_future param5, - hpx::shared_future param6, - hpx::shared_future param7, - hpx::shared_future param8, - hpx::shared_future param9, - hpx::shared_future param10, - hpx::shared_future param11, - hpx::shared_future param12, - hpx::shared_future param13, - hpx::shared_future param14, - hpx::shared_future param15, - hpx::shared_future param16, - hpx::shared_future param17, - hpx::shared_future param18, - hpx::shared_future param19) + gcc_target, ctx](hpx::shared_future param0, + hpx::shared_future param1, + hpx::shared_future param2, + hpx::shared_future param3, + hpx::shared_future param4, + hpx::shared_future param5, + hpx::shared_future param6, + hpx::shared_future param7, + hpx::shared_future param8, + hpx::shared_future param9, + hpx::shared_future param10, + hpx::shared_future param11, + hpx::shared_future param12, + hpx::shared_future param13, + hpx::shared_future param14, + hpx::shared_future param15, + hpx::shared_future param16, + hpx::shared_future param17, + hpx::shared_future param18, + hpx::shared_future param19) -> hpx::future { std::vector params = { param0.get(), param1.get(), param2.get(), param3.get(), @@ -843,7 +843,7 @@ void _dfr_create_async_task(wfnptr wfn, size_t num_params, size_t num_outputs, param16.get(), param17.get(), param18.get(), param19.get()}; mlir::concretelang::dfr::OpaqueInputData oid( wfnname, params, param_sizes, param_types, output_sizes, - output_types); + output_types, ctx); return gcc_target->execute_task(oid); }, *((dfr_refcounted_future_p)refcounted_futures[0])->future, @@ -969,6 +969,7 @@ void _dfr_set_use_omp(bool use_omp) { bool _dfr_is_jit() { return mlir::concretelang::dfr::is_jit_p; } bool _dfr_is_root_node() { return mlir::concretelang::dfr::is_root_node_p; } bool _dfr_use_omp() { return mlir::concretelang::dfr::use_omp_p; } +bool _dfr_is_distributed() { return num_nodes > 1; } } // namespace dfr } // namespace concretelang } // namespace mlir @@ -1002,6 +1003,7 @@ static inline void _dfr_stop_impl() { } static inline void _dfr_start_impl(int argc, char *argv[]) { + BEGIN_TIME(&mlir::concretelang::dfr::init_timer); mlir::concretelang::dfr::dl_handle = dlopen(nullptr, RTLD_NOW); // If OpenMP is to be used, we need to force its initialization @@ -1126,15 +1128,15 @@ static inline void _dfr_start_impl(int argc, char *argv[]) { mlir::concretelang::dfr::num_nodes) .get(); } + END_TIME(&mlir::concretelang::dfr::init_timer, "Initialization"); } /* Start/stop functions to be called from within user code (or during JIT invocation). These serve to pause/resume the runtime scheduler and to clean up used resources. */ -void _dfr_start(int64_t use_dfr_p) { +void _dfr_start(int64_t use_dfr_p, void *ctx) { BEGIN_TIME(&mlir::concretelang::dfr::whole_timer); if (use_dfr_p) { - BEGIN_TIME(&mlir::concretelang::dfr::init_timer); // The first invocation will initialise the runtime. As each call to // _dfr_start is matched with _dfr_stop, if this is not hte first, // we need to resume the HPX runtime. @@ -1146,16 +1148,11 @@ void _dfr_start(int64_t use_dfr_p) { if (mlir::concretelang::dfr::init_guard.compare_exchange_strong( expected, mlir::concretelang::dfr::active)) _dfr_start_impl(0, nullptr); - END_TIME(&mlir::concretelang::dfr::init_timer, "Initialization"); assert(mlir::concretelang::dfr::init_guard == mlir::concretelang::dfr::active && "DFR runtime failed to initialise"); - if (use_dfr_p == 1) { - BEGIN_TIME(&mlir::concretelang::dfr::compute_timer); - } - // If this is not the root node in a non-JIT execution, then this // node should only run the scheduler for any incoming work until // termination is flagged. If this is JIT, we need to run the @@ -1164,28 +1161,24 @@ void _dfr_start(int64_t use_dfr_p) { !mlir::concretelang::dfr::_dfr_is_jit()) _dfr_stop_impl(); } -} -// Startup entry point when a RuntimeContext is used -void _dfr_start_c(int64_t use_dfr_p, void *ctx) { - _dfr_start(2); + // If DFR is used and a runtime context is needed, and execution is + // distributed, then broadcast from root to all compute nodes. + if (use_dfr_p && (mlir::concretelang::dfr::num_nodes > 1) && + (ctx || !mlir::concretelang::dfr::_dfr_is_root_node())) { + BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer); + new mlir::concretelang::dfr::RuntimeContextManager(); + mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager + ->setContext(ctx); - if (use_dfr_p) { - if (mlir::concretelang::dfr::num_nodes > 1) { - BEGIN_TIME(&mlir::concretelang::dfr::broadcast_timer); - new mlir::concretelang::dfr::RuntimeContextManager(); - mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager - ->setContext(ctx); - - // If this is not JIT, then the remote nodes never reach _dfr_stop, - // so root should not instantiate this barrier. - if (mlir::concretelang::dfr::_dfr_is_root_node() && - mlir::concretelang::dfr::_dfr_is_jit()) - mlir::concretelang::dfr::_dfr_startup_barrier->wait(); - END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting"); - } - BEGIN_TIME(&mlir::concretelang::dfr::compute_timer); + // If this is not JIT, then the remote nodes never reach _dfr_stop, + // so root should not instantiate this barrier. + if (mlir::concretelang::dfr::_dfr_is_root_node() && + mlir::concretelang::dfr::_dfr_is_jit()) + mlir::concretelang::dfr::_dfr_startup_barrier->wait(); + END_TIME(&mlir::concretelang::dfr::broadcast_timer, "Key broadcasting"); } + BEGIN_TIME(&mlir::concretelang::dfr::compute_timer); } // This function cannot be used to terminate the runtime as it is @@ -1216,8 +1209,8 @@ void _dfr_stop(int64_t use_dfr_p) { mlir::concretelang::dfr::_dfr_node_level_runtime_context_manager ->clearContext(); } - END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute"); } + END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute"); END_TIME(&mlir::concretelang::dfr::whole_timer, "Total execution"); } @@ -1298,6 +1291,7 @@ namespace dfr { namespace { static bool is_jit_p = false; static bool use_omp_p = false; +static size_t num_nodes = 1; static struct timespec compute_timer; } // namespace @@ -1307,15 +1301,15 @@ void _dfr_set_use_omp(bool use_omp) { use_omp_p = use_omp; } bool _dfr_is_jit() { return is_jit_p; } bool _dfr_is_root_node() { return true; } bool _dfr_use_omp() { return use_omp_p; } +bool _dfr_is_distributed() { return num_nodes > 1; } } // namespace dfr } // namespace concretelang } // namespace mlir -void _dfr_start(int64_t use_dfr_p) { +void _dfr_start(int64_t use_dfr_p, void *ctx) { BEGIN_TIME(&mlir::concretelang::dfr::compute_timer); } -void _dfr_start_c(int64_t use_dfr_p, void *ctx) { _dfr_start(2); } void _dfr_stop(int64_t use_dfr_p) { END_TIME(&mlir::concretelang::dfr::compute_timer, "Compute"); } diff --git a/compiler/lib/Support/Pipeline.cpp b/compiler/lib/Support/Pipeline.cpp index 2fd3cbe78..321857d25 100644 --- a/compiler/lib/Support/Pipeline.cpp +++ b/compiler/lib/Support/Pipeline.cpp @@ -153,6 +153,8 @@ mlir::LogicalResult autopar(mlir::MLIRContext &context, mlir::ModuleOp &module, addPotentiallyNestedPass( pm, mlir::concretelang::createBuildDataflowTaskGraphPass(), enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass); return pm.run(module.getOperation()); } @@ -294,6 +296,9 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, mlir::bufferization::createOneShotBufferizePass(bufferizationOptions); addPotentiallyNestedPass(pm, std::move(comprBuffPass), enablePass); + addPotentiallyNestedPass( + pm, mlir::concretelang::createBufferizeDataflowTaskOpsPass(), enablePass); + if (parallelizeLoops) { addPotentiallyNestedPass(pm, mlir::concretelang::createForLoopToParallel(), enablePass); @@ -305,14 +310,16 @@ lowerStdToLLVMDialect(mlir::MLIRContext &context, mlir::ModuleOp &module, // Lower affine addPotentiallyNestedPass(pm, mlir::createLowerAffinePass(), enablePass); - // Lower Dataflow tasks to DRF + // Finalize the lowering of RT/DFR which includes: + // - adding type and typesize information for dependences + // - issue _dfr_start and _dfr_stop calls to start/stop the runtime + // - remove deallocation calls for buffers managed through refcounting addPotentiallyNestedPass( - pm, mlir::concretelang::createFixupDataflowTaskOpsPass(), enablePass); - addPotentiallyNestedPass( - pm, mlir::concretelang::createLowerDataflowTasksPass(), enablePass); - // Use the buffer deallocation interface to insert future deallocation calls + pm, mlir::concretelang::createFinalizeTaskCreationPass(), enablePass); addPotentiallyNestedPass( pm, mlir::bufferization::createBufferDeallocationPass(), enablePass); + addPotentiallyNestedPass(pm, mlir::concretelang::createStartStopPass(), + enablePass); addPotentiallyNestedPass( pm, mlir::concretelang::createFixupBufferDeallocationPass(), enablePass); diff --git a/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc b/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc index f1b3b2e92..b4d1a3e42 100644 --- a/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc +++ b/compiler/tests/end_to_end_tests/end_to_end_jit_distributed.cc @@ -13,7 +13,7 @@ std::vector distributed_results; -TEST(DISABLED_Distributed, nn_med_nested) { +TEST(Distributed, nn_med_nested) { checkedJit(lambda, R"XXX( func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> { %cst = arith.constant dense<"0x01010100010100000001010101000101010101010101010001000101000001010001010100000101000001000001010001000001010100010001000000010100010001010001000001000101010101000100010001000000000100010001000101000001000101010100010001000000000101000100000000000001000100000100000100000001010000010001000101000100010001000100000100000100010101010000000000000000010001010000000100000100010100000100000000010001000101000100000000000101010101000101010101010100010100010100000000000101010100000100010100000001000101000000010101000101000100000101010100010101010000010101010100010000000000000001010101000100010101000001010001010000010001010101000000000000000001000001000000010100000100000101010100010001000000000000010100010101000000010100000100010001010001000000000100010001000101010100010100000001010100010101010100010100010001000001000000000101000101010001000100000101010100000101010100000100010101000100000101000101010100010001000101010100010001010001010000010000010001010000000001000101010001000000000101000000010000010100010001000001000001010101000100010001010100000101000000010001000000000101000101000000010000000001000101010100010001000000000001010000010001000001010101000101010101010100000000000001000100000100000001000000010101010101000000000101010101000100000101000100000000000001000100000101000101010100010000000101000000000100000100000101010000010100000000010000000000010001000100000101010001010101000000000000010000010101010001000000010001010001010000000000000101000000010101010101000001010101000001000001010100000000010001010100000100000101000101010100010001010001000001000100000101000100010100000100010000000101000000010000010001010101010000000101000000010101000001010100000100010001000000000001010000000100010000000000000000000000000001010101010101010101000001010101000001010100000001000101010101010000010101000101010100010101010000010101010100000100000000000101010000000000010101010000000001000000010100000100000001000101010000000001000001000001010001010000010001000101010001010001010101000100010000000100000100010101000000000101010101010001000100000000000101010000010101000001010001010000000001010100000101000001010000000001010101000100010000010101000000000001000101000001010101000101000001000001000000010100010001000101010100010001010000000101000000010001000001000100000101010001000001000001000101010000010001000001000101000000000000000101010000010000000101010100010100010001010101010000000000010001000101010000000001010100000000010001010100010001000001000101000000010100010000010000010001010100010000010001010100010000010100010101010001000100010100010101000100000101010100000100010100000100000000010101000000010001000001010000000101000100000100010101000000010100000101000001010001010100010000000101010000000001010001000000010100010101010001000100010001000001010101000000010001000100000100010101000000000000010100010000000100000000010100010000000100000101010000010101000100010000010100000001000100000000000100000001010101010101000100010001000000010101010100000001000001000001010001000101010100000001010001010100010101000101000000010001010100010101000100000101000101000001000001000001000101010100010001010000000100000101010100000001000000000000010101000100010001000001000001000000000000010100000100000001"> : tensor<200x8xi5> @@ -106,7 +106,7 @@ func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> ASSERT_EXPECTED_FAILURE(lambda.operator()>()); } -TEST(DISABLED_Distributed, nn_med_sequential) { +TEST(Distributed, nn_med_sequential) { if (mlir::concretelang::dfr::_dfr_is_root_node()) { checkedJit(lambda, R"XXX( func.func @main(%arg0: tensor<200x4x!FHE.eint<4>>) -> tensor<200x8x!FHE.eint<4>> {