mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Convert layout illegal mem access fix (#2287)
This commit is contained in:
@@ -18,6 +18,7 @@ using ::mlir::triton::gpu::getOrder;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getShapePerCTATile;
|
||||
using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::getUniqueContigPerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
@@ -50,9 +51,7 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
|
||||
return {inOrd, outOrd};
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec) {
|
||||
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
@@ -76,15 +75,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
}
|
||||
}
|
||||
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpected layout in getScratchConfigForCvtLayout()");
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
|
||||
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()");
|
||||
|
||||
auto srcShapePerCTA = getShapePerCTA(srcTy);
|
||||
auto dstShapePerCTA = getShapePerCTA(dstTy);
|
||||
@@ -92,21 +83,44 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
|
||||
|
||||
unsigned rank = dstTy.getRank();
|
||||
SmallVector<unsigned> paddedRepShape(rank);
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
SmallVector<unsigned> repShape(rank);
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
paddedRepShape[d] =
|
||||
repShape[d] =
|
||||
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
|
||||
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
|
||||
}
|
||||
if (rank == 1)
|
||||
return paddedRepShape;
|
||||
return repShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
unsigned &outVec) {
|
||||
auto repShape = getRepShapeForCvtLayout(op);
|
||||
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
unsigned srcContigPerThread =
|
||||
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
|
||||
unsigned dstContigPerThread =
|
||||
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
|
||||
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
|
||||
// that we cannot do vectorization.
|
||||
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
|
||||
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
|
||||
|
||||
if (repShape.size() <= 1)
|
||||
return repShape;
|
||||
unsigned paddedDim = 1;
|
||||
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
paddedDim = dstBlockedLayout.getOrder()[0];
|
||||
}
|
||||
paddedRepShape[paddedDim] += pad;
|
||||
return paddedRepShape;
|
||||
unsigned pad = std::max(inVec, outVec);
|
||||
repShape[paddedDim] += pad;
|
||||
return repShape;
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
|
||||
Reference in New Issue
Block a user