mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Fix layout formatting
This commit is contained in:
committed by
Jason Furmanek
parent
336c4b5f3c
commit
8ccc4b0cce
@@ -13,6 +13,7 @@ from numpy.random import RandomState
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
import triton.language as tl
|
||||
from triton.common.build import is_hip
|
||||
from triton.runtime.jit import JITFunction, TensorWrapper, reinterpret
|
||||
|
||||
int_dtypes = ['int8', 'int16', 'int32', 'int64']
|
||||
@@ -22,6 +23,13 @@ dtypes = int_dtypes + uint_dtypes + float_dtypes
|
||||
dtypes_with_bfloat16 = dtypes + ['bfloat16']
|
||||
torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16']
|
||||
|
||||
if is_hip():
|
||||
GPU_DIALECT = "triton_gpu"
|
||||
THREADS_PER_WARP = 64
|
||||
else:
|
||||
GPU_DIALECT = "triton_gpu"
|
||||
THREADS_PER_WARP = 32
|
||||
|
||||
|
||||
def _bitwidth(dtype: str) -> int:
|
||||
# ex.: "int64" -> 64
|
||||
@@ -2429,46 +2437,58 @@ def test_while():
|
||||
# -----------------------
|
||||
# TODO: backend should be tested separately
|
||||
|
||||
|
||||
class MmaLayout:
|
||||
def __init__(self, version, warps_per_cta):
|
||||
def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.ctas_per_cga = str(ctas_per_cga)
|
||||
self.cta_split_num = str(cta_split_num)
|
||||
self.cta_order = str(cta_order)
|
||||
self.instr_shape = str(instr_shape)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}}}>"
|
||||
return f"#{GPU_DIALECT}.mma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA={self.warps_per_cta}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}, instrShape={self.instr_shape}}}>"
|
||||
|
||||
|
||||
class MfmaLayout:
|
||||
def __init__(self, non_k_dim, warps_per_cta, isTransposed):
|
||||
def __init__(self, non_k_dim, warps_per_cta, is_transposed, ctas_per_cga, cta_split_num, cta_order):
|
||||
self.non_k_dim = str(non_k_dim)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.isTransposed = str(isTransposed).lower()
|
||||
self.is_transposed = str(is_transposed).lower()
|
||||
self.ctas_per_cga = str(ctas_per_cga)
|
||||
self.cta_split_num = str(cta_split_num)
|
||||
self.cta_order = str(cta_order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.isTransposed}}}>"
|
||||
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
|
||||
|
||||
|
||||
class BlockedLayout:
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order):
|
||||
def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order):
|
||||
self.sz_per_thread = str(size_per_thread)
|
||||
self.threads_per_warp = str(threads_per_warp)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.order = str(order)
|
||||
self.ctas_per_cga = str(ctas_per_cga)
|
||||
self.cta_split_num = str(cta_split_num)
|
||||
self.cta_order = str(cta_order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}}}>"
|
||||
return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
|
||||
|
||||
|
||||
class SharedLayout:
|
||||
def __init__(self, vec, per_phase, max_phase, order):
|
||||
def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order):
|
||||
self.vec = str(vec)
|
||||
self.per_phase = str(per_phase)
|
||||
self.max_phase = str(max_phase)
|
||||
self.order = str(order)
|
||||
self.ctas_per_cga = str(ctas_per_cga)
|
||||
self.cta_split_num = str(cta_split_num)
|
||||
self.cta_order = str(cta_order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#triton_gpu.shared<{{vec = {self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}}}>"
|
||||
return f"#{GPU_DIALECT}.shared<{{vec={self.vec}, perPhase={self.per_phase}, maxPhase={self.max_phase}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("vec_size", [2, 4])
|
||||
@@ -2501,17 +2521,20 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
if vec_size != 4:
|
||||
pytest.skip()
|
||||
|
||||
shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0])
|
||||
shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0])
|
||||
blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
|
||||
ir = f"""
|
||||
#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}}>
|
||||
#blocked = {blocked}
|
||||
#shared1 = {shared_a}
|
||||
#shared2 = {shared_b}
|
||||
#mfma = {mfma}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
|
||||
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
|
||||
tt.func public @kernel_0d1d2d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mfma>
|
||||
%cst_0 = arith.constant dense<32> : tensor<32x1xi32, #blocked>
|
||||
%0 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked>
|
||||
@@ -2533,12 +2556,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
%17 = tt.addptr %16, %8 : tensor<32x32x!tt.ptr<f16>, #blocked>, tensor<32x32xi32, #blocked>
|
||||
%18 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
|
||||
%19 = triton_gpu.convert_layout %18 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared1>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
|
||||
%20 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #shared1>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>>
|
||||
%21 = tt.load %13 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf16, #blocked>
|
||||
%22 = triton_gpu.convert_layout %21 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #shared2>
|
||||
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>>
|
||||
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #triton_gpu.mfma<{nonKDim = 32, kDim = 8, warpsPerCTA = [4, 1], isTransposed = false}>, kWidth=4}>> -> tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>
|
||||
%25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #triton_gpu.mfma<{nonKDim = 32, warpsPerCTA = [4, 1], isTransposed = false}>>) -> tensor<32x32xf32, #blocked>
|
||||
%23 = triton_gpu.convert_layout %22 : (tensor<32x32xf16, #shared2>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=4}>>
|
||||
%24 = tt.dot %20, %23, %cst {allowTF32 = false} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth=8}>> -> tensor<32x32xf32, #mfma>
|
||||
%25 = triton_gpu.convert_layout %24 : (tensor<32x32xf32, #mfma>) -> tensor<32x32xf32, #blocked>
|
||||
%26 = arith.truncf %25 : tensor<32x32xf32, #blocked> to tensor<32x32xf16, #blocked>
|
||||
tt.store %17, %26 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf16, #blocked>
|
||||
tt.return
|
||||
@@ -2586,28 +2609,28 @@ if _get_warp_size() == 64:
|
||||
# MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
# MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 2], [2, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 64], [2, 2], [1, 0]),
|
||||
BlockedLayout([4, 2], [16, 4], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 2], [8, 8], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 2], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 2], [1, 64], [4, 1], [1, 0])
|
||||
BlockedLayout([1, 2], [2, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 1], [1, 64], [2, 2], [1, 0],[1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 2], [16, 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 2], [8, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 2], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 2], [1, 64], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1])
|
||||
]
|
||||
else:
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]),
|
||||
# MmaLayout(version=1, warps_per_cta=[4, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1]),
|
||||
BlockedLayout([1, 2], [2, 16], [2, 2], [1, 0]),
|
||||
BlockedLayout([2, 2], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([4, 2], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 2], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 2], [4, 8], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [16, 2], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 2], [1, 32], [4, 1], [1, 0])
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1]),
|
||||
BlockedLayout([1, 2], [2, 16], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([2, 2], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 2], [16, 2], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 2], [8, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 2], [4, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 1], [16, 2], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 2], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1])
|
||||
]
|
||||
|
||||
|
||||
@@ -2625,7 +2648,7 @@ def test_convert2d(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} {
|
||||
module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} {
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<128> : tensor<128x1xi32, #src>
|
||||
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
@@ -2665,15 +2688,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
|
||||
|
||||
if torch.version.hip is not None and _get_warp_size() == 64:
|
||||
layouts = [
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], isTransposed=False),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
]
|
||||
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
|
||||
else:
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]),
|
||||
MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 16, 16]),
|
||||
]
|
||||
shapes = [[128, 16], [128, 128], [32, 128]]
|
||||
|
||||
@@ -2683,15 +2709,16 @@ else:
|
||||
@pytest.mark.parametrize("axis", [0, 1])
|
||||
def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
if torch.version.hip is not None and _get_warp_size() == 64:
|
||||
if src_layout.isTransposed and axis == 0:
|
||||
if src_layout.is_transposed and axis == 0:
|
||||
pytest.skip("Reduce along axis 0 is not supported in transposed mfma layout")
|
||||
rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1"
|
||||
rdims_1d = f"{N}" if axis == 0 else f"{M}"
|
||||
store_range = "%7" if axis == 0 else "%1"
|
||||
blocked = BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1])
|
||||
ir = f"""
|
||||
#blocked = #triton_gpu.blocked<{{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1]}}>
|
||||
#blocked = {blocked}
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {_get_warp_size()} : i32}} {{
|
||||
module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = {_get_warp_size()} : i32}} {{
|
||||
tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked>
|
||||
@@ -2747,17 +2774,17 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
|
||||
|
||||
scan_layouts = [
|
||||
BlockedLayout([1, 4], [4, 16], [4, 1], [0, 1]),
|
||||
BlockedLayout([1, 4], [8, 8], [4, 1], [0, 1]),
|
||||
BlockedLayout([4, 1], [4, 16], [1, 4], [0, 1]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [0, 1]),
|
||||
BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 4], [4, 16], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 4], [8, 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 1], [4, 16], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([2, 2], [8, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
|
||||
BlockedLayout([1, 4], [4, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 8], [4, 1], [1, 0]),
|
||||
BlockedLayout([4, 1], [4, 16], [1, 4], [1, 0]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0]),
|
||||
BlockedLayout([2, 2], [8, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 4], [8, 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([4, 1], [4, 16], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([2, 2], [4, 16], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([2, 2], [8, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
]
|
||||
|
||||
|
||||
@@ -2767,7 +2794,7 @@ scan_layouts = [
|
||||
def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
ir = f"""
|
||||
#blocked = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{
|
||||
module attributes {{"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 64 : i32}} {{
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>
|
||||
@@ -2817,14 +2844,15 @@ def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
|
||||
@pytest.mark.parametrize("shape", [(64, 64)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], isTransposed=False), MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], isTransposed=True)])
|
||||
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0])])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])])
|
||||
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])])
|
||||
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
#dst = {dst_layout}
|
||||
""" + """
|
||||
module attributes {"triton_gpu.num-warps" = """ + str(128 // _get_warp_size()) + """ : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} {
|
||||
module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = """ + str(128 // _get_warp_size()) + """ : i32, "triton_gpu.threads-per-warp" = """ + str(_get_warp_size()) + """ : i32} {
|
||||
tt.func public @kernel_0d1d(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%cst = arith.constant dense<64> : tensor<64x1xi32, #src>
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>
|
||||
|
||||
Reference in New Issue
Block a user