mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -5,7 +5,7 @@ os.environ["AMD_AQL"] = "1"
|
||||
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.runtime.support.compiler_amd import HIPCompiler
|
||||
from extra.assembly.amd.dsl import Reg
|
||||
from extra.assembly.amd.dsl import Reg, Inst, s, v
|
||||
|
||||
NUM_WORKGROUPS = 96
|
||||
WAVE_SIZE = 32
|
||||
@@ -17,6 +17,14 @@ DIRECTIVE = ".amdhsa_wavefront_size32 1"
|
||||
|
||||
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
|
||||
|
||||
def repeat(insts:list[Inst], n:int, counter_sreg:Reg) -> bytes:
|
||||
preamble = s_mov_b32(counter_sreg, n).to_bytes()
|
||||
insts_bytes = b"".join([inst.to_bytes() for inst in insts])
|
||||
sub_inst, cmp_inst = s_sub_u32(counter_sreg, counter_sreg, 1), s_cmp_lg_i32(counter_sreg, 0)
|
||||
loop_sz = len(insts_bytes) + sub_inst.size() + cmp_inst.size()
|
||||
branch_inst = s_cbranch_scc1(simm16=-((loop_sz // 4) + 1) & 0xFFFF)
|
||||
return preamble + insts_bytes + sub_inst.to_bytes() + cmp_inst.to_bytes() + branch_inst.to_bytes() + s_endpgm().to_bytes()
|
||||
|
||||
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs):
|
||||
if accum:
|
||||
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1, acc_cd=1, **kwargs)
|
||||
@@ -27,7 +35,7 @@ def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs)
|
||||
vgprs:set = set()
|
||||
for n,_ in inst._fields:
|
||||
if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)}
|
||||
inst_bytes = b"".join([inst.to_bytes() for _ in range(INSTRUCTIONS_PER_LOOP)])
|
||||
inst_bytes = repeat([inst for _ in range(INSTRUCTIONS_PER_LOOP)], n=INTERNAL_LOOP, counter_sreg=s[1])
|
||||
inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n"
|
||||
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs)))
|
||||
src = src.replace("DIRECTIVE", DIRECTIVE)
|
||||
@@ -54,6 +62,8 @@ if __name__=="__main__":
|
||||
launchBenchmark(v_wmma_i32_16x16x16_iu8, (7,8,11))
|
||||
elif arch in {'gfx1200', 'gfx1201'}:
|
||||
from extra.assembly.amd.autogen.rdna4.ins import *
|
||||
# this instruction does not exist in the rdna4 isa, use the co version
|
||||
s_sub_u32 = s_sub_co_u32
|
||||
NUM_WORKGROUPS = 64
|
||||
launchBenchmark(v_wmma_bf16_16x16x16_bf16, (3,4,7))
|
||||
launchBenchmark(v_wmma_f16_16x16x16_f16, (3,4,7))
|
||||
|
||||
@@ -3,14 +3,7 @@
|
||||
.p2align 8
|
||||
.type matmul,@function
|
||||
matmul:
|
||||
s_mov_b32 s1, INTERNAL_LOOP
|
||||
s_mov_b32 s2, 0
|
||||
inner_loop:
|
||||
INSTRUCTION
|
||||
s_sub_u32 s1, s1, 1
|
||||
s_cmp_lg_i32 s1, s2
|
||||
s_cbranch_scc1 inner_loop
|
||||
s_endpgm
|
||||
|
||||
.rodata
|
||||
.p2align 6
|
||||
|
||||
Reference in New Issue
Block a user