mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix scan issues on repetitive warps and improve perf when there's a single warp on the axis (#2330)
1. On the axis, using `getAxisNumWarpsWithUniqueData` instead of getting the raw number of warps to avoid communication among warps that handle the same piece of data. 2. When there's a single warp on the axis, using warp Intrinsics for communication and skip shared memory. Need a follow up PR for code clean up.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import triton
|
||||
@@ -153,3 +154,23 @@ def test_avg_pool_bw():
|
||||
out_ref[:, :, 0::7, 1:7] = 2 / 3
|
||||
out_ref[:, :, 0::7, 0::7] = 4 / 9
|
||||
torch.testing.assert_close(out, out_ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("RBLOCK", [32, 64, 128])
|
||||
@pytest.mark.parametrize("num_warps", [1, 4])
|
||||
def test_scan2d_broadcast(RBLOCK, num_warps):
|
||||
@triton.jit(debug=True)
|
||||
def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
||||
rindex = tl.arange(0, RBLOCK)[None, :]
|
||||
xindex = tl.arange(0, XBLOCK)[:, None]
|
||||
data = tl.load(in_ptr + rindex)
|
||||
scan = tl.cumsum(data, 1)
|
||||
expected_max = tl.sum(data, 1)
|
||||
tl.device_assert(scan <= expected_max)
|
||||
tl.store(out_ptr + xindex * RBLOCK + rindex, scan)
|
||||
|
||||
XBLOCK = 4
|
||||
input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int32, device='cuda')
|
||||
output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int32, device='cuda')
|
||||
fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps)
|
||||
torch.testing.assert_allclose(output, input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)))
|
||||
|
||||
Reference in New Issue
Block a user