mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Fix the issue when CTA coverage is larger than the tile
This commit is contained in:
@@ -271,8 +271,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= type.getShape()[dim] /
|
||||
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
|
||||
stride *= std::max<unsigned int>(
|
||||
1, type.getShape()[dim] /
|
||||
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]));
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
@@ -390,7 +391,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getKWidth() == 8 &&
|
||||
dotOperandLayout.getParent() == mfmaLayout &&
|
||||
mfmaLayout.getIsTransposed() && (srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
mfmaLayout.getIsTransposed() &&
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
Reference in New Issue
Block a user