[BACKEND] Fix scan issues on repetitive warps and improve perf when there's a single warp on the axis (#2330)

1. On the axis, using `getAxisNumWarpsWithUniqueData` instead of getting
the raw number of warps to avoid communication among warps that handle
the same piece of data.
2. When there's a single warp on the axis, using warp Intrinsics for
communication and skip shared memory.

Need a follow up PR for code clean up.
This commit is contained in:
Keren Zhou
2023-09-18 17:45:05 -04:00
committed by GitHub
parent a9ae9886dc
commit 307b5caa49
6 changed files with 195 additions and 24 deletions

View File

@@ -1,3 +1,4 @@
import pytest
import torch
import triton
@@ -153,3 +154,23 @@ def test_avg_pool_bw():
out_ref[:, :, 0::7, 1:7] = 2 / 3
out_ref[:, :, 0::7, 0::7] = 4 / 9
torch.testing.assert_close(out, out_ref)
@pytest.mark.parametrize("RBLOCK", [32, 64, 128])
@pytest.mark.parametrize("num_warps", [1, 4])
def test_scan2d_broadcast(RBLOCK, num_warps):
@triton.jit(debug=True)
def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
rindex = tl.arange(0, RBLOCK)[None, :]
xindex = tl.arange(0, XBLOCK)[:, None]
data = tl.load(in_ptr + rindex)
scan = tl.cumsum(data, 1)
expected_max = tl.sum(data, 1)
tl.device_assert(scan <= expected_max)
tl.store(out_ptr + xindex * RBLOCK + rindex, scan)
XBLOCK = 4
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int32, device='cuda')
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int32, device='cuda')
fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
torch.testing.assert_allclose(output, input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)))