[BACKEND] Make sure getAxisBlockStride does not return 0 (#2273)

This can happen when the CTA shape is larger than the tensor shape along
the non-axis dim during scanOp lowering.
This commit is contained in:
Lixun Zhang
2023-09-11 13:02:56 -05:00
committed by GitHub
parent f6828e1a6f
commit 28d4c3bdb4
2 changed files with 4 additions and 3 deletions

View File

@@ -1750,7 +1750,7 @@ scan_layouts = [
]
@pytest.mark.parametrize("M, N", [[32, 32], [32, 64], [64, 32]])
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
@pytest.mark.parametrize("src_layout", scan_layouts)
@pytest.mark.parametrize("axis", [0, 1])
def test_scan_layouts(M, N, src_layout, axis, device):