[BACKEND] Make sure getAxisBlockStride does not return 0 (#2273)

This can happen when the CTA shape is larger than the tensor shape along
the non-axis dim during scanOp lowering.
This commit is contained in:
Lixun Zhang
2023-09-11 13:02:56 -05:00
committed by GitHub
parent f6828e1a6f
commit 28d4c3bdb4
2 changed files with 4 additions and 3 deletions

View File

@@ -317,8 +317,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= type.getShape()[dim] /
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
stride *= ceil<unsigned int>(type.getShape()[dim], sizePerThreads[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
}
llvm_unreachable("Axis not found in order");
}