mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use amdgpu dsl in mmapeak (#14342)
* use amdgpu dsl in mmapeak * don't rely on llvm for vgpr counting * llvm roundtrip assert * rm it, add ci * vgpr_count * move emulated test to amd, it needs comgr * env * arch * inst._fields -> inst.operands * vgpr offset
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -652,6 +652,10 @@ jobs:
|
||||
run: |
|
||||
VIZ=-2 DEBUG=5 python3 test/test_ops.py TestOps.test_add
|
||||
extra/sqtt/rgptool.py create "/tmp/profile.pkl.$USER" -o /tmp/gpu0.rgp
|
||||
- name: Run AMD emulated mmapeak on NULL backend
|
||||
env:
|
||||
AMD: 0
|
||||
run: PYTHONPATH=. NULL=1 EMULATE=AMD python extra/mmapeak/mmapeak.py
|
||||
- name: Run process replay tests
|
||||
uses: ./.github/actions/process-replay
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ os.environ["AMD_AQL"] = "1"
|
||||
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from extra.assembly.amd.dsl import Reg
|
||||
|
||||
NUM_WORKGROUPS = 96
|
||||
WAVE_SIZE = 32
|
||||
@@ -16,27 +17,25 @@ DIRECTIVE = ".amdhsa_wavefront_size32 1"
|
||||
|
||||
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
|
||||
|
||||
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, extra=""):
|
||||
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs):
|
||||
if accum:
|
||||
instructions = "{} a[0:{}], v[{}:{}], v[{}:{}], 1{}\n".format(instruction, vgprIndices[0],
|
||||
vgprIndices[1], vgprIndices[2],
|
||||
vgprIndices[1], vgprIndices[2], extra)
|
||||
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1, acc_cd=1, **kwargs)
|
||||
elif dense:
|
||||
instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], 1\n".format(instruction, vgprIndices[0],
|
||||
vgprIndices[1], vgprIndices[2],
|
||||
vgprIndices[1], vgprIndices[2])
|
||||
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1)
|
||||
else:
|
||||
instructions = "{} v[0:{}], v[{}:{}], v[{}:{}], v{}\n".format(instruction, vgprIndices[0],
|
||||
vgprIndices[1], vgprIndices[2],
|
||||
vgprIndices[3], vgprIndices[4],
|
||||
vgprIndices[5])
|
||||
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", instructions*INSTRUCTIONS_PER_LOOP)
|
||||
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[3]:vgprIndices[4]], v[vgprIndices[5]])
|
||||
vgprs:set = set()
|
||||
for n,_ in inst._fields:
|
||||
if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)}
|
||||
inst_bytes = b"".join([inst.to_bytes() for _ in range(INSTRUCTIONS_PER_LOOP)])
|
||||
inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n"
|
||||
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs)))
|
||||
src = src.replace("DIRECTIVE", DIRECTIVE)
|
||||
lib = COMPILER.compile(src)
|
||||
fxn = DEV.runtime("matmul", lib)
|
||||
elapsed = min([fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) for _ in range(2)])
|
||||
FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
|
||||
print(f"{instruction:<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS")
|
||||
print(f"{inst.op_name.lower():<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS")
|
||||
|
||||
if __name__=="__main__":
|
||||
DEV = Device[Device.DEFAULT]
|
||||
@@ -44,53 +43,56 @@ if __name__=="__main__":
|
||||
|
||||
COMPILER = HIPCompiler(arch)
|
||||
if arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
|
||||
from extra.assembly.amd.autogen.rdna3.ins import *
|
||||
if arch == 'gfx1103': NUM_WORKGROUPS = 8
|
||||
if arch == 'gfx1151': NUM_WORKGROUPS = 32
|
||||
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (7,8,15))
|
||||
launchBenchmark("v_wmma_f16_16x16x16_f16", (7,8,15))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,15))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,15))
|
||||
launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,9))
|
||||
launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,11))
|
||||
elif arch == 'gfx1201':
|
||||
launchBenchmark(v_wmma_bf16_16x16x16_bf16, (7,8,15))
|
||||
launchBenchmark(v_wmma_f16_16x16x16_f16, (7,8,15))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_bf16, (7,8,15))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_f16, (7,8,15))
|
||||
launchBenchmark(v_wmma_i32_16x16x16_iu4, (7,8,9))
|
||||
launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,11))
|
||||
elif arch in {'gfx1200', 'gfx1201'}:
|
||||
from extra.assembly.amd.autogen.rdna4.ins import *
|
||||
NUM_WORKGROUPS = 64
|
||||
launchBenchmark("v_wmma_bf16_16x16x16_bf16", (3,4,7))
|
||||
launchBenchmark("v_wmma_f16_16x16x16_f16", (3,4,7))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_bf16", (7,8,11))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_f16", (7,8,11))
|
||||
launchBenchmark("v_wmma_i32_16x16x16_iu4", (7,8,8))
|
||||
launchBenchmark("v_wmma_i32_16x16x16_iu8", (7,8,9))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_fp8_fp8", (7,8,9))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_fp8_bf8", (7,8,9))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_bf8_fp8", (7,8,9))
|
||||
launchBenchmark("v_wmma_f32_16x16x16_bf8_bf8", (7,8,9))
|
||||
launchBenchmark(v_wmma_bf16_16x16x16_bf16, (3,4,7))
|
||||
launchBenchmark(v_wmma_f16_16x16x16_f16, (3,4,7))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_bf16, (7,8,11))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_f16, (7,8,11))
|
||||
launchBenchmark(v_wmma_i32_16x16x16_iu4, (7,8,8))
|
||||
launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,9))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_fp8_fp8, (7,8,9))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_fp8_bf8, (7,8,9))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_bf8_fp8, (7,8,9))
|
||||
launchBenchmark(v_wmma_f32_16x16x16_bf8_bf8, (7,8,9))
|
||||
FLOPS_PER_MATMUL = 16*16*32*2
|
||||
launchBenchmark("v_wmma_i32_16X16X32_iu4", (7,8,9))
|
||||
launchBenchmark("v_swmmac_f32_16x16x32_f16", (7,8,11,12,19,20), False)
|
||||
launchBenchmark("v_swmmac_f32_16x16x32_bf16", (7,8,11,12,19,20), False)
|
||||
launchBenchmark("v_swmmac_f16_16x16x32_f16", (3,4,7,8,15,16), False)
|
||||
launchBenchmark("v_swmmac_bf16_16x16x32_bf16", (3,4,7,8,15,16), False)
|
||||
launchBenchmark("v_swmmac_i32_16x16x32_iu8", (7,8,9,10,13,14), False)
|
||||
launchBenchmark("v_swmmac_i32_16x16x32_iu4", (7,8,8,9,10,11), False)
|
||||
launchBenchmark("v_swmmac_f32_16x16x32_fp8_fp8", (7,8,9,10,13,14), False)
|
||||
launchBenchmark("v_swmmac_f32_16x16x32_fp8_bf8", (7,8,9,10,13,14), False)
|
||||
launchBenchmark("v_swmmac_f32_16x16x32_bf8_fp8", (7,8,9,10,13,14), False)
|
||||
launchBenchmark("v_swmmac_f32_16x16x32_bf8_bf8", (7,8,9,10,13,14), False)
|
||||
launchBenchmark(v_wmma_i32_16x16x32_iu4, (7,8,9))
|
||||
launchBenchmark(v_swmmac_f32_16x16x32_f16, (7,8,11,12,19,20), False)
|
||||
launchBenchmark(v_swmmac_f32_16x16x32_bf16, (7,8,11,12,19,20), False)
|
||||
launchBenchmark(v_swmmac_f16_16x16x32_f16, (3,4,7,8,15,16), False)
|
||||
launchBenchmark(v_swmmac_bf16_16x16x32_bf16, (3,4,7,8,15,16), False)
|
||||
launchBenchmark(v_swmmac_i32_16x16x32_iu8, (7,8,9,10,13,14), False)
|
||||
launchBenchmark(v_swmmac_i32_16x16x32_iu4, (7,8,8,9,10,11), False)
|
||||
launchBenchmark(v_swmmac_f32_16x16x32_fp8_fp8, (7,8,9,10,13,14), False)
|
||||
launchBenchmark(v_swmmac_f32_16x16x32_fp8_bf8, (7,8,9,10,13,14), False)
|
||||
launchBenchmark(v_swmmac_f32_16x16x32_bf8_fp8, (7,8,9,10,13,14), False)
|
||||
launchBenchmark(v_swmmac_f32_16x16x32_bf8_bf8, (7,8,9,10,13,14), False)
|
||||
FLOPS_PER_MATMUL = 16*16*64*2
|
||||
launchBenchmark("v_swmmac_i32_16x16x64_iu4", (7,8,9,10,13,14), False)
|
||||
launchBenchmark(v_swmmac_i32_16x16x64_iu4, (7,8,9,10,13,14), False)
|
||||
elif arch == 'gfx950':
|
||||
from extra.assembly.amd.autogen.cdna.ins import *
|
||||
DIRECTIVE = ".amdhsa_accum_offset 4"
|
||||
NUM_WORKGROUPS = 256
|
||||
WAVE_SIZE = 64
|
||||
NUM_WAVES = 4
|
||||
launchBenchmark("v_mfma_f32_16x16x16_f16", (3,0,1), accum=True)
|
||||
launchBenchmark("v_mfma_f32_16x16x16_bf16", (3,0,1), accum=True)
|
||||
launchBenchmark(v_mfma_f32_16x16x16_f16, (3,0,1), accum=True)
|
||||
launchBenchmark(v_mfma_f32_16x16x16_bf16, (3,0,1), accum=True)
|
||||
FLOPS_PER_MATMUL = 16*16*32*2
|
||||
launchBenchmark("v_mfma_f32_16x16x32_f16", (3,0,3), accum=True)
|
||||
launchBenchmark("v_mfma_f32_16x16x32_bf16", (3,0,3), accum=True)
|
||||
launchBenchmark(v_mfma_f32_16x16x32_f16, (3,0,3), accum=True)
|
||||
launchBenchmark(v_mfma_f32_16x16x32_bf16, (3,0,3), accum=True)
|
||||
FLOPS_PER_MATMUL = 16*16*128*2
|
||||
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,7), accum=True) # fp8
|
||||
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,5), accum=True, extra=", cbsz:2 blgp:2") # fp6
|
||||
launchBenchmark("v_mfma_f32_16x16x128_f8f6f4", (3,0,3), accum=True, extra=", cbsz:4 blgp:4") # fp4
|
||||
launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,7), accum=True) # fp8
|
||||
launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,5), accum=True, cbsz=2, blgp=2) # fp6
|
||||
launchBenchmark(v_mfma_f32_16x16x128_f8f6f4, (3,0,3), accum=True, cbsz=4, blgp=4) # fp4
|
||||
else:
|
||||
raise RuntimeError(f"arch {arch} not supported.")
|
||||
|
||||
@@ -15,8 +15,8 @@ matmul:
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel matmul
|
||||
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
|
||||
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
|
||||
.amdhsa_next_free_vgpr VGPR_COUNT
|
||||
.amdhsa_next_free_sgpr 3
|
||||
DIRECTIVE
|
||||
.end_amdhsa_kernel
|
||||
|
||||
@@ -37,4 +37,4 @@ amdhsa.kernels:
|
||||
.vgpr_count: 32
|
||||
.max_flat_workgroup_size: 1024
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
.end_amdgpu_metadata
|
||||
|
||||
Reference in New Issue
Block a user