late compile the cdna gemm (#14783)

* late compile the cdna gemm

* remove old things

* finalize inplace

---------

Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2026-02-17 12:04:22 +08:00
committed by GitHub
parent 275319c789
commit 5bd2862d1a
3 changed files with 32 additions and 48 deletions

View File

@@ -32,7 +32,7 @@ def compute_gemm_args(M:int, N:int, K:int, batch:int) -> tuple[int, int, int, in
return NUM_WG, iters, total, magic, shift
class Kernel:
def __init__(self, name="gemm"): self.name, self.instructions, self.labels, self.label_at_pos, self.pos = name, [], {}, {}, 0
def __init__(self): self.instructions, self.labels, self.label_at_pos, self.pos = [], {}, {}, 0
def label(self, name):
self.labels[name] = self.pos
@@ -49,42 +49,12 @@ class Kernel:
waitcnt = (vmcnt & 0xF) | ((expcnt & 0x7) << 4) | ((lgkmcnt & 0xF) << 8) | (((vmcnt >> 4) & 0x3) << 14)
self.emit(s_waitcnt(waitcnt))
def to_asm(self):
# patch branches
def finalize(self):
"""Patch branch offsets and return the finalized instruction list."""
for inst in self.instructions:
if inst._target is None: continue
inst.simm16 = (self.labels[inst._target] - inst._pos - inst.size()) // 4
# convert instructions to bytes, pack hsa
inst_bytes = b"".join(inst.to_bytes() for inst in self.instructions)
body = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16))
hsa = [('group_segment_fixed_size', 133120), ('private_segment_fixed_size', 0), ('kernarg_size', 24),
('next_free_vgpr', 512), ('next_free_sgpr', 96), ('system_sgpr_workgroup_id_x', 1),
('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 1), ('user_sgpr_kernarg_segment_ptr', 1),
('user_sgpr_count', 2), ('user_sgpr_kernarg_preload_length', 0), ('user_sgpr_kernarg_preload_offset', 0),
('accum_offset', 256), ('uses_dynamic_stack', 0), ('tg_split', 0), ('float_round_mode_32', 0),
('float_round_mode_16_64', 0), ('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3),
('ieee_mode', 1), ('fp16_overflow', 0), ('dx10_clamp', 1)]
args = '\n'.join(f' - .address_space: generic\n .name: {n}\n .offset: {i*8}\n'
f' .size: 8\n .value_kind: global_buffer' for i,n in enumerate(['C', 'A', 'B']))
n = self.name
return '\n'.join(['.text', '.section\t.text.', f'.global\t{n}', '.p2align\t8', f'.type\t{n},@function', '', f'{n}:',
body, '', '.section .rodata,"a",@progbits', '.p2align 6, 0x0', f'.amdhsa_kernel {n}',
*[f' .amdhsa_{k} {v}' for k, v in hsa], '.end_amdhsa_kernel', '', '.amdgpu_metadata', '---', 'amdhsa.kernels:',
' - .args:', args, ' .group_segment_fixed_size: 133120', ' .kernarg_segment_align: 8',
' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 256', f' .name: {n}',
' .private_segment_fixed_size: 0', ' .sgpr_count: 95', ' .sgpr_spill_count: 0', f' .symbol: {n}.kd',
' .vgpr_count: 249', ' .vgpr_spill_count: 0', ' .wavefront_size: 64', 'amdhsa.version:', ' - 1',
' - 1', '...', '.end_amdgpu_metadata', ''])
# outputs readable source code for this kernel
def to_text(self) -> str:
lines, pos = [], 0
for inst in self.instructions:
if (label := self.label_at_pos.get(pos)) is not None: lines.append(f"{label}:")
from test.amd.disasm import disasm
lines.append(f" {disasm(inst)}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}")
pos += inst.size()
return "\n".join(lines)
return self.instructions
def build_kernel(batch, M, N, K, dtype):
numWG, iters, total, magic, shift = compute_gemm_args(M, N, K, batch)
@@ -92,7 +62,7 @@ def build_kernel(batch, M, N, K, dtype):
v_mfma_16x16x32 = {dtypes.half:v_mfma_f32_16x16x32_f16, dtypes.bfloat16:v_mfma_f32_16x16x32_bf16}[dtype]
v_cvt_pk = {dtypes.half:v_cvt_pk_f16_f32, dtypes.bfloat16:v_cvt_pk_bf16_f32}[dtype]
v_cvt = {dtypes.half:v_cvt_f32_f16_e32, dtypes.bfloat16:v_cvt_f32_bf16_e32}[dtype]
k = Kernel(f"gemm_{batch}_{M}_{N}_{K}")
k = Kernel()
# load D, A, B pointers
k.emit(s_load_dwordx2(s[24:25], s[0:1], s[0], 0, 0, 0, 0, 1))
k.emit(s_load_dwordx2(s[30:31], s[0:1], s[0], 8, 0, 0, 0, 1))
@@ -11528,4 +11498,4 @@ def build_kernel(batch, M, N, K, dtype):
k.emit(s_branch(), target='PersistentLoopStart')
k.label('KernelEnd')
k.emit(s_endpgm())
return k
return k.finalize()

View File

@@ -1,6 +1,6 @@
import atexit, functools
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad import Tensor, Device, dtypes
from tinygrad.dtype import AddrSpace
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType
from tinygrad.renderer import Estimates
from tinygrad.helpers import getenv, all_same, dedup
@@ -17,13 +17,12 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp:
assert K == K2
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(wg, "gidx0")
k = build_kernel(batch, M, N, K, A.dtype.base)
sink = UOp.sink(C.base, A.base, B.base, lidx, gidx,
arg=KernelInfo(name=k.name, estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
# TODO: you shouldn't have to call the compiler here, BINARY should be auto-added
binary = HIPCompiler(arch).compile(k.to_asm())
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=k.to_text()), UOp(Ops.BINARY, arg=binary)))
insts = build_kernel(batch, M, N, K, A.dtype.base)
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname),
UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
counters = {"used":0, "todos":[]}
def todo(msg:str) -> bool: counters["todos"].append(msg); return False

