mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
[MFMA] Remove CTA related code from layout (#429)
This PR removes CTALayout attribute from MFMA layout, because it is NV specific.
This commit is contained in:
@@ -2700,16 +2700,13 @@ class MmaLayout:
|
||||
|
||||
|
||||
class MfmaLayout:
|
||||
def __init__(self, non_k_dim, warps_per_cta, is_transposed, ctas_per_cga, cta_split_num, cta_order):
|
||||
def __init__(self, non_k_dim, warps_per_cta, is_transposed):
|
||||
self.non_k_dim = str(non_k_dim)
|
||||
self.warps_per_cta = str(warps_per_cta)
|
||||
self.is_transposed = str(is_transposed).lower()
|
||||
self.ctas_per_cga = str(ctas_per_cga)
|
||||
self.cta_split_num = str(cta_split_num)
|
||||
self.cta_order = str(cta_order)
|
||||
|
||||
def __str__(self):
|
||||
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>"
|
||||
return f"#{GPU_DIALECT}.mfma<{{nonKDim = {self.non_k_dim}, warpsPerCTA = {self.warps_per_cta}, isTransposed = {self.is_transposed}}}>"
|
||||
|
||||
|
||||
class BlockedLayout:
|
||||
@@ -2776,7 +2773,7 @@ def test_dot_mfma_vector_load(vec_size, swizzle, transposeA, transposeB):
|
||||
blocked = BlockedLayout(size_per_thread=[1, 4], threads_per_warp=[8, 8], warps_per_cta=[4, 1], order=[1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_a = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeA else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
shared_b = SharedLayout(vec=vec_size, per_phase=1, max_phase=max_phase, order=[0, 1] if transposeB else [1, 0], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1])
|
||||
mfma = MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=False)
|
||||
|
||||
ir = f"""
|
||||
#blocked = {blocked}
|
||||
@@ -2938,8 +2935,8 @@ module attributes {"triton_gpu.num-ctas" = 1, "triton_gpu.num-warps" = 4 : i32,
|
||||
|
||||
if torch.version.hip is not None and _get_warp_size() == 64:
|
||||
layouts = [
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[2, 2], is_transposed=False),
|
||||
]
|
||||
shapes = [[128, 32], [128, 128], [32, 128], [64, 64]]
|
||||
else:
|
||||
@@ -3097,8 +3094,8 @@ def test_scan_layouts(M, N, src_layout, axis, device):
|
||||
|
||||
@pytest.mark.parametrize("shape", [(64, 64)])
|
||||
@pytest.mark.parametrize("dtype", ['float16'])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0]),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True, ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])])
|
||||
@pytest.mark.parametrize("src_layout", [MfmaLayout(non_k_dim=32, warps_per_cta=[2, 1], is_transposed=False),
|
||||
MfmaLayout(non_k_dim=32, warps_per_cta=[4, 1], is_transposed=True)])
|
||||
@pytest.mark.parametrize("dst_layout", [BlockedLayout([1, 4], [4, 16], [1, 1], [1, 0], [1, 1], [1, 1], [0, 1])])
|
||||
def test_make_range(dtype, shape, src_layout, dst_layout, device='cuda'):
|
||||
ir = f"""
|
||||
|
||||
Reference in New Issue
Block a user