From 5bd2862d1a9f8b50cf2fff292ef98a99523a794e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 17 Feb 2026 12:04:22 +0800 Subject: [PATCH] late compile the cdna gemm (#14783) * late compile the cdna gemm * remove old things * finalize inplace --------- Co-authored-by: qazal --- extra/gemm/asm/cdna/asm.py | 42 ++++++------------------------------ extra/gemm/asm/cdna/gemm.py | 15 ++++++------- tinygrad/renderer/amd/elf.py | 23 ++++++++++++++++---- 3 files changed, 32 insertions(+), 48 deletions(-) diff --git a/extra/gemm/asm/cdna/asm.py b/extra/gemm/asm/cdna/asm.py index c23084bbcf..821f0f1390 100644 --- a/extra/gemm/asm/cdna/asm.py +++ b/extra/gemm/asm/cdna/asm.py @@ -32,7 +32,7 @@ def compute_gemm_args(M:int, N:int, K:int, batch:int) -> tuple[int, int, int, in return NUM_WG, iters, total, magic, shift class Kernel: - def __init__(self, name="gemm"): self.name, self.instructions, self.labels, self.label_at_pos, self.pos = name, [], {}, {}, 0 + def __init__(self): self.instructions, self.labels, self.label_at_pos, self.pos = [], {}, {}, 0 def label(self, name): self.labels[name] = self.pos @@ -49,42 +49,12 @@ class Kernel: waitcnt = (vmcnt & 0xF) | ((expcnt & 0x7) << 4) | ((lgkmcnt & 0xF) << 8) | (((vmcnt >> 4) & 0x3) << 14) self.emit(s_waitcnt(waitcnt)) - def to_asm(self): - # patch branches + def finalize(self): + """Patch branch offsets and return the finalized instruction list.""" for inst in self.instructions: if inst._target is None: continue inst.simm16 = (self.labels[inst._target] - inst._pos - inst.size()) // 4 - # convert instructions to bytes, pack hsa - inst_bytes = b"".join(inst.to_bytes() for inst in self.instructions) - body = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) - hsa = [('group_segment_fixed_size', 133120), ('private_segment_fixed_size', 0), ('kernarg_size', 24), - ('next_free_vgpr', 512), ('next_free_sgpr', 96), ('system_sgpr_workgroup_id_x', 1), - ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 1), ('user_sgpr_kernarg_segment_ptr', 1), - ('user_sgpr_count', 2), ('user_sgpr_kernarg_preload_length', 0), ('user_sgpr_kernarg_preload_offset', 0), - ('accum_offset', 256), ('uses_dynamic_stack', 0), ('tg_split', 0), ('float_round_mode_32', 0), - ('float_round_mode_16_64', 0), ('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), - ('ieee_mode', 1), ('fp16_overflow', 0), ('dx10_clamp', 1)] - args = '\n'.join(f' - .address_space: generic\n .name: {n}\n .offset: {i*8}\n' - f' .size: 8\n .value_kind: global_buffer' for i,n in enumerate(['C', 'A', 'B'])) - n = self.name - return '\n'.join(['.text', '.section\t.text.', f'.global\t{n}', '.p2align\t8', f'.type\t{n},@function', '', f'{n}:', - body, '', '.section .rodata,"a",@progbits', '.p2align 6, 0x0', f'.amdhsa_kernel {n}', - *[f' .amdhsa_{k} {v}' for k, v in hsa], '.end_amdhsa_kernel', '', '.amdgpu_metadata', '---', 'amdhsa.kernels:', - ' - .args:', args, ' .group_segment_fixed_size: 133120', ' .kernarg_segment_align: 8', - ' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 256', f' .name: {n}', - ' .private_segment_fixed_size: 0', ' .sgpr_count: 95', ' .sgpr_spill_count: 0', f' .symbol: {n}.kd', - ' .vgpr_count: 249', ' .vgpr_spill_count: 0', ' .wavefront_size: 64', 'amdhsa.version:', ' - 1', - ' - 1', '...', '.end_amdgpu_metadata', '']) - - # outputs readable source code for this kernel - def to_text(self) -> str: - lines, pos = [], 0 - for inst in self.instructions: - if (label := self.label_at_pos.get(pos)) is not None: lines.append(f"{label}:") - from test.amd.disasm import disasm - lines.append(f" {disasm(inst)}" if inst._target is None else f" {inst.op_name.lower()} {inst._target}") - pos += inst.size() - return "\n".join(lines) + return self.instructions def build_kernel(batch, M, N, K, dtype): numWG, iters, total, magic, shift = compute_gemm_args(M, N, K, batch) @@ -92,7 +62,7 @@ def build_kernel(batch, M, N, K, dtype): v_mfma_16x16x32 = {dtypes.half:v_mfma_f32_16x16x32_f16, dtypes.bfloat16:v_mfma_f32_16x16x32_bf16}[dtype] v_cvt_pk = {dtypes.half:v_cvt_pk_f16_f32, dtypes.bfloat16:v_cvt_pk_bf16_f32}[dtype] v_cvt = {dtypes.half:v_cvt_f32_f16_e32, dtypes.bfloat16:v_cvt_f32_bf16_e32}[dtype] - k = Kernel(f"gemm_{batch}_{M}_{N}_{K}") + k = Kernel() # load D, A, B pointers k.emit(s_load_dwordx2(s[24:25], s[0:1], s[0], 0, 0, 0, 0, 1)) k.emit(s_load_dwordx2(s[30:31], s[0:1], s[0], 8, 0, 0, 0, 1)) @@ -11528,4 +11498,4 @@ def build_kernel(batch, M, N, K, dtype): k.emit(s_branch(), target='PersistentLoopStart') k.label('KernelEnd') k.emit(s_endpgm()) - return k + return k.finalize() diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 2d249e0322..547e133e46 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -1,6 +1,6 @@ import atexit, functools -from tinygrad.runtime.support.compiler_amd import HIPCompiler from tinygrad import Tensor, Device, dtypes +from tinygrad.dtype import AddrSpace from tinygrad.uop.ops import UOp, Ops, KernelInfo, AxisType from tinygrad.renderer import Estimates from tinygrad.helpers import getenv, all_same, dedup @@ -17,13 +17,12 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp: assert K == K2 lidx = UOp.special(WORKGROUP_SIZE, "lidx0") gidx = UOp.special(wg, "gidx0") - k = build_kernel(batch, M, N, K, A.dtype.base) - sink = UOp.sink(C.base, A.base, B.base, lidx, gidx, - arg=KernelInfo(name=k.name, estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2))) - # TODO: you shouldn't have to call the compiler here, BINARY should be auto-added - binary = HIPCompiler(arch).compile(k.to_asm()) - return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), - UOp(Ops.SOURCE, arg=k.to_text()), UOp(Ops.BINARY, arg=binary))) + insts = build_kernel(batch, M, N, K, A.dtype.base) + lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds') + sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx, + arg=KernelInfo(name=f"gemm_{batch}_{M}_{N}_{K}", estimates=Estimates(ops=2*batch*M*N*K, mem=(batch*M*K + K*N + batch*M*N)*2))) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), + UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts])))) counters = {"used":0, "todos":[]} def todo(msg:str) -> bool: counters["todos"].append(msg); return False diff --git a/tinygrad/renderer/amd/elf.py b/tinygrad/renderer/amd/elf.py index 91692e26aa..b07101f13d 100644 --- a/tinygrad/renderer/amd/elf.py +++ b/tinygrad/renderer/amd/elf.py @@ -4,6 +4,7 @@ from tinygrad.helpers import ceildiv, round_up from tinygrad.uop.ops import UOp, Ops from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc from tinygrad.renderer.amd.dsl import Reg, FixedBitField +from tinygrad.runtime.autogen.amd.common import OpType # instructions used for padding from tinygrad.runtime.autogen.amd.rdna3.ins import s_code_end # same encoding as RDNA4 @@ -13,14 +14,23 @@ _arch_map = {"gfx9": "cdna", "gfx10": "rdna3", "gfx11": "rdna3", "gfx12": "rdna4 def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp: insts = [u.arg for u in lin.src] - # ** scan for max vgpr/sgpr - max_vgpr, max_sgpr = 0, 0 + # ** scan for max vgpr/sgpr/accvgpr + max_vgpr, max_sgpr, max_accvgpr = 0, 0, 0 + _ACCVGPR_TYPES = {OpType.OPR_ACCVGPR, OpType.OPR_SRC_ACCVGPR} for inst in insts: + # build set of field names that are AccVGPR for this instruction + accvgpr_fields: set[str] = set() + for opr_name, (_, _, opr_type) in inst.operands.items(): + if opr_type in _ACCVGPR_TYPES: accvgpr_fields.add(opr_name) + elif opr_type in {OpType.OPR_VGPR_OR_ACCVGPR, OpType.OPR_SRC_VGPR_OR_ACCVGPR, OpType.OPR_SRC_VGPR_OR_ACCVGPR_OR_CONST}: + if getattr(inst, 'acc_cd', 0) == 1: accvgpr_fields.add(opr_name) for name, field in inst._fields: if isinstance(field, FixedBitField): continue val = getattr(inst, name) if not isinstance(val, Reg): continue - if 256 <= val.offset < 512: max_vgpr = max(max_vgpr, (val.offset - 256) + val.sz) + if 256 <= val.offset < 512: + if name in accvgpr_fields: max_accvgpr = max(max_accvgpr, (val.offset - 256) + val.sz) + else: max_vgpr = max(max_vgpr, (val.offset - 256) + val.sz) elif val.offset < 106: max_sgpr = max(max_sgpr, val.offset + val.sz) # ** scan sink for metadata @@ -41,7 +51,10 @@ def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp: text_offset = round_up(ctypes.sizeof(libc.Elf64_Ehdr), hsa.AMD_ISA_ALIGN_BYTES) # ** pack kernel descriptor (rodata) - next_free_vgpr, next_free_sgpr = round_up(max_vgpr, 8), round_up(max_sgpr, 8) + # CDNA: total VGPRs = regular VGPRs + AccVGPRs, each rounded to granularity of 4 + accum_offset = round_up(max_vgpr, 4) if max_accvgpr > 0 else 0 + next_free_vgpr = round_up(accum_offset + max_accvgpr, 8) if max_accvgpr > 0 else round_up(max_vgpr, 8) + next_free_sgpr = round_up(max_sgpr, 8) vgpr_granule = max(0, (next_free_vgpr + 7) // 8 - 1) # CDNA: add 6 for VCC(2) + FLAT_SCRATCH(2) + XNACK_MASK(2), next_free_sgpr is unused in RDNA. sgpr_granule = max(0, ceildiv(next_free_sgpr + 6, 8) - 1) if is_cdna else 0 @@ -64,6 +77,8 @@ def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp: int(2 in gids) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_Z_SHIFT) desc.kernel_code_properties = (1 << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_KERNARG_SEGMENT_PTR_SHIFT | (0 if is_cdna else 1) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_WAVEFRONT_SIZE32_SHIFT) + if is_cdna and max_accvgpr > 0: + desc.compute_pgm_rsrc3 = max(0, accum_offset // 4 - 1) << amdgpu_kd.COMPUTE_PGM_RSRC3_GFX90A_ACCUM_OFFSET_SHIFT rodata = bytes(desc) # ** pack ELF