Files
tinygrad/extra/mmapeak/mmapeak.py
2025-12-31 16:52:23 +09:00

101 lines
5.1 KiB
Python

import os, pathlib
# TODO: there is a timing bug without this
os.environ["AMD_AQL"] = "1"
from tinygrad.device import Device
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.runtime.ops_amd import AMDProgram
NUM_WORKGROUPS = 96
WAVE_SIZE = 32
NUM_WAVES = 2
FLOPS_PER_MATMUL = 16*16*16*2
INTERNAL_LOOP = 1_000_00
INSTRUCTIONS_PER_LOOP = 200
DIRECTIVE = ".amdhsa_wavefront_size32 1"
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, extra=""):
if accum:
instructions = "{} a[0:{}], v[{}:{}], v[{}:{}], 1{}\n".format(instruction, vgprIndices[0],
vgprIndices[1], vgprIndices[2],
vgprIndices[1], vgprIndices[2], extra)
elif dense:
instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], 1\n".format(instruction, vgprIndices[0],
vgprIndices[1], vgprIndices[2],
vgprIndices[1], vgprIndices[2])
else:
instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], v{}\n".format(instruction, vgprIndices[0],
vgprIndices[1], vgprIndices[2],
vgprIndices[3], vgprIndices[4],
vgprIndices[5])
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", instructions*INSTRUCTIONS_PER_LOOP)
src = src.replace("DIRECTIVE", DIRECTIVE)
lib = COMPILER.compile(src)
fxn = AMDProgram(DEV, "matmul", lib)
elapsed = min([fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) for _ in range(2)])
FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
print(f"{instruction:<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS")
if __name__=="__main__":
DEVICENUM = os.getenv("DEVICENUM", "0")
try:
DEV = Device['AMD:' + DEVICENUM]
except:
raise RuntimeError("Error while initiating AMD device")
COMPILER = HIPCompiler(DEV.arch)
if DEV.arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
if DEV.arch == 'gfx1103': NUM_WORKGROUPS = 8
if DEV.arch == 'gfx1151': NUM_WORKGROUPS = 32
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15))
launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15))
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15))
launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,15))
launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,9))
launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,11))
elif DEV.arch == 'gfx1201':
NUM_WORKGROUPS = 64
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (3,4,7))
launchBenchmark("v_wmma_f16_16x16x16_f16", (3,4,7))
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,11))
launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,11))
launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,8))
launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_fp8_fp8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_fp8_bf8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_bf8_fp8", (7,8,9))
launchBenchmark("v_wmma_f32_16x16x16_bf8_bf8", (7,8,9))
FLOPS_PER_MATMUL = 16*16*32*2
launchBenchmark("v_wmma_i32_16X16X32_iu4", (7,8,9))
launchBenchmark("v_swmmac_f32_16x16x32_f16", (7,8,11,12,19,20), False)
launchBenchmark("v_swmmac_f32_16x16x32_bf16", (7,8,11,12,19,20), False)
launchBenchmark("v_swmmac_f16_16x16x32_f16", (3,4,7,8,15,16), False)
launchBenchmark("v_swmmac_bf16_16x16x32_bf16", (3,4,7,8,15,16), False)
launchBenchmark("v_swmmac_i32_16x16x32_iu8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_i32_16x16x32_iu4", (7,8,8,9,10,11), False)
launchBenchmark("v_swmmac_f32_16x16x32_fp8_fp8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_f32_16x16x32_fp8_bf8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_f32_16x16x32_bf8_fp8", (7,8,9,10,13,14), False)
launchBenchmark("v_swmmac_f32_16x16x32_bf8_bf8", (7,8,9,10,13,14), False)
FLOPS_PER_MATMUL = 16*16*64*2
launchBenchmark("v_swmmac_i32_16x16x64_iu4", (7,8,9,10,13,14), False)
elif DEV.arch == 'gfx950':
DIRECTIVE = ".amdhsa_accum_offset 4"
NUM_WORKGROUPS = 256
WAVE_SIZE = 64
NUM_WAVES = 4
launchBenchmark("v_mfma_f32_16x16x16_f16", (3,0,1), accum=True)
launchBenchmark("v_mfma_f32_16x16x16_bf16", (3,0,1), accum=True)
FLOPS_PER_MATMUL = 16*16*32*2
launchBenchmark("v_mfma_f32_16x16x32_f16", (3,0,3), accum=True)
launchBenchmark("v_mfma_f32_16x16x32_bf16", (3,0,3), accum=True)
FLOPS_PER_MATMUL = 16*16*128*2
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,7), accum=True) # fp8
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,5), accum=True, extra=", cbsz:2 blgp:2") # fp6
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,3), accum=True, extra=", cbsz:4 blgp:4") # fp4
else:
raise RuntimeError(f"arch {DEV.arch} not supported.")