From 2d91fe6310759f23f700ef689d9deefd4fc8d1e6 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 26 Jan 2026 08:03:43 -0500 Subject: [PATCH] use amdgpu dsl in mmapeak (#14342) * use amdgpu dsl in mmapeak * don't rely on llvm for vgpr counting * llvm roundtrip assert * rm it, add ci * vgpr_count * move emulated test to amd, it needs comgr * env * arch * inst._fields -> inst.operands * vgpr offset --- .github/workflows/test.yml | 4 ++ extra/mmapeak/mmapeak.py | 100 +++++++++++++++++++------------------ extra/mmapeak/template.s | 6 +-- 3 files changed, 58 insertions(+), 52 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f67a9f6b97..b6ab08f1ac 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -652,6 +652,10 @@ jobs: run: | VIZ=-2 DEBUG=5 python3 test/test_ops.py TestOps.test_add extra/sqtt/rgptool.py create "/tmp/profile.pkl.$USER" -o /tmp/gpu0.rgp + - name: Run AMD emulated mmapeak on NULL backend + env: + AMD: 0 + run: PYTHONPATH=. NULL=1 EMULATE=AMD python extra/mmapeak/mmapeak.py - name: Run process replay tests uses: ./.github/actions/process-replay diff --git a/extra/mmapeak/mmapeak.py b/extra/mmapeak/mmapeak.py index 7046d58fe3..43e61b0279 100644 --- a/extra/mmapeak/mmapeak.py +++ b/extra/mmapeak/mmapeak.py @@ -5,6 +5,7 @@ os.environ["AMD_AQL"] = "1" from tinygrad.device import Device from tinygrad.runtime.support.compiler_amd import HIPCompiler +from extra.assembly.amd.dsl import Reg NUM_WORKGROUPS = 96 WAVE_SIZE = 32 @@ -16,27 +17,25 @@ DIRECTIVE = ".amdhsa_wavefront_size32 1" assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text() -def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, extra=""): +def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs): if accum: - instructions = "{} a[0:{}], v[{}:{}], v[{}:{}], 1{}\n".format(instruction, vgprIndices[0], - vgprIndices[1], vgprIndices[2], - vgprIndices[1], vgprIndices[2], extra) + inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1, acc_cd=1, **kwargs) elif dense: - instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], 1\n".format(instruction, vgprIndices[0], - vgprIndices[1], vgprIndices[2], - vgprIndices[1], vgprIndices[2]) + inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1) else: - instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], v{}\n".format(instruction, vgprIndices[0], - vgprIndices[1], vgprIndices[2], - vgprIndices[3], vgprIndices[4], - vgprIndices[5]) - src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", instructions*INSTRUCTIONS_PER_LOOP) + inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[3]:vgprIndices[4]], v[vgprIndices[5]]) + vgprs:set = set() + for n,_ in inst._fields: + if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)} + inst_bytes = b"".join([inst.to_bytes() for _ in range(INSTRUCTIONS_PER_LOOP)]) + inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n" + src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs))) src = src.replace("DIRECTIVE", DIRECTIVE) lib = COMPILER.compile(src) fxn = DEV.runtime("matmul", lib) elapsed = min([fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) for _ in range(2)]) FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP - print(f"{instruction:<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS") + print(f"{inst.op_name.lower():<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS") if __name__=="__main__": DEV = Device[Device.DEFAULT] @@ -44,53 +43,56 @@ if __name__=="__main__": COMPILER = HIPCompiler(arch) if arch in {'gfx1100', 'gfx1103', 'gfx1151'}: + from extra.assembly.amd.autogen.rdna3.ins import * if arch == 'gfx1103': NUM_WORKGROUPS = 8 if arch == 'gfx1151': NUM_WORKGROUPS = 32 - launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15)) - launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15)) - launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15)) - launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,15)) - launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,9)) - launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,11)) - elif arch == 'gfx1201': + launchBenchmark(v_wmma_bf16_16x16x16_bf16, (7,8,15)) + launchBenchmark(v_wmma_f16_16x16x16_f16, (7,8,15)) + launchBenchmark(v_wmma_f32_16x16x16_bf16, (7,8,15)) + launchBenchmark(v_wmma_f32_16x16x16_f16, (7,8,15)) + launchBenchmark(v_wmma_i32_16x16x16_iu4, (7,8,9)) + launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,11)) + elif arch in {'gfx1200', 'gfx1201'}: + from extra.assembly.amd.autogen.rdna4.ins import * NUM_WORKGROUPS = 64 - launchBenchmark("v_wmma_bf16_16x16x16_bf16", (3,4,7)) - launchBenchmark("v_wmma_f16_16x16x16_f16", (3,4,7)) - launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,11)) - launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,11)) - launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,8)) - launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,9)) - launchBenchmark("v_wmma_f32_16x16x16_fp8_fp8", (7,8,9)) - launchBenchmark("v_wmma_f32_16x16x16_fp8_bf8", (7,8,9)) - launchBenchmark("v_wmma_f32_16x16x16_bf8_fp8", (7,8,9)) - launchBenchmark("v_wmma_f32_16x16x16_bf8_bf8", (7,8,9)) + launchBenchmark(v_wmma_bf16_16x16x16_bf16, (3,4,7)) + launchBenchmark(v_wmma_f16_16x16x16_f16, (3,4,7)) + launchBenchmark(v_wmma_f32_16x16x16_bf16, (7,8,11)) + launchBenchmark(v_wmma_f32_16x16x16_f16, (7,8,11)) + launchBenchmark(v_wmma_i32_16x16x16_iu4, (7,8,8)) + launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,9)) + launchBenchmark(v_wmma_f32_16x16x16_fp8_fp8, (7,8,9)) + launchBenchmark(v_wmma_f32_16x16x16_fp8_bf8, (7,8,9)) + launchBenchmark(v_wmma_f32_16x16x16_bf8_fp8, (7,8,9)) + launchBenchmark(v_wmma_f32_16x16x16_bf8_bf8, (7,8,9)) FLOPS_PER_MATMUL = 16*16*32*2 - launchBenchmark("v_wmma_i32_16X16X32_iu4", (7,8,9)) - launchBenchmark("v_swmmac_f32_16x16x32_f16", (7,8,11,12,19,20), False) - launchBenchmark("v_swmmac_f32_16x16x32_bf16", (7,8,11,12,19,20), False) - launchBenchmark("v_swmmac_f16_16x16x32_f16", (3,4,7,8,15,16), False) - launchBenchmark("v_swmmac_bf16_16x16x32_bf16", (3,4,7,8,15,16), False) - launchBenchmark("v_swmmac_i32_16x16x32_iu8", (7,8,9,10,13,14), False) - launchBenchmark("v_swmmac_i32_16x16x32_iu4", (7,8,8,9,10,11), False) - launchBenchmark("v_swmmac_f32_16x16x32_fp8_fp8", (7,8,9,10,13,14), False) - launchBenchmark("v_swmmac_f32_16x16x32_fp8_bf8", (7,8,9,10,13,14), False) - launchBenchmark("v_swmmac_f32_16x16x32_bf8_fp8", (7,8,9,10,13,14), False) - launchBenchmark("v_swmmac_f32_16x16x32_bf8_bf8", (7,8,9,10,13,14), False) + launchBenchmark(v_wmma_i32_16x16x32_iu4, (7,8,9)) + launchBenchmark(v_swmmac_f32_16x16x32_f16, (7,8,11,12,19,20), False) + launchBenchmark(v_swmmac_f32_16x16x32_bf16, (7,8,11,12,19,20), False) + launchBenchmark(v_swmmac_f16_16x16x32_f16, (3,4,7,8,15,16), False) + launchBenchmark(v_swmmac_bf16_16x16x32_bf16, (3,4,7,8,15,16), False) + launchBenchmark(v_swmmac_i32_16x16x32_iu8, (7,8,9,10,13,14), False) + launchBenchmark(v_swmmac_i32_16x16x32_iu4, (7,8,8,9,10,11), False) + launchBenchmark(v_swmmac_f32_16x16x32_fp8_fp8, (7,8,9,10,13,14), False) + launchBenchmark(v_swmmac_f32_16x16x32_fp8_bf8, (7,8,9,10,13,14), False) + launchBenchmark(v_swmmac_f32_16x16x32_bf8_fp8, (7,8,9,10,13,14), False) + launchBenchmark(v_swmmac_f32_16x16x32_bf8_bf8, (7,8,9,10,13,14), False) FLOPS_PER_MATMUL = 16*16*64*2 - launchBenchmark("v_swmmac_i32_16x16x64_iu4", (7,8,9,10,13,14), False) + launchBenchmark(v_swmmac_i32_16x16x64_iu4, (7,8,9,10,13,14), False) elif arch == 'gfx950': + from extra.assembly.amd.autogen.cdna.ins import * DIRECTIVE = ".amdhsa_accum_offset 4" NUM_WORKGROUPS = 256 WAVE_SIZE = 64 NUM_WAVES = 4 - launchBenchmark("v_mfma_f32_16x16x16_f16", (3,0,1), accum=True) - launchBenchmark("v_mfma_f32_16x16x16_bf16", (3,0,1), accum=True) + launchBenchmark(v_mfma_f32_16x16x16_f16, (3,0,1), accum=True) + launchBenchmark(v_mfma_f32_16x16x16_bf16, (3,0,1), accum=True) FLOPS_PER_MATMUL = 16*16*32*2 - launchBenchmark("v_mfma_f32_16x16x32_f16", (3,0,3), accum=True) - launchBenchmark("v_mfma_f32_16x16x32_bf16", (3,0,3), accum=True) + launchBenchmark(v_mfma_f32_16x16x32_f16, (3,0,3), accum=True) + launchBenchmark(v_mfma_f32_16x16x32_bf16, (3,0,3), accum=True) FLOPS_PER_MATMUL = 16*16*128*2 - launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,7), accum=True) # fp8 - launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,5), accum=True, extra=", cbsz:2 blgp:2") # fp6 - launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,3), accum=True, extra=", cbsz:4 blgp:4") # fp4 + launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,7), accum=True) # fp8 + launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,5), accum=True, cbsz=2, blgp=2) # fp6 + launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,3), accum=True, cbsz=4, blgp=4) # fp4 else: raise RuntimeError(f"arch {arch} not supported.") diff --git a/extra/mmapeak/template.s b/extra/mmapeak/template.s index b84aba74f3..f8b60d52eb 100644 --- a/extra/mmapeak/template.s +++ b/extra/mmapeak/template.s @@ -15,8 +15,8 @@ matmul: .rodata .p2align 6 .amdhsa_kernel matmul - .amdhsa_next_free_vgpr .amdgcn.next_free_vgpr - .amdhsa_next_free_sgpr .amdgcn.next_free_sgpr + .amdhsa_next_free_vgpr VGPR_COUNT + .amdhsa_next_free_sgpr 3 DIRECTIVE .end_amdhsa_kernel @@ -37,4 +37,4 @@ amdhsa.kernels: .vgpr_count: 32 .max_flat_workgroup_size: 1024 ... -.end_amdgpu_metadata \ No newline at end of file +.end_amdgpu_metadata