use amdgpu dsl in mmapeak (#14342)

* use amdgpu dsl in mmapeak

* don't rely on llvm for vgpr counting

* llvm roundtrip assert

* rm it, add ci

* vgpr_count

* move emulated test to amd, it needs comgr

* env

* arch

* inst._fields -> inst.operands

* vgpr offset
This commit is contained in:
qazal
2026-01-26 08:03:43 -05:00
committed by GitHub
parent b2e2ace85b
commit 2d91fe6310
3 changed files with 58 additions and 52 deletions

View File

@@ -652,6 +652,10 @@ jobs:
run: |
VIZ=-2 DEBUG=5 python3 test/test_ops.py TestOps.test_add
extra/sqtt/rgptool.py create "/tmp/profile.pkl.$USER" -o /tmp/gpu0.rgp
- name: Run AMD emulated mmapeak on NULL backend
env:
AMD: 0
run: PYTHONPATH=. NULL=1 EMULATE=AMD python extra/mmapeak/mmapeak.py
- name: Run process replay tests
uses: ./.github/actions/process-replay

View File

@@ -5,6 +5,7 @@ os.environ["AMD_AQL"] = "1"
from tinygrad.device import Device
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from extra.assembly.amd.dsl import Reg
NUM_WORKGROUPS = 96
WAVE_SIZE = 32
@@ -16,27 +17,25 @@ DIRECTIVE = ".amdhsa_wavefront_size32 1"
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, extra=""):
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs):
if accum:
instructions = "{} a[0:{}], v[{}:{}], v[{}:{}], 1{}\n".format(instruction, vgprIndices[0],
vgprIndices[1], vgprIndices[2],
vgprIndices[1], vgprIndices[2], extra)
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1, acc_cd=1, **kwargs)
elif dense:
instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], 1\n".format(instruction, vgprIndices[0],
vgprIndices[1], vgprIndices[2],
vgprIndices[1], vgprIndices[2])
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1)
else:
instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], v{}\n".format(instruction, vgprIndices[0],
vgprIndices[1], vgprIndices[2],
vgprIndices[3], vgprIndices[4],
vgprIndices[5])
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", instructions*INSTRUCTIONS_PER_LOOP)
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[3]:vgprIndices[4]], v[vgprIndices[5]])
vgprs:set = set()
for n,_ in inst._fields:
if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)}
inst_bytes = b"".join([inst.to_bytes() for _ in range(INSTRUCTIONS_PER_LOOP)])
inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n"
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs)))
src = src.replace("DIRECTIVE", DIRECTIVE)
lib = COMPILER.compile(src)
fxn = DEV.runtime("matmul", lib)
elapsed = min([fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) for _ in range(2)])
FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
print(f"{instruction:<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS")
print(f"{inst.op_name.lower():<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS")
if __name__=="__main__":
DEV = Device[Device.DEFAULT]
@@ -44,53 +43,56 @@ if __name__=="__main__":
COMPILER = HIPCompiler(arch)
if arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
from extra.assembly.amd.autogen.rdna3.ins import *
if arch == 'gfx1103': NUM_WORKGROUPS = 8
if arch == 'gfx1151': NUM_WORKGROUPS = 32
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15))
launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15))
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15))
launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,15))
launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,9))
launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,11))
elif arch == 'gfx1201':
launchBenchmark(v_wmma_bf16_16x16x16_bf16, (7,8,15))
launchBenchmark(v_wmma_f16_16x16x16_f16, (7,8,15))
launchBenchmark(v_wmma_f32_16x16x16_bf16, (7,8,15))
launchBenchmark(v_wmma_f32_16x16x16_f16, (7,8,15))
launchBenchmark(v_wmma_i32_16x16x16_iu4, (7,8,9))
launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,11))
elif arch in {'gfx1200', 'gfx1201'}:
from extra.assembly.amd.autogen.rdna4.ins import *
NUM_WORKGROUPS = 64
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (3,4,7))
launchBenchmark("v_wmma_f16_16x16x16_f16", (3,4,7))
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,11))
launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,11))
launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,8))
launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_fp8_fp8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_fp8_bf8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_bf8_fp8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_bf8_bf8", (7,8,9))
launchBenchmark(v_wmma_bf16_16x16x16_bf16, (3,4,7))
launchBenchmark(v_wmma_f16_16x16x16_f16, (3,4,7))
launchBenchmark(v_wmma_f32_16x16x16_bf16, (7,8,11))
launchBenchmark(v_wmma_f32_16x16x16_f16, (7,8,11))
launchBenchmark(v_wmma_i32_16x16x16_iu4, (7,8,8))
launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,9))
launchBenchmark(v_wmma_f32_16x16x16_fp8_fp8, (7,8,9))
launchBenchmark(v_wmma_f32_16x16x16_fp8_bf8, (7,8,9))
launchBenchmark(v_wmma_f32_16x16x16_bf8_fp8, (7,8,9))
launchBenchmark(v_wmma_f32_16x16x16_bf8_bf8, (7,8,9))
FLOPS_PER_MATMUL = 16*16*32*2
launchBenchmark("v_wmma_i32_16X16X32_iu4", (7,8,9))
launchBenchmark("v_swmmac_f32_16x16x32_f16", (7,8,11,12,19,20), False)
launchBenchmark("v_swmmac_f32_16x16x32_bf16", (7,8,11,12,19,20), False)
launchBenchmark("v_swmmac_f16_16x16x32_f16", (3,4,7,8,15,16), False)
launchBenchmark("v_swmmac_bf16_16x16x32_bf16", (3,4,7,8,15,16), False)
launchBenchmark("v_swmmac_i32_16x16x32_iu8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_i32_16x16x32_iu4", (7,8,8,9,10,11), False)
launchBenchmark("v_swmmac_f32_16x16x32_fp8_fp8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_f32_16x16x32_fp8_bf8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_f32_16x16x32_bf8_fp8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_f32_16x16x32_bf8_bf8", (7,8,9,10,13,14), False)
launchBenchmark(v_wmma_i32_16x16x32_iu4, (7,8,9))
launchBenchmark(v_swmmac_f32_16x16x32_f16, (7,8,11,12,19,20), False)
launchBenchmark(v_swmmac_f32_16x16x32_bf16, (7,8,11,12,19,20), False)
launchBenchmark(v_swmmac_f16_16x16x32_f16, (3,4,7,8,15,16), False)
launchBenchmark(v_swmmac_bf16_16x16x32_bf16, (3,4,7,8,15,16), False)
launchBenchmark(v_swmmac_i32_16x16x32_iu8, (7,8,9,10,13,14), False)
launchBenchmark(v_swmmac_i32_16x16x32_iu4, (7,8,8,9,10,11), False)
launchBenchmark(v_swmmac_f32_16x16x32_fp8_fp8, (7,8,9,10,13,14), False)
launchBenchmark(v_swmmac_f32_16x16x32_fp8_bf8, (7,8,9,10,13,14), False)
launchBenchmark(v_swmmac_f32_16x16x32_bf8_fp8, (7,8,9,10,13,14), False)
launchBenchmark(v_swmmac_f32_16x16x32_bf8_bf8, (7,8,9,10,13,14), False)
FLOPS_PER_MATMUL = 16*16*64*2
launchBenchmark("v_swmmac_i32_16x16x64_iu4", (7,8,9,10,13,14), False)
launchBenchmark(v_swmmac_i32_16x16x64_iu4, (7,8,9,10,13,14), False)
elif arch == 'gfx950':
from extra.assembly.amd.autogen.cdna.ins import *
DIRECTIVE = ".amdhsa_accum_offset 4"
NUM_WORKGROUPS = 256
WAVE_SIZE = 64
NUM_WAVES = 4
launchBenchmark("v_mfma_f32_16x16x16_f16", (3,0,1), accum=True)
launchBenchmark("v_mfma_f32_16x16x16_bf16", (3,0,1), accum=True)
launchBenchmark(v_mfma_f32_16x16x16_f16, (3,0,1), accum=True)
launchBenchmark(v_mfma_f32_16x16x16_bf16, (3,0,1), accum=True)
FLOPS_PER_MATMUL = 16*16*32*2
launchBenchmark("v_mfma_f32_16x16x32_f16", (3,0,3), accum=True)
launchBenchmark("v_mfma_f32_16x16x32_bf16", (3,0,3), accum=True)
launchBenchmark(v_mfma_f32_16x16x32_f16, (3,0,3), accum=True)
launchBenchmark(v_mfma_f32_16x16x32_bf16, (3,0,3), accum=True)
FLOPS_PER_MATMUL = 16*16*128*2
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,7), accum=True) # fp8
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,5), accum=True, extra=", cbsz:2 blgp:2") # fp6
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,3), accum=True, extra=", cbsz:4 blgp:4") # fp4
launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,7), accum=True) # fp8
launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,5), accum=True, cbsz=2, blgp=2) # fp6
launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,3), accum=True, cbsz=4, blgp=4) # fp4
else:
raise RuntimeError(f"arch {arch} not supported.")

View File

@@ -15,8 +15,8 @@ matmul:
.rodata
.p2align 6
.amdhsa_kernel matmul
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
.amdhsa_next_free_vgpr VGPR_COUNT
.amdhsa_next_free_sgpr 3
DIRECTIVE
.end_amdhsa_kernel
@@ -37,4 +37,4 @@ amdhsa.kernels:
.vgpr_count: 32
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata
.end_amdgpu_metadata