[FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)

* [WIP][FA OPTIMIZATION] Optimize chain dot

This commit optimizes chain dot operation by keeping
results of the first dot operation in registers.

* [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers

* Move operand swapping in ttgir -> llir lowering phase

* Refactor emitMfmaOffsetForCTA function to be more readable

* Fix accidental change in 06-fused-attention.py

* Address review comments

* Fix rebase errors
This commit is contained in:
oplavsic
2023-07-12 22:25:55 +02:00
committed by GitHub
parent 4d0deef45f
commit d6e51fd221
17 changed files with 299 additions and 100 deletions

View File

@@ -195,6 +195,22 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
!srcTy.getElementType().isF32();
}
#ifdef USE_ROCM
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto srcLayout = srcTy.getEncoding();
auto dstLayout = dstTy.getEncoding();
auto mfmaLayout = srcLayout.cast<triton::gpu::MfmaEncodingAttr>();
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
// improved. In addition, we can enable this shortcut for regular MFMA
// layout when opIdx == 1.
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mfmaLayout &&
mfmaLayout.getIsTransposed() && srcTy.getElementType().isF16();
}
#endif
bool isSingleValue(Value value) {
// Don't consider load as expensive if it is loading a scalar.
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())