mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FA OPTIMIZATION] Keep results of FA dot operations in registers (#247)
* [WIP][FA OPTIMIZATION] Optimize chain dot This commit optimizes chain dot operation by keeping results of the first dot operation in registers. * [FA OPTIMIZATION] Enable lowering pipeline for keeping result of chain dot in registers * Move operand swapping in ttgir -> llir lowering phase * Refactor emitMfmaOffsetForCTA function to be more readable * Fix accidental change in 06-fused-attention.py * Address review comments * Fix rebase errors
This commit is contained in:
@@ -195,6 +195,22 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
!srcTy.getElementType().isF32();
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto dstLayout = dstTy.getEncoding();
|
||||
auto mfmaLayout = srcLayout.cast<triton::gpu::MfmaEncodingAttr>();
|
||||
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
// TODO: Remove the restriction on the warpsPerCTA once chain dot testing is
|
||||
// improved. In addition, we can enable this shortcut for regular MFMA
|
||||
// layout when opIdx == 1.
|
||||
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mfmaLayout &&
|
||||
mfmaLayout.getIsTransposed() && srcTy.getElementType().isF16();
|
||||
}
|
||||
#endif
|
||||
|
||||
bool isSingleValue(Value value) {
|
||||
// Don't consider load as expensive if it is loading a scalar.
|
||||
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
||||
|
||||
Reference in New Issue
Block a user