mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Thread local reduction optimization (#2542)
Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
@@ -1815,6 +1815,51 @@ scan_layouts = [
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op", ['sum', 'max', 'min'])
|
||||
@pytest.mark.parametrize("BLOCK_N", [32, 64, 128])
|
||||
@pytest.mark.parametrize("N", [512, 1024, 2048])
|
||||
@pytest.mark.parametrize("num_pid_n", [2, 4])
|
||||
def test_locality(op, BLOCK_N, N, num_pid_n):
|
||||
@triton.jit
|
||||
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
pid_n = tl.program_id(1)
|
||||
num_pid_n = tl.num_programs(1)
|
||||
local = INITIALIZE_PATCH
|
||||
off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n):
|
||||
off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
|
||||
Xs = X + off_m[:, None] * N + off_n[None, :]
|
||||
x = tl.load(Xs)
|
||||
local = ACCUMULATE_PATCH
|
||||
tl.store(Y + off_m * num_pid_n + pid_n, local)
|
||||
initialize_patch = {
|
||||
'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)',
|
||||
'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)',
|
||||
'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)',
|
||||
}[op]
|
||||
reduce_patch = {
|
||||
'sum': 'local + tl.sum(x, axis=1)',
|
||||
'max': 'tl.maximum(local, tl.max(x, axis=1))',
|
||||
'min': 'tl.minimum(local, tl.min(x, axis=1))',
|
||||
}[op]
|
||||
numpy_op = {
|
||||
'sum': np.sum,
|
||||
'max': np.max,
|
||||
'min': np.min,
|
||||
}[op]
|
||||
kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch})
|
||||
torch.manual_seed(0)
|
||||
BLOCK_M = 32
|
||||
x = torch.randn((BLOCK_M, N), dtype=torch.float32, device="cuda")
|
||||
y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device="cuda")
|
||||
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N)
|
||||
assert h.asm['ttgir'].count('"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work"
|
||||
y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True)
|
||||
y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True)
|
||||
np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
|
||||
@pytest.mark.parametrize("src_layout", scan_layouts)
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
|
||||
Reference in New Issue
Block a user