View File

@@ -4,6 +4,7 @@ from tinygrad.helpers import ceildiv, round_up
from tinygrad.uop.ops import UOp, Ops
from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc
from tinygrad.renderer.amd.dsl import Reg, FixedBitField
from tinygrad.runtime.autogen.amd.common import OpType
# instructions used for padding
from tinygrad.runtime.autogen.amd.rdna3.ins import s_code_end # same encoding as RDNA4
@@ -13,14 +14,23 @@ _arch_map = {"gfx9": "cdna", "gfx10": "rdna3", "gfx11": "rdna3", "gfx12": "rdna4
def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp:
insts = [u.arg for u in lin.src]
# ** scan for max vgpr/sgpr
max_vgpr, max_sgpr = 0, 0
# ** scan for max vgpr/sgpr/accvgpr
max_vgpr, max_sgpr, max_accvgpr = 0, 0, 0
_ACCVGPR_TYPES = {OpType.OPR_ACCVGPR, OpType.OPR_SRC_ACCVGPR}
for inst in insts:
# build set of field names that are AccVGPR for this instruction
accvgpr_fields: set[str] = set()
for opr_name, (_, _, opr_type) in inst.operands.items():
if opr_type in _ACCVGPR_TYPES: accvgpr_fields.add(opr_name)
elif opr_type in {OpType.OPR_VGPR_OR_ACCVGPR, OpType.OPR_SRC_VGPR_OR_ACCVGPR, OpType.OPR_SRC_VGPR_OR_ACCVGPR_OR_CONST}:
if getattr(inst, 'acc_cd', 0) == 1: accvgpr_fields.add(opr_name)
for name, field in inst._fields:
if isinstance(field, FixedBitField): continue
val = getattr(inst, name)
if not isinstance(val, Reg): continue
if 256 <= val.offset < 512: max_vgpr = max(max_vgpr, (val.offset - 256) + val.sz)
if 256 <= val.offset < 512:
if name in accvgpr_fields: max_accvgpr = max(max_accvgpr, (val.offset - 256) + val.sz)
else: max_vgpr = max(max_vgpr, (val.offset - 256) + val.sz)
elif val.offset < 106: max_sgpr = max(max_sgpr, val.offset + val.sz)
# ** scan sink for metadata
@@ -41,7 +51,10 @@ def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp:
text_offset = round_up(ctypes.sizeof(libc.Elf64_Ehdr), hsa.AMD_ISA_ALIGN_BYTES)
# ** pack kernel descriptor (rodata)
next_free_vgpr, next_free_sgpr = round_up(max_vgpr, 8), round_up(max_sgpr, 8)
# CDNA: total VGPRs = regular VGPRs + AccVGPRs, each rounded to granularity of 4
accum_offset = round_up(max_vgpr, 4) if max_accvgpr > 0 else 0
next_free_vgpr = round_up(accum_offset + max_accvgpr, 8) if max_accvgpr > 0 else round_up(max_vgpr, 8)
next_free_sgpr = round_up(max_sgpr, 8)
vgpr_granule = max(0, (next_free_vgpr + 7) // 8 - 1)
# CDNA: add 6 for VCC(2) + FLAT_SCRATCH(2) + XNACK_MASK(2), next_free_sgpr is unused in RDNA.
sgpr_granule = max(0, ceildiv(next_free_sgpr + 6, 8) - 1) if is_cdna else 0
@@ -64,6 +77,8 @@ def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp:
int(2 in gids) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_Z_SHIFT)
desc.kernel_code_properties = (1 << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_KERNARG_SEGMENT_PTR_SHIFT |
(0 if is_cdna else 1) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_WAVEFRONT_SIZE32_SHIFT)
if is_cdna and max_accvgpr > 0:
desc.compute_pgm_rsrc3 = max(0, accum_offset // 4 - 1) << amdgpu_kd.COMPUTE_PGM_RSRC3_GFX90A_ACCUM_OFFSET_SHIFT
rodata = bytes(desc)
# ** pack ELF