[BACKEND] Handle repetitive threads in scan op when the tensor dim is small (#2345)

https://github.com/openai/triton/issues/2298
This commit is contained in:
Keren Zhou
2023-09-20 12:25:52 -04:00
committed by GitHub
parent e5eda098b3
commit ed5a53057d
4 changed files with 43 additions and 37 deletions

View File

@@ -156,7 +156,7 @@ def test_avg_pool_bw():
torch.testing.assert_close(out, out_ref)
@pytest.mark.parametrize("RBLOCK", [32, 64, 128])
@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128])
@pytest.mark.parametrize("num_warps", [1, 4])
def test_scan2d_broadcast(RBLOCK, num_warps):
@triton.jit(debug=True)