mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER]]BACKEND] Some backend and optimization passes clean-up (#1284)
* Cleaned up pipeline pass. Now works when there are element-wise ops between the load and the dot * Made `splat` compatible with varibales that have DotOperandLayout * Moves rematerialization utils to separate Transforms/Utility.cpp file.
This commit is contained in:
@@ -54,13 +54,14 @@ public:
|
||||
private:
|
||||
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
unsigned elemId, ArrayRef<int64_t> shape,
|
||||
unsigned elemId, RankedTensorType type,
|
||||
ArrayRef<unsigned> multiDimCTAInRepId,
|
||||
ArrayRef<unsigned> shapePerCTA) const {
|
||||
auto shape = type.getShape();
|
||||
unsigned rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
|
||||
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, getSizePerThread(layout), getOrder(layout));
|
||||
@@ -73,9 +74,12 @@ private:
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parentEncoding = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(shape);
|
||||
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
|
||||
parentEncoding);
|
||||
auto multiDimOffsetParent =
|
||||
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
|
||||
sliceLayout.paddedShape(shape),
|
||||
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
|
||||
sliceLayout.paddedShape(multiDimCTAInRepId),
|
||||
sliceLayout.paddedShape(shapePerCTA));
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
@@ -193,7 +197,7 @@ private:
|
||||
// of performance issue observed.
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
Value offset =
|
||||
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
||||
@@ -283,7 +287,7 @@ private:
|
||||
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
|
||||
// duplicate in Volta.
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
|
||||
}
|
||||
@@ -476,7 +480,7 @@ private:
|
||||
|
||||
auto dstStrides =
|
||||
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
|
||||
smemBase, elemTy, loc, rewriter);
|
||||
auto smemObj =
|
||||
|
||||
Reference in New Issue
Block a user