fix(dfr-compiler): clone memref task arguments using identity maps for serialization.

This commit is contained in:
Antoniu Pop
2022-08-11 16:24:25 +01:00
committed by Quentin Bourgerie
parent 8cd3a3a599
commit 93802d128b

View File

@@ -201,7 +201,6 @@ getTaskArgumentSizeAndType(Value val, Location loc, OpBuilder builder) {
static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
func::FuncOp workFunction) {
DataLayout dataLayout = DataLayout::closest(DFTOp);
Region &opBody = DFTOp->getParentOfType<func::FuncOp>().getBody();
OpBuilder builder(DFTOp);
@@ -213,7 +212,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
if (!val.getType().isa<RT::FutureType>()) {
OpBuilder::InsertionGuard guard(builder);
Type futType = RT::FutureType::get(val.getType());
Value memrefCloned, newval = val;
Value memrefCloned;
// Find out if this value is needed in any other task
SmallVector<Operation *, 2> taskOps;
@@ -230,20 +229,6 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
// the memref in order to allow the deallocation pass which does
// not synchronize with task execution.
if (val.getType().isa<mlir::MemRefType>()) {
// Get the type of memref that we will clone. In case this is
// a subview, we discard the mapping so we copy to a contiguous
// layout which pre-serializes this.
MemRefType mrType = val.getType().dyn_cast<mlir::MemRefType>();
if (!mrType.getLayout().isIdentity()) {
unsigned rank = mrType.getRank();
mrType = MemRefType::Builder(mrType)
.setShape(mrType.getShape())
.setLayout(AffineMapAttr::get(
builder.getMultiDimIdentityMap(rank)));
}
newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
.getResult();
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
memrefCloned = builder.create<arith::ConstantOp>(
val.getLoc(), builder.getI64IntegerAttr(1));
} else {
@@ -252,7 +237,7 @@ static void lowerDataflowTaskOp(RT::DataflowTaskOp DFTOp,
}
auto mrf = builder.create<RT::MakeReadyFutureOp>(val.getLoc(), futType,
newval, memrefCloned);
val, memrefCloned);
replaceAllUsesInDFTsInRegionWith(val, mrf, opBody);
}
}
@@ -356,6 +341,59 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
static void propagateMemRefLayoutInDFTs(RT::DataflowTaskOp op, Value val,
Value newval) {
for (auto &use : llvm::make_early_inc_range(val.getUses()))
if (use.getOwner()->getParentOfType<RT::DataflowTaskOp>() != nullptr) {
OpBuilder builder(use.getOwner());
Value cast_newval = builder.create<mlir::memref::CastOp>(
val.getLoc(), val.getType(), newval);
use.set(cast_newval);
}
}
static void cloneMemRefTaskArgumentsWithIdentityMaps(RT::DataflowTaskOp op) {
OpBuilder builder(op);
for (Value val : op.getOperands()) {
if (val.getType().isa<mlir::MemRefType>()) {
OpBuilder::InsertionGuard guard(builder);
// Find out if this memref is needed in any other task to clone
// before all uses
SmallVector<Operation *, 2> taskOps;
for (auto &use : val.getUses())
if (isa<RT::DataflowTaskOp>(use.getOwner()))
taskOps.push_back(use.getOwner());
Operation *first = op;
for (auto op : taskOps)
if (first->getBlock() == op->getBlock() && op->isBeforeInBlock(first))
first = op;
builder.setInsertionPoint(first);
// Get the type of memref that we will clone. In case this is
// a subview, we discard the mapping so we copy to a contiguous
// layout which pre-serializes this.
MemRefType mrType_base = val.getType().dyn_cast<mlir::MemRefType>();
MemRefType mrType = mrType_base;
if (!mrType_base.getLayout().isIdentity()) {
unsigned rank = mrType_base.getRank();
mrType = MemRefType::Builder(mrType_base)
.setShape(mrType_base.getShape())
.setLayout(AffineMapAttr::get(
builder.getMultiDimIdentityMap(rank)));
}
Value newval = builder.create<mlir::memref::AllocOp>(val.getLoc(), mrType)
.getResult();
builder.create<mlir::memref::CopyOp>(val.getLoc(), val, newval);
// Value cast_newval = builder.create<mlir::memref::CastOp>(val.getLoc(),
// mrType_base, newval);
replaceAllUsesInDFTsInRegionWith(
val, newval, op->getParentOfType<func::FuncOp>().getBody());
propagateMemRefLayoutInDFTs(op, val, newval);
}
}
}
/// For documentation see Autopar.td
struct LowerDataflowTasksPass
: public LowerDataflowTasksBase<LowerDataflowTasksPass> {
@@ -376,6 +414,7 @@ struct LowerDataflowTasksPass
SmallVector<std::pair<RT::DataflowTaskOp, func::FuncOp>, 4> outliningMap;
func.walk([&](RT::DataflowTaskOp op) {
cloneMemRefTaskArgumentsWithIdentityMaps(op);
auto workFunctionName =
Twine("_dfr_DFT_work_function__") +
Twine(op->getParentOfType<func::FuncOp>().getName()) +