mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix for FP8 QK inputs in flash attention forward pass (#2435)
This commit is contained in:
@@ -434,7 +434,7 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
|
||||
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
|
||||
return src && dst && src.getVersionMajor() == 3 &&
|
||||
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
|
||||
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
|
||||
dst.getWarpsPerCTA()[1] == 1;
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
|
||||
Reference in New Issue
Block a user