Fix the issue when CTA coverage is larger than the tile

This commit is contained in:
Lixun Zhang
2023-09-08 12:38:49 -05:00
committed by Lixun Zhang
parent ed20089bc8
commit ea397b49aa
3 changed files with 134 additions and 7 deletions

View File

@@ -271,8 +271,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= type.getShape()[dim] /
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
stride *= std::max<unsigned int>(
1, type.getShape()[dim] /
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]));
}
llvm_unreachable("Axis not found in order");
}
@@ -390,7 +391,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 8 &&
dotOperandLayout.getParent() == mfmaLayout &&
mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
#endif