mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Accommodate new triton IR format (#2294)
- Support memory space for pointers (e.g., `!tt.ptr<f32, 1>`). - Support parsing function attribute, though not used yet.
This commit is contained in:
@@ -1779,28 +1779,28 @@ def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
ir = f"""
|
||||
#blocked = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
||||
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked>
|
||||
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%3 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>
|
||||
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
|
||||
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>
|
||||
%8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
|
||||
%11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
%16 = arith.addi %arg2, %arg3 : i32
|
||||
tt.scan.return %16 : i32
|
||||
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
|
||||
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
|
||||
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%12 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>
|
||||
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>
|
||||
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked>
|
||||
tt.return
|
||||
}}
|
||||
@@ -1871,7 +1871,7 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
|
||||
}}
|
||||
}}
|
||||
""" if reduce2d else f"""
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<{ty}>) -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
|
||||
%14 = tt.splat %arg2 : (!tt.ptr<{ty}, 1>) -> tensor<{rdims_2d}x!tt.ptr<{ty}, 1>, #blocked>
|
||||
%15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
|
||||
%16 = {GPU_DIALECT}.convert_layout %13 : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>
|
||||
%17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}x{ty}, #blocked>
|
||||
@@ -1885,18 +1885,18 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
|
||||
#blocked = {blocked}
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
|
||||
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
|
||||
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
||||
%2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked>
|
||||
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<{ty}>) -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%4 = tt.splat %arg0 : (!tt.ptr<{ty}, 1>) -> tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>
|
||||
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x1xi32, #blocked>
|
||||
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
|
||||
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked>
|
||||
%9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
|
||||
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked>
|
||||
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x{ty}, #blocked>
|
||||
%12 = {GPU_DIALECT}.convert_layout %11 : (tensor<{M}x{N}x{ty}, #blocked>) -> tensor<{M}x{N}x{ty}, #src>
|
||||
%13 = "tt.reduce"(%12) ({{
|
||||
@@ -1945,16 +1945,16 @@ def test_store_op(M, src_layout, device):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
|
||||
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
||||
tt.func public @kernel(%arg0: !tt.ptr<f32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<{M}x!tt.ptr<f32, 1>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32, 1>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
|
||||
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<{M}x1x!tt.ptr<f32, 1>, #src>
|
||||
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32, 1>, #src>, tensor<{M}x1xi32, #src>
|
||||
tt.store %8, %4 : tensor<{M}x1xf32, #src>
|
||||
tt.return
|
||||
}}
|
||||
@@ -1998,14 +1998,14 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
|
||||
#dst = {dst_layout}
|
||||
#src = {src_layout}
|
||||
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
|
||||
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
tt.func public @kernel(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%4 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
tt.return
|
||||
@@ -2069,7 +2069,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
@@ -2079,8 +2079,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
|
||||
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
|
||||
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<{M}x{N}x!tt.ptr<i32, 1>, #src>
|
||||
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32, 1>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
|
||||
%11 = "tt.reduce"(%10) ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
@@ -3664,22 +3664,22 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
|
||||
|
||||
ir = layouts + f"""
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16, 1> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
|
||||
%2 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<{M}x{N}x!tt.ptr<f16, 1>, #src>
|
||||
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
|
||||
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
|
||||
%7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16, 1>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
|
||||
%3 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<{M}x{N}x!tt.ptr<f16, 1>, #dst>
|
||||
""" + conversion + f"""
|
||||
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
|
||||
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16, 1>, #dst>, tensor<{M}x{N}xi32, #dst>
|
||||
tt.store %14, %13 : tensor<{M}x{N}xf16, #dst>
|
||||
tt.return
|
||||
}}
|
||||
|
||||
@@ -330,7 +330,7 @@ def test_compile_link_autotune_matmul():
|
||||
def test_ttgir_to_ptx():
|
||||
src = """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} {
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32, 1>, %arg1: !tt.ptr<i32, 1>) {
|
||||
tt.return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -204,7 +204,9 @@ def get_kernel_name(src: str, pattern: str) -> str:
|
||||
|
||||
|
||||
def convert_type_repr(x):
|
||||
match = re.search(r'!tt\.ptr<(.*)>', x)
|
||||
# Currently we only capture the pointer type and assume the pointer is on global memory.
|
||||
# TODO: Capture and support shared memory space
|
||||
match = re.search(r'!tt\.ptr<([^,]+)', x)
|
||||
if match is not None:
|
||||
return '*' + convert_type_repr(match.group(1))
|
||||
return x
|
||||
@@ -241,7 +243,8 @@ def make_hash(fn, arch, env_vars, **kwargs):
|
||||
# (letters, digits, or underscores), and capture it as group 1 (the function name)
|
||||
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
|
||||
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
|
||||
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
|
||||
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
|
||||
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
|
||||
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
|
||||
prototype_pattern = {
|
||||
"ttir": mlir_prototype_pattern,
|
||||
@@ -249,7 +252,11 @@ prototype_pattern = {
|
||||
"ptx": ptx_prototype_pattern,
|
||||
}
|
||||
|
||||
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
|
||||
# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either:
|
||||
# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol.
|
||||
# |: OR
|
||||
# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between.
|
||||
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?'
|
||||
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
|
||||
arg_type_pattern = {
|
||||
"ttir": mlir_arg_type_pattern,
|
||||
@@ -422,6 +429,7 @@ def compile(fn, **kwargs):
|
||||
src = Path(fn).read_text()
|
||||
import re
|
||||
match = re.search(prototype_pattern[ir_name], src, re.MULTILINE)
|
||||
# TODO: support function attributes at group 3 (e.g., device function)
|
||||
name, signature = match.group(1), match.group(2)
|
||||
types = re.findall(arg_type_pattern[ir_name], signature)
|
||||
if ir_name == 'ttgir':
|
||||
|
||||
Reference in New Issue
Block a user