mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Support MMA V3 with register operand (#2375)
MMA V3 support taking operand A from register. This helps for chained matmul operations like in attention. Add an optimization to use this mode when it helps and add the lowering for it.
This commit is contained in:
@@ -424,7 +424,21 @@ bool supportMMA(Value value, int version) {
|
||||
(elemTy.isInteger(8) && version >= 2);
|
||||
}
|
||||
|
||||
// For MMAV3 dotOperand layout matches mma operand for f16 case.
|
||||
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
RankedTensorType dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto mmaLayout = srcLayout.cast<triton::gpu::MmaEncodingAttr>();
|
||||
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
return mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout &&
|
||||
srcTy.getElementType().isF16();
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
|
||||
return true;
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
|
||||
Reference in New Issue
Block a user