From 93802d128b637ec1f1780503934685d2b3dfb600 Mon Sep 17 00:00:00 2001 From: Antoniu Pop Date: Thu, 11 Aug 2022 16:24:25 +0100 Subject: [PATCH] fix(dfr-compiler): clone memref task arguments using identity maps for serialization. --- .../RT/Analysis/LowerDataflowTasksToRT.cpp | 73 ++++++++++++++----- 1 file changed, 56 insertions(+), 17 deletions(-) diff --git a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp index 8a7d0d548..2355905b4 100644 --- a/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp +++ b/compiler/lib/Dialect/RT/Analysis/LowerDataflowTasksToRT.cpp @@ -201,7 +201,6 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) { static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, func::FuncOp workFunction) { - DataLayout dataLayout = DataLayout::closest(DFTOp); Region &opBody = DFTOp->getParentOfType().getBody(); OpBuilder builder(DFTOp); @@ -213,7 +212,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, if (!val.getType().isa()) { OpBuilder::InsertionGuard guard(builder); Type futType = RT::FutureType::get(val.getType()); - Value memrefCloned, newval = val; + Value memrefCloned; // Find out if this value is needed in any other task SmallVector taskOps; @@ -230,20 +229,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, // the memref in order to allow the deallocation pass which does // not synchronize with task execution. if (val.getType().isa()) { - // Get the type of memref that we will clone. In case this is - // a subview, we discard the mapping so we copy to a contiguous - // layout which pre-serializes this. - MemRefType mrType = val.getType().dyn_cast(); - if (!mrType.getLayout().isIdentity()) { - unsigned rank = mrType.getRank(); - mrType = MemRefType::Builder(mrType) - .setShape(mrType.getShape()) - .setLayout(AffineMapAttr::get( - builder.getMultiDimIdentityMap(rank))); - } - newval = builder.create(val.getLoc(), mrType) - .getResult(); - builder.create(val.getLoc(), val, newval); memrefCloned = builder.create( val.getLoc(), builder.getI64IntegerAttr(1)); } else { @@ -252,7 +237,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp, } auto mrf = builder.create(val.getLoc(), futType, - newval, memrefCloned); + val, memrefCloned); replaceAllUsesInDFTsInRegionWith(val, mrf, opBody); } } @@ -356,6 +341,59 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) { SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } +static void propagateMemRefLayoutInDFTs(RT::DataflowTaskOp op, Value val, + Value newval) { + for (auto &use : llvm::make_early_inc_range(val.getUses())) + if (use.getOwner()->getParentOfType() != nullptr) { + OpBuilder builder(use.getOwner()); + Value cast_newval = builder.create( + val.getLoc(), val.getType(), newval); + use.set(cast_newval); + } +} + +static void cloneMemRefTaskArgumentsWithIdentityMaps(RT::DataflowTaskOp op) { + OpBuilder builder(op); + for (Value val : op.getOperands()) { + if (val.getType().isa()) { + OpBuilder::InsertionGuard guard(builder); + + // Find out if this memref is needed in any other task to clone + // before all uses + SmallVector taskOps; + for (auto &use : val.getUses()) + if (isa(use.getOwner())) + taskOps.push_back(use.getOwner()); + Operation *first = op; + for (auto op : taskOps) + if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first)) + first = op; + builder.setInsertionPoint(first); + + // Get the type of memref that we will clone. In case this is + // a subview, we discard the mapping so we copy to a contiguous + // layout which pre-serializes this. + MemRefType mrType_base = val.getType().dyn_cast(); + MemRefType mrType = mrType_base; + if (!mrType_base.getLayout().isIdentity()) { + unsigned rank = mrType_base.getRank(); + mrType = MemRefType::Builder(mrType_base) + .setShape(mrType_base.getShape()) + .setLayout(AffineMapAttr::get( + builder.getMultiDimIdentityMap(rank))); + } + Value newval = builder.create(val.getLoc(), mrType) + .getResult(); + builder.create(val.getLoc(), val, newval); + // Value cast_newval = builder.create(val.getLoc(), + // mrType_base, newval); + replaceAllUsesInDFTsInRegionWith( + val, newval, op->getParentOfType().getBody()); + propagateMemRefLayoutInDFTs(op, val, newval); + } + } +} + /// For documentation see Autopar.td struct LowerDataflowTasksPass : public LowerDataflowTasksBase { @@ -376,6 +414,7 @@ struct LowerDataflowTasksPass SmallVector, 4> outliningMap; func.walk([&](RT::DataflowTaskOp op) { + cloneMemRefTaskArgumentsWithIdentityMaps(op); auto workFunctionName = Twine("_dfr_DFT_work_function__") + Twine(op->getParentOfType().getName()) +