mfma loop in asm dsl (#14349)

* mfma loop in asm dsl

* work
This commit is contained in:
qazal
2026-01-26 21:11:37 -05:00
committed by GitHub
parent 0793319929
commit f866b2a513
2 changed files with 12 additions and 9 deletions

View File

@@ -5,7 +5,7 @@ os.environ["AMD_AQL"] = "1"
from tinygrad.device import Device
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from extra.assembly.amd.dsl import Reg
from extra.assembly.amd.dsl import Reg, Inst, s, v
NUM_WORKGROUPS = 96
WAVE_SIZE = 32
@@ -17,6 +17,14 @@ DIRECTIVE = ".amdhsa_wavefront_size32 1"
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
def repeat(insts:list[Inst], n:int, counter_sreg:Reg) -> bytes:
preamble = s_mov_b32(counter_sreg, n).to_bytes()
insts_bytes = b"".join([inst.to_bytes() for inst in insts])
sub_inst, cmp_inst = s_sub_u32(counter_sreg, counter_sreg, 1), s_cmp_lg_i32(counter_sreg, 0)
loop_sz = len(insts_bytes) + sub_inst.size() + cmp_inst.size()
branch_inst = s_cbranch_scc1(simm16=-((loop_sz // 4) + 1) & 0xFFFF)
return preamble + insts_bytes + sub_inst.to_bytes() + cmp_inst.to_bytes() + branch_inst.to_bytes() + s_endpgm().to_bytes()
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs):
if accum:
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1, acc_cd=1, **kwargs)
@@ -27,7 +35,7 @@ def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs)
vgprs:set = set()
for n,_ in inst._fields:
if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)}
inst_bytes = b"".join([inst.to_bytes() for _ in range(INSTRUCTIONS_PER_LOOP)])
inst_bytes = repeat([inst for _ in range(INSTRUCTIONS_PER_LOOP)], n=INTERNAL_LOOP, counter_sreg=s[1])
inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n"
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs)))
src = src.replace("DIRECTIVE", DIRECTIVE)
@@ -54,6 +62,8 @@ if __name__=="__main__":
launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,11))
elif arch in {'gfx1200', 'gfx1201'}:
from extra.assembly.amd.autogen.rdna4.ins import *
# this instruction does not exist in the rdna4 isa, use the co version
s_sub_u32 = s_sub_co_u32
NUM_WORKGROUPS = 64
launchBenchmark(v_wmma_bf16_16x16x16_bf16, (3,4,7))
launchBenchmark(v_wmma_f16_16x16x16_f16, (3,4,7))

View File

@@ -3,14 +3,7 @@
.p2align 8
.type matmul,@function
matmul:
s_mov_b32 s1, INTERNAL_LOOP
s_mov_b32 s2, 0
inner_loop:
INSTRUCTION
s_sub_u32 s1, s1, 1
s_cmp_lg_i32 s1, s2
s_cbranch_scc1 inner_loop
s_endpgm
.rodata
.p2align 6