[MFMA] Remove CTA related code from layout (#429)

This PR removes CTALayout attribute from MFMA layout, because it is NV specific.
This commit is contained in:
Alexander Efimov
2023-12-27 20:01:28 +03:00
committed by GitHub
parent 1e2fd0dd1a
commit 98589ac013
8 changed files with 52 additions and 73 deletions

View File

@@ -2700,16 +2700,13 @@ class MmaLayout:
class MfmaLayout:
def __init__(self, non_k_dim, warps_per_cta, is_transposed, ctas_per_cga, cta_split_num, cta_order):
def __init__(self, non_k_dim, warps_per_cta, is_transposed):
self.non_k_dim = str(non_k_dim)
self.warps_per_cta = str(warps_per_cta)
self.is_transposed = str(is_transposed).lower()
self.ctas_per_cga = str(ctas_per_cga)
self.cta_split_num = str(cta_split_num)
self.cta_order = str(cta_order)
def __str__(self):
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}}}>"
class BlockedLayout:
@@ -2776,7 +2773,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False)
ir = f"""
#blocked = {blocked}
@@ -2938,8 +2935,8 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
if torch.version.hip is not None and _get_warp_size() == 64:
layouts = [
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False),
]
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
else:
@@ -3097,8 +3094,8 @@ def test_scan_layouts(M, N, src_layout, axis, device):
@pytest.mark.parametrize("shape", [(64, 64)])
@pytest.mark.parametrize("dtype", ['float16'])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])])
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False),
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True)])
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])])
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
ir = f"""