[BACKEND] Convert layout illegal mem access fix (#2287)

This commit is contained in:
Zahi Moudallal
2023-09-13 10:02:25 -07:00
committed by GitHub
parent 994f7e4460
commit e95e1f12eb
5 changed files with 106 additions and 64 deletions

View File

@@ -18,6 +18,7 @@ using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getShapePerCTATile;
using ::mlir::triton::gpu::getSizePerThread;
using ::mlir::triton::gpu::getUniqueContigPerThread;
using ::mlir::triton::gpu::MmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
using ::mlir::triton::gpu::SliceEncodingAttr;
@@ -50,9 +51,7 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) {
return {inOrd, outOrd};
}
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
SmallVector<unsigned> getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) {
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
@@ -76,15 +75,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
}
}
assert(srcLayout && dstLayout &&
"Unexpected layout in getScratchConfigForCvtLayout()");
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread = getContigPerThread(srcLayout)[inOrd[0]];
unsigned dstContigPerThread = getContigPerThread(dstLayout)[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()");
auto srcShapePerCTA = getShapePerCTA(srcTy);
auto dstShapePerCTA = getShapePerCTA(dstTy);
@@ -92,21 +83,44 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape());
unsigned rank = dstTy.getRank();
SmallVector<unsigned> paddedRepShape(rank);
unsigned pad = std::max(inVec, outVec);
SmallVector<unsigned> repShape(rank);
for (unsigned d = 0; d < rank; ++d) {
paddedRepShape[d] =
repShape[d] =
std::max(std::min<unsigned>(srcShapePerCTA[d], srcShapePerCTATile[d]),
std::min<unsigned>(dstShapePerCTA[d], dstShapePerCTATile[d]));
}
if (rank == 1)
return paddedRepShape;
return repShape;
}
SmallVector<unsigned>
getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
unsigned &outVec) {
auto repShape = getRepShapeForCvtLayout(op);
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
Attribute srcLayout = srcTy.getEncoding();
Attribute dstLayout = dstTy.getEncoding();
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
unsigned srcContigPerThread =
getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]];
unsigned dstContigPerThread =
getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]];
// TODO: Fix the legacy issue that ourOrd[0] == 0 always means
// that we cannot do vectorization.
inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread;
outVec = outOrd[0] == 0 ? 1 : dstContigPerThread;
if (repShape.size() <= 1)
return repShape;
unsigned paddedDim = 1;
if (auto dstBlockedLayout = dstLayout.dyn_cast<BlockedEncodingAttr>()) {
paddedDim = dstBlockedLayout.getOrder()[0];
}
paddedRepShape[paddedDim] += pad;
return paddedRepShape;
unsigned pad = std::max(inVec, outVec);
repShape[paddedDim] += pad;
return repShape;
}
SmallVector<unsigned>