[TEST] Added convert layout test from/to sliced blocked/mma (#1620)

This commit is contained in:
Zahi Moudallal
2023-05-05 17:20:52 -07:00
committed by GitHub
parent 9d7980fa3b
commit 125d9d1cc7

View File

@@ -1543,7 +1543,53 @@ def test_store_op(M, src_layout, device='cuda'):
pgm = store_kernel[(1, 1, 1)](x_tri, y_tri)
y_ref = x
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
]
@pytest.mark.parametrize("M", [64, 128, 256])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("dst_layout", layouts)
@pytest.mark.parametrize("src_dim", [0, 1])
@pytest.mark.parametrize("dst_dim", [0, 1])
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device='cuda'):
ir = f"""
#dst = {dst_layout}
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%7 = triton_gpu.convert_layout %3 : (tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.store %6, %7 : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, )).astype('int32')
y = np.zeros((M, ), dtype='int32')
x_tri = torch.tensor(x, device=device)
y_tri = torch.tensor(y, device=device)
pgm = kernel[(1, 1, 1)](x_tri, y_tri)
y_ref = x
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)