Add support for MFMA layout to view_slice op (#442)

Co-authored-by: Ognjen <oplavsic@luxoft.com>
This commit is contained in:
oplavsic
2024-01-03 19:13:36 +01:00
committed by GitHub
parent 6a520566a3
commit bcea3051af
2 changed files with 51 additions and 44 deletions

View File

@@ -2933,50 +2933,58 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
assert torch.equal(z, x)
layouts = [
view_layouts = [
BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False),
]
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 128, 256, 32, 0, 0], [256, 256, 128, 64, 64, 128], [128, 128, 128, 32, 0, 0], [128, 128, 128, 32, 0, 64]])
blocked_layouts = [
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
]
@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", [[256, 128, 256, 32, 0, 0], [128, 128, 128, 32, 0, 0], [128, 128, 128, 32, 0, 64]])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", layouts)
def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, src_layout, device='cuda'):
@pytest.mark.parametrize("view_layout", view_layouts)
@pytest.mark.parametrize("blocked_layout", blocked_layouts)
def test_view_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, view_layout, device='cuda'):
if torch.version.hip is None:
pytest.skip("view_slice is AMD specific instruction.")
ir = f"""
#src = {src_layout}
#blocked = {blocked_layout}
#view_layout = {view_layout}
""" + """
module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + f""" : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%43 = tt.expand_dims %42 {{axis = 1 : i32}} : (tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M_tile_size}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{M}xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x{M}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%34 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>
%37 = tt.expand_dims %33 {{axis = 0 : i32}} : (tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N_tile_size}xi32, #src>
%38 = tt.broadcast %37 : (tensor<1x{N_tile_size}xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%39 = tt.broadcast %44 : (tensor<{M_tile_size}x1xi32, #src>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
%12 = triton_gpu.view_slice %11[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #src> to tensor<{M_tile_size}x{N_tile_size}xf16, #src>
%13 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #src>, tensor<{M_tile_size}x{N_tile_size}xi32, #src>
tt.store %13, %12 : tensor<{M_tile_size}x{N_tile_size}xf16, #src>
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
%cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #blocked>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%43 = tt.expand_dims %42 {{axis = 1 : i32}} : (tensor<{M_tile_size}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M_tile_size}x1xi32, #blocked>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked>
%44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{M}xi32, #blocked>
%7 = tt.broadcast %6 : (tensor<1x{M}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked>
%33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%34 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #blocked>
%37 = tt.expand_dims %33 {{axis = 0 : i32}} : (tensor<{N_tile_size}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N_tile_size}xi32, #blocked>
%38 = tt.broadcast %37 : (tensor<1x{N_tile_size}xi32, #blocked>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
%39 = tt.broadcast %44 : (tensor<{M_tile_size}x1xi32, #blocked>) -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
%40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #blocked>
%12 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #blocked>) -> tensor<{M}x{N}xf16, #view_layout>
%13 = triton_gpu.view_slice %12[{M_tile_offset}, {N_tile_offset}] [{M_tile_size}, {N_tile_size}] [1, 1] : tensor<{M}x{N}xf16, #view_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout>
%14 = triton_gpu.convert_layout %13 : (tensor<{M_tile_size}x{N_tile_size}xf16, #view_layout>) -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked>
%15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr<f16>, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked>
tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}xf16, #blocked>
tt.return
}}
}}
@@ -2998,6 +3006,7 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
kernel[(1, 1, 1)](x.data_ptr(), z_tri)
np.testing.assert_equal(z_numpy, to_numpy(z_tri))
if torch.version.hip is not None and _get_warp_size() == 64:
layouts = [
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),