mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Make sure getAxisBlockStride does not return 0 (#2273)
This can happen when the CTA shape is larger than the tensor shape along the non-axis dim during scanOp lowering.
This commit is contained in:
@@ -317,8 +317,9 @@ unsigned ScanLoweringHelper::getAxisBlockStride() {
|
||||
for (unsigned dim : order) {
|
||||
if (dim == getAxis())
|
||||
return stride;
|
||||
stride *= type.getShape()[dim] /
|
||||
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
|
||||
stride *= ceil<unsigned int>(type.getShape()[dim], sizePerThreads[dim] *
|
||||
threadsPerWarp[dim] *
|
||||
warpsPerCTA[dim]);
|
||||
}
|
||||
llvm_unreachable("Axis not found in order");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user