[OPTIMIZER]]BACKEND] Some backend and optimization passes clean-up (#1284)

* Cleaned up pipeline pass. Now works when there are element-wise ops
between the load and the dot
* Made `splat` compatible with varibales that have DotOperandLayout
* Moves rematerialization utils to separate Transforms/Utility.cpp file.
This commit is contained in:
Philippe Tillet
2023-03-06 17:17:59 -08:00
committed by GitHub
parent 73d55eb59c
commit 3db55c5f94
22 changed files with 451 additions and 344 deletions

View File

@@ -54,13 +54,14 @@ public:
private:
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
unsigned elemId, ArrayRef<int64_t> shape,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
ArrayRef<unsigned> shapePerCTA) const {
auto shape = type.getShape();
unsigned rank = shape.size();
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
auto multiDimOffsetFirstElem =
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
SmallVector<Value> multiDimOffset(rank);
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
elemId, getSizePerThread(layout), getOrder(layout));
@@ -73,9 +74,12 @@ private:
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
auto multiDimOffsetParent =
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
sliceLayout.paddedShape(shape),
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
@@ -193,7 +197,7 @@ private:
// of performance issue observed.
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
Value offset =
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
@@ -283,7 +287,7 @@ private:
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
// duplicate in Volta.
SmallVector<Value> multiDimOffset =
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
getMultiDimOffset(layout, loc, rewriter, elemId, type,
multiDimCTAInRepId, shapePerCTA);
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
}
@@ -476,7 +480,7 @@ private:
auto dstStrides =
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
smemBase, elemTy, loc, rewriter);
auto smemObj =