mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZATION] Fix performance for attention backward path with mma v3 (#2411)
Support having chain of mma with mixed size. Serialize the different block calculation in backward attention to workaround problem with ptxas and wgmma.
This commit is contained in:
@@ -81,8 +81,7 @@ public:
|
||||
// forwarding on mma->mma shortcut, lower distributed->distributed otherwise
|
||||
if (srcLayout.isa<MmaEncodingAttr>() && dstLayout.isa<MmaEncodingAttr>()) {
|
||||
if (isMmaToMmaShortcut(srcTy, dstTy)) {
|
||||
rewriter.replaceOp(op, op.getSrc());
|
||||
return success();
|
||||
return lowerMmaToMma(op, adaptor, rewriter);
|
||||
}
|
||||
}
|
||||
if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) {
|
||||
@@ -963,6 +962,43 @@ private:
|
||||
return failure();
|
||||
}
|
||||
|
||||
// mma -> mma
|
||||
LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto dstTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
if (triton::gpu::getTotalElemsPerThread(srcTy) ==
|
||||
triton::gpu::getTotalElemsPerThread(dstTy)) {
|
||||
rewriter.replaceOp(op, op.getSrc());
|
||||
return success();
|
||||
}
|
||||
// get source values
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
SmallVector<Value> retVals;
|
||||
SmallVector<unsigned> dstElementPerThread =
|
||||
triton::gpu::getElemsPerThread(dstTy);
|
||||
SmallVector<unsigned> srcElementPerThread =
|
||||
triton::gpu::getElemsPerThread(srcTy);
|
||||
for (unsigned j = 0; j < dstElementPerThread[0]; j++) {
|
||||
for (unsigned i = 0; i < dstElementPerThread[1]; i++) {
|
||||
if (i >= srcElementPerThread[1] || j >= srcElementPerThread[0]) {
|
||||
retVals.push_back(undef(vals[0].getType()));
|
||||
continue;
|
||||
}
|
||||
unsigned index = i + j * srcElementPerThread[1];
|
||||
retVals.push_back(vals[index]);
|
||||
}
|
||||
}
|
||||
assert(retVals.size() == triton::gpu::getTotalElemsPerThread(dstTy));
|
||||
Value view =
|
||||
getTypeConverter()->packLLElements(loc, retVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
|
||||
// shared -> dot_operand if the result layout is mma
|
||||
Value lowerSharedToDotOperandMMA(
|
||||
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
|
||||
|
||||
Reference in New Issue
Block a user