Fix the issue when CTA coverage is larger than the tile

This commit is contained in:
Lixun Zhang
2023-09-08 12:38:49 -05:00
committed by Lixun Zhang
parent ed20089bc8
commit ea397b49aa
3 changed files with 134 additions and 7 deletions

View File

@@ -1188,6 +1188,62 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
else:
np.testing.assert_equal(z_ref, z_tri)
# ---------------
# test scan
# ---------------
scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)]
scan_configs = [
(op, type, shape, axis, num_warps)
for num_warps in [4, 16]
for type in ['int32', 'float32']
for axis in [1, 0]
for shape in scan2d_shapes
for op in ['cumsum', 'cumprod']
]
@pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs)
def test_scan2d(op, dtype_str, shape, axis, num_warps, device):
check_type_supported(dtype_str, device)
# triton kernel
@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
range_m = tl.arange(0, BLOCK_M)
range_n = tl.arange(0, BLOCK_N)
x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :])
z = GENERATE_TEST_HERE
tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z)
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'})
# input
rs = RandomState(17)
x = numpy_random(shape, dtype_str=dtype_str, rs=rs)
z = np.empty_like(x)
x_tri = to_triton(x, device=device)
numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op]
z_dtype_str = dtype_str
z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str))
# triton result
z_tri = to_triton(z, device=device)
kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps)
z_tri = to_numpy(z_tri)
# compare
if dtype_str == 'float32':
if op == 'cumprod':
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01, atol=1e-3)
else:
np.testing.assert_allclose(z_ref, z_tri, rtol=0.01)
else:
np.testing.assert_equal(z_ref, z_tri)
# ---------------
# test permute
# ---------------
@@ -2555,6 +2611,75 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
scan_layouts = [
BlockedLayout([1, 4], [4, 16], [4, 1], [0, 1]),
BlockedLayout([1, 4], [8, 8], [4, 1], [0, 1]),
BlockedLayout([4, 1], [4, 16], [1, 4], [0, 1]),
BlockedLayout([2, 2], [4, 16], [2, 2], [0, 1]),
BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1]),
BlockedLayout([1, 4], [4, 16], [4, 1], [1, 0]),
BlockedLayout([1, 4], [8, 8], [4, 1], [1, 0]),
BlockedLayout([4, 1], [4, 16], [1, 4], [1, 0]),
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]),
BlockedLayout([2, 2], [8, 8], [2, 2], [1, 0]),
]
@pytest.mark.parametrize("M, N", [[32, 32], [32, 64], [64, 32]])
@pytest.mark.parametrize("src_layout", scan_layouts)
@pytest.mark.parametrize("axis", [0, 1])
def test_scan_layouts(M, N, src_layout, axis, device):
ir = f"""
#blocked = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked>
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
%11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{
^bb0(%arg2: i32, %arg3: i32):
%16 = arith.addi %arg2, %arg3 : i32
tt.scan.return %16 : i32
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked>
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(-100, 100, (M, N)).astype('int32')
z = np.zeros((M, N)).astype('int32')
x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)
kernel[(1, 1, 1)](x_tri, z_tri)
z_ref = np.cumsum(x, axis=axis)
np.testing.assert_equal(z_ref, z_tri.cpu().numpy())
@pytest.mark.parametrize("shape", [(64, 64)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True)])