[OPTIMIZATION] Fix performance for attention backward path with mma v3 (#2411)

Support having chain of mma with mixed size.
Serialize the different block calculation in backward attention to
workaround problem with ptxas and wgmma.
This commit is contained in:
Thomas Raoux
2023-09-28 10:29:08 -07:00
committed by GitHub
parent 1e093fbfff
commit 721bdebee1
11 changed files with 164 additions and 195 deletions

View File

@@ -81,8 +81,7 @@ public:
// forwarding on mma->mma shortcut, lower distributed->distributed otherwise
if (srcLayout.isa<MmaEncodingAttr>() && dstLayout.isa<MmaEncodingAttr>()) {
if (isMmaToMmaShortcut(srcTy, dstTy)) {
rewriter.replaceOp(op, op.getSrc());
return success();
return lowerMmaToMma(op, adaptor, rewriter);
}
}
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
@@ -963,6 +962,43 @@ private:
return failure();
}
// mma -> mma
LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
if (triton::gpu::getTotalElemsPerThread(srcTy) ==
triton::gpu::getTotalElemsPerThread(dstTy)) {
rewriter.replaceOp(op, op.getSrc());
return success();
}
// get source values
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
SmallVector<Value> retVals;
SmallVector<unsigned> dstElementPerThread =
triton::gpu::getElemsPerThread(dstTy);
SmallVector<unsigned> srcElementPerThread =
triton::gpu::getElemsPerThread(srcTy);
for (unsigned j = 0; j < dstElementPerThread[0]; j++) {
for (unsigned i = 0; i < dstElementPerThread[1]; i++) {
if (i >= srcElementPerThread[1] || j >= srcElementPerThread[0]) {
retVals.push_back(undef(vals[0].getType()));
continue;
}
unsigned index = i + j * srcElementPerThread[1];
retVals.push_back(vals[index]);
}
}
assert(retVals.size() == triton::gpu::getTotalElemsPerThread(dstTy));
Value view =
getTypeConverter()->packLLElements(loc, retVals, rewriter, dstTy);
rewriter.replaceOp(op, view);
return success();
}
// shared -> dot_operand if the result layout is mma
Value lowerSharedToDotOperandMMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,