[BACKEND] Fix for FP8 QK inputs in flash attention forward pass (#2435)

This commit is contained in:
Thomas Raoux
2023-10-03 21:02:13 -07:00
committed by GitHub
parent 0d84a7d70c
commit c656a139d3
2 changed files with 4 additions and 1 deletions

View File

@@ -434,7 +434,7 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
// when #mma = MmaEncoding<version=3, warpsPerCTA=[..., 1]>
return src && dst && src.getVersionMajor() == 3 &&
src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 &&
dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2];
dst.getWarpsPerCTA()[1] == 1;
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {