[FRONTEND] Accommodate new triton IR format (#2294)

- Support memory space for pointers (e.g., `!tt.ptr<f32, 1>`).
- Support parsing function attribute, though not used yet.
This commit is contained in:
Keren Zhou
2023-09-14 12:03:23 -04:00
committed by GitHub
parent 36087a108f
commit 08c1658957
3 changed files with 45 additions and 37 deletions

View File

@@ -1779,28 +1779,28 @@ def test_scan_layouts(M, N, src_layout, axis, device):
ir = f"""
#blocked = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked>
%3 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%3 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>
%4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x1xi32, #blocked>
%5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>
%6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>
%8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked>
%11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{
^bb0(%arg2: i32, %arg3: i32):
%16 = arith.addi %arg2, %arg3 : i32
tt.scan.return %16 : i32
}}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%12 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x1x!tt.ptr<i32>, #blocked>
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32>, #blocked>, tensor<{M}x1xi32, #blocked>
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32>, #blocked>
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%12 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>
%13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x1xi32, #blocked>
%14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr<i32, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>
%15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr<i32, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked>
tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked>
tt.return
}}
@@ -1871,7 +1871,7 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
}}
}}
""" if reduce2d else f"""
%14 = tt.splat %arg2 : (!tt.ptr<{ty}>) -> tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>
%14 = tt.splat %arg2 : (!tt.ptr<{ty}, 1>) -> tensor<{rdims_2d}x!tt.ptr<{ty}, 1>, #blocked>
%15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked>
%16 = {GPU_DIALECT}.convert_layout %13 : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>
%17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}x{ty}, #blocked>
@@ -1885,18 +1885,18 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op,
#blocked = {blocked}
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}> {{tt.divisibility = 16 : i32}}) {{
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
%2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked>
%3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<{ty}>) -> tensor<{M}x1x!tt.ptr<{ty}>, #blocked>
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}>, #blocked>, tensor<{M}x1xi32, #blocked>
%4 = tt.splat %arg0 : (!tt.ptr<{ty}, 1>) -> tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>
%5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x1xi32, #blocked>
%6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>
%7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked>
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>
%8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked>
%9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked>
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x{ty}, #blocked>
%12 = {GPU_DIALECT}.convert_layout %11 : (tensor<{M}x{N}x{ty}, #blocked>) -> tensor<{M}x{N}x{ty}, #src>
%13 = "tt.reduce"(%12) ({{
@@ -1945,16 +1945,16 @@ def test_store_op(M, src_layout, device):
ir = f"""
#src = {src_layout}
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
tt.func public @kernel(%arg0: !tt.ptr<f32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32, 1> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%1 = tt.splat %arg0 : (!tt.ptr<f32, 1>) -> tensor<{M}x!tt.ptr<f32, 1>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32, 1>, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
%7 = tt.splat %arg1 : (!tt.ptr<f32, 1>) -> tensor<{M}x1x!tt.ptr<f32, 1>, #src>
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32, 1>, #src>, tensor<{M}x1xi32, #src>
tt.store %8, %4 : tensor<{M}x1xf32, #src>
tt.return
}}
@@ -1998,14 +1998,14 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device):
#dst = {dst_layout}
#src = {src_layout}
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
tt.func public @kernel(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%4 = tt.splat %arg1 : (!tt.ptr<i32, 1>) -> tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32, 1>, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
%7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.return
@@ -2069,7 +2069,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
ir = f"""
#src = {src_layout}
module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32, 1> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
@@ -2079,8 +2079,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32, 1>) -> tensor<{M}x{N}x!tt.ptr<i32, 1>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32, 1>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
@@ -3664,22 +3664,22 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device):
ir = layouts + f"""
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16> {{tt.divisibility = 16 : i32}}) {{
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16, 1> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f16, 1> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%2 = tt.splat %arg0 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
%2 = tt.splat %arg0 : (!tt.ptr<f16, 1>) -> tensor<{M}x{N}x!tt.ptr<f16, 1>, #src>
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16, 1>, #src>, tensor<{M}x{N}xi32, #src>
%11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src>
%3 = tt.splat %arg1 : (!tt.ptr<f16>) -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
%3 = tt.splat %arg1 : (!tt.ptr<f16, 1>) -> tensor<{M}x{N}x!tt.ptr<f16, 1>, #dst>
""" + conversion + f"""
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16, 1>, #dst>, tensor<{M}x{N}xi32, #dst>
tt.store %14, %13 : tensor<{M}x{N}xf16, #dst>
tt.return
}}

View File

@@ -330,7 +330,7 @@ def test_compile_link_autotune_matmul():
def test_ttgir_to_ptx():
src = """
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} {
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32, 1>, %arg1: !tt.ptr<i32, 1>) {
tt.return
}
}

View File

@@ -204,7 +204,9 @@ def get_kernel_name(src: str, pattern: str) -> str:
def convert_type_repr(x):
match = re.search(r'!tt\.ptr<(.*)>', x)
# Currently we only capture the pointer type and assume the pointer is on global memory.
# TODO: Capture and support shared memory space
match = re.search(r'!tt\.ptr<([^,]+)', x)
if match is not None:
return '*' + convert_type_repr(match.group(1))
return x
@@ -241,7 +243,8 @@ def make_hash(fn, arch, env_vars, **kwargs):
# (letters, digits, or underscores), and capture it as group 1 (the function name)
# - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
# zero or more arguments separated by commas, and capture it as group 2 (the argument list)
mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$'
# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$"
ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
prototype_pattern = {
"ttir": mlir_prototype_pattern,
@@ -249,7 +252,11 @@ prototype_pattern = {
"ptx": ptx_prototype_pattern,
}
mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?'
# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either:
# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol.
# |: OR
# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between.
mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?'
ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
arg_type_pattern = {
"ttir": mlir_arg_type_pattern,
@@ -422,6 +429,7 @@ def compile(fn, **kwargs):
src = Path(fn).read_text()
import re
match = re.search(prototype_pattern[ir_name], src, re.MULTILINE)
# TODO: support function attributes at group 3 (e.g., device function)
name, signature = match.group(1), match.group(2)
types = re.findall(arg_type_pattern[ir_name], signature)
if ir_name == 'ttgir':