mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Backend] Refactor mfma selection (#441)
* Select mfma dimensions and instruction from static table * Extend mfmaLayout to include version and instrShape * Simplify generateMFMAOp by searching the mfma instruction in the table * Fix getNonKDim() and non_k_dim * Break instrShape into MDim and NDim
This commit is contained in:
@@ -1914,11 +1914,11 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
|
||||
if backend.get_matrix_core_version() > 0:
|
||||
ttgir = pgm.asm['ttgir']
|
||||
if non_k_dim == 16:
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" in ttgir
|
||||
assert "#triton_gpu.mfma<{nonKDim = 32" not in ttgir
|
||||
assert "instrShape = [16, 16]" in ttgir
|
||||
assert "instrShape = [32, 32]" not in ttgir
|
||||
elif non_k_dim == 32:
|
||||
assert "#triton_gpu.mfma<{nonKDim = 32" in ttgir
|
||||
assert "#triton_gpu.mfma<{nonKDim = 16" not in ttgir
|
||||
assert "instrShape = [32, 32]" in ttgir
|
||||
assert "instrShape = [16, 16]" not in ttgir
|
||||
gcn = pgm.asm['amdgcn']
|
||||
if backend.get_matrix_core_version() == 3 and effective_in_dtype == tl.float8e5b16:
|
||||
assert "v_mfma_f32_32x32x16_bf8_bf8" in gcn or "v_mfma_f32_16x16x32_bf8_bf8" in gcn
|
||||
@@ -2709,13 +2709,14 @@ class MmaLayout:
|
||||
|
||||
|
||||
class MfmaLayout:
|
||||
def __init__(self, non_k_dim, warps_per_cta, is_transposed):
|
||||
self.non_k_dim = str(non_k_dim)
|
||||
def __init__(self, version, warps_per_cta, instr_shape, is_transposed):
|
||||
self.version = version
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.instr_shape = str(instr_shape)
|
||||
self.is_transposed = str(is_transposed).lower()
|
||||
|
||||
def __str__(self):
|
||||
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}}}>"
|
||||
return f"#{GPU_DIALECT}.mfma<{{versionMajor={self.version[0]}, versionMinor={self.version[1]}, warpsPerCTA = {self.warps_per_cta}, instrShape={self.instr_shape}, isTransposed = {self.is_transposed}}}>"
|
||||
|
||||
|
||||
class BlockedLayout:
|
||||
@@ -2782,7 +2783,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False)
|
||||
mfma = MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[32,32], is_transposed=False)
|
||||
|
||||
ir = f"""
|
||||
#blocked = {blocked}
|
||||
@@ -2947,8 +2948,8 @@ view_layouts = [
|
||||
BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False),
|
||||
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[32,32], is_transposed=True),
|
||||
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[32,32], is_transposed=False),
|
||||
]
|
||||
blocked_layouts = [
|
||||
BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
|
||||
@@ -3018,8 +3019,8 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
|
||||
|
||||
if torch.version.hip is not None and _get_warp_size() == 64:
|
||||
layouts = [
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False),
|
||||
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[32,32], is_transposed=True),
|
||||
MfmaLayout(version=(2,0), warps_per_cta=[2, 2], instr_shape=[32,32], is_transposed=False),
|
||||
]
|
||||
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
|
||||
else:
|
||||
@@ -3177,8 +3178,8 @@ def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
|
||||
@pytest.mark.parametrize("shape", [(64, 64)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True)])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(version=(2,0), warps_per_cta=[2, 1], instr_shape=[32,32], is_transposed=False),
|
||||
MfmaLayout(version=(2,0), warps_per_cta=[4, 1], instr_shape=[32,32], is_transposed=True)])
|
||||
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])])
|
||||
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
ir = f"""
|
||||
|
||||
Reference in New Issue
Block a user