From f866b2a513186efc5e57b4c90eea10243d8a9a72 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 26 Jan 2026 21:11:37 -0500 Subject: [PATCH] mfma loop in asm dsl (#14349) * mfma loop in asm dsl * work --- extra/mmapeak/mmapeak.py | 14 ++++++++++++-- extra/mmapeak/template.s | 7 ------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/extra/mmapeak/mmapeak.py b/extra/mmapeak/mmapeak.py index 43e61b0279..e05b6ec362 100644 --- a/extra/mmapeak/mmapeak.py +++ b/extra/mmapeak/mmapeak.py @@ -5,7 +5,7 @@ os.environ["AMD_AQL"] = "1" from tinygrad.device import Device from tinygrad.runtime.support.compiler_amd import HIPCompiler -from extra.assembly.amd.dsl import Reg +from extra.assembly.amd.dsl import Reg, Inst, s, v NUM_WORKGROUPS = 96 WAVE_SIZE = 32 @@ -17,6 +17,14 @@ DIRECTIVE = ".amdhsa_wavefront_size32 1" assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text() +def repeat(insts:list[Inst], n:int, counter_sreg:Reg) -> bytes: + preamble = s_mov_b32(counter_sreg, n).to_bytes() + insts_bytes = b"".join([inst.to_bytes() for inst in insts]) + sub_inst, cmp_inst = s_sub_u32(counter_sreg, counter_sreg, 1), s_cmp_lg_i32(counter_sreg, 0) + loop_sz = len(insts_bytes) + sub_inst.size() + cmp_inst.size() + branch_inst = s_cbranch_scc1(simm16=-((loop_sz // 4) + 1) & 0xFFFF) + return preamble + insts_bytes + sub_inst.to_bytes() + cmp_inst.to_bytes() + branch_inst.to_bytes() + s_endpgm().to_bytes() + def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs): if accum: inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1, acc_cd=1, **kwargs) @@ -27,7 +35,7 @@ def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs) vgprs:set = set() for n,_ in inst._fields: if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)} - inst_bytes = b"".join([inst.to_bytes() for _ in range(INSTRUCTIONS_PER_LOOP)]) + inst_bytes = repeat([inst for _ in range(INSTRUCTIONS_PER_LOOP)], n=INTERNAL_LOOP, counter_sreg=s[1]) inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n" src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs))) src = src.replace("DIRECTIVE", DIRECTIVE) @@ -54,6 +62,8 @@ if __name__=="__main__": launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,11)) elif arch in {'gfx1200', 'gfx1201'}: from extra.assembly.amd.autogen.rdna4.ins import * + # this instruction does not exist in the rdna4 isa, use the co version + s_sub_u32 = s_sub_co_u32 NUM_WORKGROUPS = 64 launchBenchmark(v_wmma_bf16_16x16x16_bf16, (3,4,7)) launchBenchmark(v_wmma_f16_16x16x16_f16, (3,4,7)) diff --git a/extra/mmapeak/template.s b/extra/mmapeak/template.s index f8b60d52eb..b915b0e595 100644 --- a/extra/mmapeak/template.s +++ b/extra/mmapeak/template.s @@ -3,14 +3,7 @@ .p2align 8 .type matmul,@function matmul: - s_mov_b32 s1, INTERNAL_LOOP - s_mov_b32 s2, 0 - inner_loop: INSTRUCTION - s_sub_u32 s1, s1, 1 - s_cmp_lg_i32 s1, s2 - s_cbranch_scc1 inner_loop - s_endpgm .rodata .p2align 6