[OPTIMIZER] Thread local reduction optimization (#2542)

Co-authored-by: Phil Tillet <phil@openai.com>
This commit is contained in:
Zahi Moudallal
2023-10-31 16:13:36 -07:00
committed by GitHub
parent 258399c114
commit 3650213218
12 changed files with 986 additions and 31 deletions

View File

@@ -1815,6 +1815,51 @@ scan_layouts = [
]
@pytest.mark.parametrize("op", ['sum', 'max', 'min'])
@pytest.mark.parametrize("BLOCK_N", [32, 64, 128])
@pytest.mark.parametrize("N", [512, 1024, 2048])
@pytest.mark.parametrize("num_pid_n", [2, 4])
def test_locality(op, BLOCK_N, N, num_pid_n):
@triton.jit
def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
pid_n = tl.program_id(1)
num_pid_n = tl.num_programs(1)
local = INITIALIZE_PATCH
off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
for start_n in range(pid_n, tl.cdiv(N, BLOCK_N), num_pid_n):
off_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
Xs = X + off_m[:, None] * N + off_n[None, :]
x = tl.load(Xs)
local = ACCUMULATE_PATCH
tl.store(Y + off_m * num_pid_n + pid_n, local)
initialize_patch = {
'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)',
'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)',
'min': 'tl.full([BLOCK_M], float("inf"), dtype=tl.float32)',
}[op]
reduce_patch = {
'sum': 'local + tl.sum(x, axis=1)',
'max': 'tl.maximum(local, tl.max(x, axis=1))',
'min': 'tl.minimum(local, tl.min(x, axis=1))',
}[op]
numpy_op = {
'sum': np.sum,
'max': np.max,
'min': np.min,
}[op]
kernel = patch_kernel(kernel, {'ACCUMULATE_PATCH': reduce_patch, 'INITIALIZE_PATCH': initialize_patch})
torch.manual_seed(0)
BLOCK_M = 32
x = torch.randn((BLOCK_M, N), dtype=torch.float32, device="cuda")
y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device="cuda")
h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N)
assert h.asm['ttgir'].count('"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work"
y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True)
y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True)
np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3)
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
@pytest.mark.parametrize("src_layout", scan_layouts)
@pytest.mark.parametrize("axis", [0, 1])