remove llvm requirement from amd (#14717)

* remove llvm requirement from amd

* tests pass

* test

* sink kernarg_size

* move stuff

* amd_asm_matmul to new style

* default type

* fix tests, simpler

* cu mode is faster and simpler

* darken
This commit is contained in:
George Hotz
2026-02-13 10:50:12 +08:00
committed by GitHub
parent 9e33a08adb
commit 4088d686b2
15 changed files with 200 additions and 163 deletions

View File

@@ -9,10 +9,10 @@
# Accumulators: 128 vgprs (v[2-129])
import numpy as np
from pathlib import Path
from tinygrad import Tensor, Device, Context, GlobalCounters
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.helpers import getenv, colored
from tinygrad.dtype import dtypes, AddrSpace
from tinygrad.engine.realize import Estimates
from tinygrad.renderer.amd.dsl import s, v, VCC_LO, NULL
from tinygrad.runtime.autogen.amd.rdna3.ins import *
@@ -51,14 +51,14 @@ V_B_TILE_REGS = [132, 136, 140, 144, 148, 152, 156, 160] # B tile: banks 0,0,0,
# Named register assignments (SGPRs)
# =============================================================================
S_OUT_PTR = (0, 1) # output C matrix base pointer
S_TILE_X = 2 # workgroup_x << 7
S_TILE_Y = 3 # workgroup_y << 7
S_WORKGROUP_X = 2 # workgroup_id_x (system SGPR, follows user SGPRs)
S_WORKGROUP_Y = 3 # workgroup_id_y (system SGPR)
S_DIM_N = 4 # matrix dimension N
S_LOOP_BOUND = 7 # K-8 (loop termination bound)
S_LOOP_CTR = 12 # loop counter (increments by 8)
S_PREFETCH_FLAG = 13 # prefetch condition flag / row stride in epilogue
S_WORKGROUP_X = 14 # workgroup_id_x
S_WORKGROUP_Y = 15 # workgroup_id_y
S_TILE_X = 14 # workgroup_x << 7
S_TILE_Y = 15 # workgroup_y << 7
# Kernarg load destinations
S_KERNARG_A = (20, 21) # A pointer from kernarg
S_KERNARG_B = (22, 23) # B pointer from kernarg
@@ -183,48 +183,14 @@ class Kernel:
waitcnt = (expcnt & 0x7) | ((lgkmcnt & 0x3f) << 4) | ((vmcnt & 0x3f) << 10)
self.emit(s_waitcnt(simm16=waitcnt))
def to_asm(self):
# Patch branch offsets: simm16 = (target_pos - branch_end_pos) / 4
def finalize(self):
"""Patch branch offsets and return the finalized instruction list."""
for inst in self.instructions:
if inst._target is None: continue
offset_dwords = (self.labels[inst._target] - inst._pos - inst.size()) // 4
if not -32768 <= offset_dwords <= 32767: raise ValueError(f"branch to '{inst._target}' offset {offset_dwords} exceeds simm16 range")
inst.simm16 = offset_dwords
# TODO: replace this with direct ELF
from test.amd.disasm import disasm
body = ['\t' + disasm(inst) for inst in self.instructions]
# limit wave occupancy by using more LDS
lds_size = max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536))
# HSA kernel descriptor attributes (zeros included for compatibility)
hsa = [
('group_segment_fixed_size', lds_size), ('private_segment_fixed_size', 0), ('kernarg_size', 36),
('user_sgpr_count', 14), ('user_sgpr_dispatch_ptr', 0), ('user_sgpr_queue_ptr', 0),
('user_sgpr_kernarg_segment_ptr', 1), ('user_sgpr_dispatch_id', 0), ('user_sgpr_private_segment_size', 0),
('wavefront_size32', 1), ('uses_dynamic_stack', 0), ('enable_private_segment', 0),
('system_sgpr_workgroup_id_x', 1), ('system_sgpr_workgroup_id_y', 1), ('system_sgpr_workgroup_id_z', 0),
('system_sgpr_workgroup_info', 0), ('system_vgpr_workitem_id', 0), ('next_free_vgpr', 179),
('next_free_sgpr', 16), ('float_round_mode_32', 0), ('float_round_mode_16_64', 0),
('float_denorm_mode_32', 3), ('float_denorm_mode_16_64', 3), ('dx10_clamp', 1), ('ieee_mode', 1),
('fp16_overflow', 0), ('workgroup_processor_mode', 0), ('memory_ordered', 1), ('forward_progress', 0),
('shared_vgpr_count', 0)]
return '\n'.join([
'\t.text', f'\t.amdgcn_target "amdgcn-amd-amdhsa--{self.arch}"',
'\t.protected\tkernel', '\t.globl\tkernel', '\t.p2align\t8', '\t.type\tkernel,@function', 'kernel:',
*body,
'\t.section\t.rodata,"a",@progbits', '\t.p2align\t6, 0x0', '\t.amdhsa_kernel kernel',
*[f'\t\t.amdhsa_{k} {v}' for k, v in hsa],
'\t.end_amdhsa_kernel', '\t.text', '.Lfunc_end0:', '\t.size\tkernel, .Lfunc_end0-kernel',
'\t.amdgpu_metadata', '---', 'amdhsa.kernels:', ' - .args:',
*[f' - .address_space: global\n .offset: {i*8}\n .size: 8\n .value_kind: global_buffer' for i in range(3)],
f' .group_segment_fixed_size: {lds_size}', ' .kernarg_segment_align: 8',
' .kernarg_segment_size: 24', ' .max_flat_workgroup_size: 128', ' .name: kernel',
' .private_segment_fixed_size: 0', ' .sgpr_count: 60', ' .symbol: kernel.kd',
' .vgpr_count: 179', ' .wavefront_size: 32', f'amdhsa.target: amdgcn-amd-amdhsa--{self.arch}',
'amdhsa.version:', ' - 1', ' - 2', '...', '\t.end_amdgpu_metadata'])
return self.instructions
# =============================================================================
@@ -460,7 +426,7 @@ def build_kernel(arch='gfx1100'):
k.emit(s_sendmsg(simm16=3)) # DEALLOC_VGPRS
k.emit(s_endpgm())
return k.to_asm()
return k.finalize()
# =============================================================================
# Test harness
@@ -474,16 +440,7 @@ def test_matmul():
dev = Device[Device.DEFAULT]
print(f"Device arch: {dev.renderer.arch}")
if getenv("STOCK", 0):
# Load the stock kernel from amd_seb/kernel8_batched_gmem.s
stock_path = Path(__file__).parent / "amd_seb" / "kernel8_batched_gmem.s"
asm = stock_path.read_text()
print(f"Loaded stock kernel from {stock_path}")
else:
asm = build_kernel(dev.renderer.arch)
binary = dev.compiler.compile(asm)
print(f"Compiled! Binary size: {len(binary)} bytes")
insts = build_kernel(dev.renderer.arch)
rng = np.random.default_rng(42)
a = Tensor(rng.random((N, N), dtype=np.float32) - 0.5)
@@ -498,10 +455,10 @@ def test_matmul():
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
sink = UOp.sink(A.base, B.base, C.base, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=asm),
UOp(Ops.BINARY, arg=binary)))
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=max(LDS_SIZE, 65536//getenv("LIMIT_OCC", 65536)), addrspace=AddrSpace.LOCAL), (), 'lds')
sink = UOp.sink(A.base, B.base, C.base, lds, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"),
estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
ei = c.schedule()[0].lower()
@@ -542,6 +499,6 @@ def run_sqtt():
print(f"Wrote {len(output)} bytes to /tmp/sqtt_trace.txt")
if __name__ == "__main__":
if getenv("ASM", 0): print(build_kernel(Device[Device.DEFAULT].arch))
if getenv("ASM", 0): print("\n".join(str(inst) for inst in build_kernel(Device[Device.DEFAULT].renderer.arch)))
elif getenv("SQTT", 0): run_sqtt()
else: test_matmul()

View File

@@ -1,10 +1,11 @@
import os, pathlib
import os
# TODO: there is a timing bug without this
os.environ["AMD_AQL"] = "1"
from tinygrad.device import Device
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad import Tensor, Device
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.renderer.amd.dsl import Reg, Inst, s, v
NUM_WORKGROUPS = 96
@@ -13,17 +14,22 @@ NUM_WAVES = 2
FLOPS_PER_MATMUL = 16*16*16*2
INTERNAL_LOOP = 1_000_00
INSTRUCTIONS_PER_LOOP = 200
DIRECTIVE = ".amdhsa_wavefront_size32 1"
assemblyTemplate = (pathlib.Path(__file__).parent / "template.s").read_text()
def repeat(insts:list[Inst], n:int, counter_sreg:Reg) -> bytes:
preamble = s_mov_b32(counter_sreg, n).to_bytes()
def repeat(insts:list[Inst], n:int, counter_sreg:Reg) -> list[Inst]:
insts_bytes = b"".join([inst.to_bytes() for inst in insts])
sub_inst, cmp_inst = s_sub_u32(counter_sreg, counter_sreg, 1), s_cmp_lg_i32(counter_sreg, 0)
loop_sz = len(insts_bytes) + sub_inst.size() + cmp_inst.size()
branch_inst = s_cbranch_scc1(simm16=-((loop_sz // 4) + 1) & 0xFFFF)
return preamble + insts_bytes + sub_inst.to_bytes() + cmp_inst.to_bytes() + branch_inst.to_bytes() + s_endpgm().to_bytes()
return [s_mov_b32(counter_sreg, n)] + insts + [sub_inst, cmp_inst, branch_inst, s_endpgm()]
def make_kernel(insts:list[Inst]):
def fxn(A:UOp) -> UOp:
threads = UOp.special(WAVE_SIZE * NUM_WAVES, "lidx0")
gidx = UOp.special(NUM_WORKGROUPS, "gidx0")
FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
sink = UOp.sink(A.base, threads, gidx, arg=KernelInfo("mmapeak", estimates=Estimates(ops=FLOPs, mem=0)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
return fxn
def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs):
if accum:
@@ -32,16 +38,12 @@ def launchBenchmark(instruction, vgprIndices, dense=True, accum=False, **kwargs)
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[1]:vgprIndices[2]], 1)
else:
inst = instruction(v[0:vgprIndices[0]], v[vgprIndices[1]:vgprIndices[2]], v[vgprIndices[3]:vgprIndices[4]], v[vgprIndices[5]])
vgprs:set = set()
for n,_ in inst._fields:
if isinstance(val:=getattr(inst, n), Reg) and val.offset >= v.offset: vgprs |= {val.offset+i for i in range(val.sz)}
inst_bytes = repeat([inst for _ in range(INSTRUCTIONS_PER_LOOP)], n=INTERNAL_LOOP, counter_sreg=s[1])
inst_hex = "\n".join(" .byte " + ",".join(f"0x{b:02x}" for b in inst_bytes[i:i+16]) for i in range(0, len(inst_bytes), 16)) + "\n"
src = assemblyTemplate.replace("INTERNAL_LOOP", str(INTERNAL_LOOP)).replace("INSTRUCTION", inst_hex).replace("VGPR_COUNT", str(len(vgprs)))
src = src.replace("DIRECTIVE", DIRECTIVE)
lib = COMPILER.compile(src)
fxn = DEV.runtime("matmul", lib)
elapsed = min([fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) for _ in range(2)])
insts = repeat([inst for _ in range(INSTRUCTIONS_PER_LOOP)], n=INTERNAL_LOOP, counter_sreg=s[1])
fxn = make_kernel(insts)
dummy = Tensor.zeros(1).contiguous().realize()
out = Tensor.custom_kernel(dummy, fxn=fxn)[0]
ei = out.schedule()[-1].lower()
elapsed = min([ei.run(wait=True) for _ in range(2)])
FLOPs = FLOPS_PER_MATMUL * NUM_WAVES * NUM_WORKGROUPS * INTERNAL_LOOP * INSTRUCTIONS_PER_LOOP
print(f"{inst.op_name.lower():<29} : {FLOPs/elapsed/10**12:.2f} T(FL)OPS")
@@ -49,7 +51,6 @@ if __name__=="__main__":
DEV = Device[Device.DEFAULT]
arch = DEV.renderer.arch
COMPILER = HIPCompiler(arch)
if arch in {'gfx1100', 'gfx1103', 'gfx1151'}:
from tinygrad.runtime.autogen.amd.rdna3.ins import *
if arch == 'gfx1103': NUM_WORKGROUPS = 8
@@ -91,7 +92,6 @@ if __name__=="__main__":
launchBenchmark(v_swmmac_i32_16x16x64_iu4, (7,8,9,10,13,14), False)
elif arch == 'gfx950':
from tinygrad.runtime.autogen.amd.cdna.ins import *
DIRECTIVE = ".amdhsa_accum_offset 4"
NUM_WORKGROUPS = 256
WAVE_SIZE = 64
NUM_WAVES = 4

View File

@@ -1,33 +0,0 @@
.text
.globl matmul
.p2align 8
.type matmul,@function
matmul:
INSTRUCTION
.rodata
.p2align 6
.amdhsa_kernel matmul
.amdhsa_next_free_vgpr VGPR_COUNT
.amdhsa_next_free_sgpr 3
DIRECTIVE
.end_amdhsa_kernel
.amdgpu_metadata
---
amdhsa.version:
- 1
- 0
amdhsa.kernels:
- .name: matmul
.symbol: matmul.kd
.kernarg_segment_size: 0
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.kernarg_segment_align: 4
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 32
.max_flat_workgroup_size: 1024
...
.end_amdgpu_metadata

View File

@@ -1,23 +1,12 @@
import unittest
import functools
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.runtime.autogen.amd.rdna3.ins import *
from tinygrad.renderer.amd.dsl import s, v, Inst
from test.amd.disasm import disasm as disasm_inst
from tinygrad.renderer.amd.dsl import s, v
def assemble_insts(insts:list[Inst], name:str, arch:str, kernarg_size:int=8) -> tuple[UOp, UOp]:
kd = {"kernarg_size":kernarg_size, "user_sgpr_kernarg_segment_ptr":1, "next_free_vgpr":8, "next_free_sgpr":8, "wavefront_size32":1}
disasm = "\n".join([disasm_inst(inst) for inst in insts])
hsasrc = f".text\n.globl {name}\n.p2align 8\n.type fn_name,@function\n{name}:\n{disasm}\ns_code_end\n"
hsasrc += f".rodata\n.p2align 6\n.amdhsa_kernel {name}\n"+"\n".join([f".amdhsa_{k} {v}" for k,v in kd.items()])+"\n.end_amdhsa_kernel"
binary = HIPCompiler(arch).compile(hsasrc)
return UOp(Ops.SOURCE, arg=disasm), UOp(Ops.BINARY, arg=binary)
def custom_add_one(A:UOp, arch:str) -> UOp:
def custom_add_one(A:UOp) -> UOp:
A = A.flatten()
assert dtypes.is_float(A.dtype.base), f"buffer dtype must be float32, got {A.dtype}"
threads = UOp.special(A.size, "lidx0")
@@ -32,10 +21,10 @@ def custom_add_one(A:UOp, arch:str) -> UOp:
global_store_b32(addr=v[0], data=v[1], saddr=s[0:1]),
s_endpgm(),
]
sink = UOp.sink(A.base, threads, arg=KernelInfo(name:=f"custom_add_one_{A.size}", estimates=Estimates(ops=A.size, mem=A.size*4*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=(*sink.src, sink)), *assemble_insts(insts, name, arch)))
sink = UOp.sink(A.base, threads, arg=KernelInfo(f"custom_add_one_{A.size}", estimates=Estimates(ops=A.size, mem=A.size*4*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_add_var(A:UOp, B:UOp, arch:str) -> UOp:
def custom_add_var(A:UOp, B:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert A.dtype.base == dtypes.uint32, f"buffer dtype must be uint32, got {A.dtype}"
threads = UOp.special(A.size, "lidx0")
@@ -51,15 +40,14 @@ def custom_add_var(A:UOp, B:UOp, arch:str) -> UOp:
global_store_b32(addr=v[0], data=v[1], saddr=s[4:5]),
s_endpgm(),
]
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(name:=f"custom_add_one_{A.size}"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=(*sink.src, sink)),
*assemble_insts(insts, name, arch, kernarg_size=16)))
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.size}"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
class TestCustomKernel(unittest.TestCase):
def test_simple(self):
a = Tensor.full((16, 16), 1.).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_add_one, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
a = Tensor.custom_kernel(a, fxn=custom_add_one)[0]
ei = a.schedule()[-1].lower()
self.assertEqual(ei.prg.estimates.ops, a.numel())
self.assertEqual(ei.prg.estimates.mem, a.nbytes()*2)
@@ -69,7 +57,7 @@ class TestCustomKernel(unittest.TestCase):
def test_variable(self):
b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize()
a = Tensor.zeros_like(b).contiguous().realize()
a = Tensor.custom_kernel(a, b, fxn=functools.partial(custom_add_var, arch=Device[Device.DEFAULT].renderer.arch))[0] # type: ignore[attr-defined]
a = Tensor.custom_kernel(a, b, fxn=custom_add_var)[0]
ei = a.schedule()[-1].lower()
for i in range(4):
ei.run({"var":i})

View File

@@ -2,12 +2,9 @@
# allow define from star imports
import unittest
import functools
from tinygrad import Device, Tensor
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.device import Compiler
from tinygrad.runtime.support.compiler_amd import HIPCompiler
from tinygrad.viz.serve import amdgpu_cfg
from tinygrad.runtime.autogen.amd.rdna3.ins import *
@@ -16,24 +13,17 @@ from tinygrad.renderer.amd.dsl import s
# TODO: this belongs to the dsl infrastructure
from extra.gemm.amd_asm_matmul import Kernel
# TODO: shouldn't need compiler once we can output ELF
# outputs a text disassembly for humans and a machine readable binary
def assemble(name:str, k:Kernel, compiler:Compiler) -> tuple[str, bytes]:
src = k.to_asm()
return (src, compiler.compile(src))
def asm_kernel(out:UOp, k:Kernel, name:str, device:str, compiler:Compiler, n_threads:int=1, n_workgroups:int=1) -> UOp:
lidx = UOp.special(n_threads, "lidx0")
gidx = UOp.special(n_workgroups, "gidx0")
sink = UOp.sink(out, lidx, gidx, arg=KernelInfo(name=name))
src, lib = assemble(name, k, compiler)
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=device), UOp(Ops.LINEAR, src=(*sink.src, sink)),
UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=lib)))
def run_asm(name:str, k:Kernel) -> None:
fxn = functools.partial(asm_kernel, k=k, name=name, device=Device.DEFAULT, compiler=HIPCompiler(Device[Device.DEFAULT].renderer.arch))
def run_asm(name:str, k:Kernel):
insts = k.finalize()
def fxn(out:UOp) -> UOp:
lidx = UOp.special(1, "lidx0")
gidx = UOp.special(1, "gidx0")
sink = UOp.sink(out.base, lidx, gidx, arg=KernelInfo(name=name))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
out = Tensor.custom_kernel(Tensor.empty(1), fxn=fxn)[0]
out.realize()
ei = out.schedule()[-1].lower()
ei.run()
return ei
@unittest.skipUnless(Device.DEFAULT == "AMD", "only on AMD")
class TestCfg(unittest.TestCase):
@@ -67,9 +57,8 @@ class TestCfg(unittest.TestCase):
k.label("end")
k.emit(s_endpgm())
k.emit(s_code_end())
run_asm("diamond", k)
_, lib = assemble("diamond", k, HIPCompiler(Device[Device.DEFAULT].arch))
cfg = amdgpu_cfg(lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
ei = run_asm("diamond", k)
cfg = amdgpu_cfg(ei.prg.p.lib, Device[Device.DEFAULT].device_props()["gfx_target_version"])["data"]
self.assertEqual(len(cfg["blocks"]), 5)
edge_count = sum(len(v) for v in cfg["paths"].values())
self.assertEqual(edge_count, 5)

View File

@@ -21,6 +21,7 @@ from tinygrad.codegen.opt.postrange import apply_opts, pm_make_images
from tinygrad.codegen.simplify import pm_simplify_ranges, pm_flatten_range, pm_split_ranges, pm_load_collapse
from tinygrad.schedule.rangeify import pm_add_buffers_local, rangeify_codegen, pm_mops, pm_syntactic_sugar
from tinygrad.codegen.late.linearizer import CFGContext, pm_split_ends, pm_add_control_flow, linearize
from tinygrad.renderer.amd.elf import do_assemble_amd
def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -> UOp:
if ren is None: ren = Renderer()
@@ -144,6 +145,7 @@ def do_compile(ctx:Renderer, prg:UOp, source:UOp) -> UOp|None:
pm_to_program = PatternMatcher([
(UPat(Ops.PROGRAM, src=(UPat(Ops.SINK, name="sink"), UPat(Ops.DEVICE)), name="prg"), do_linearize),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, src=UPat(Ops.INS), name="lin")), name="prg"), do_assemble_amd),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR, name="lin")), name="prg"), do_render),
(UPat(Ops.PROGRAM, src=(UPat(), UPat(Ops.DEVICE), UPat(Ops.LINEAR), UPat(Ops.SOURCE, name="source")), name="prg"), do_compile),
])

View File

@@ -38,7 +38,6 @@ class Estimates:
elif u.op is Ops.IF:
dont_count = dont_count.union(u.src[0].toposort())
for u in uops:
if u.op is Ops.SINK and isinstance(u.arg, KernelInfo) and u.arg.estimates is not None: return u.arg.estimates
if u.op in {Ops.LOAD, Ops.STORE}:
buf = u
while len(buf.src): buf = buf.src[0]
@@ -82,6 +81,7 @@ class ProgramSpec:
@functools.cached_property
def estimates(self) -> Estimates:
if self.ast.op is Ops.SINK and isinstance(self.ast.arg, KernelInfo) and self.ast.arg.estimates is not None: return self.ast.arg.estimates
return Estimates() if self.uops is None else Estimates.from_uops(self.uops, ignore_indexing=True)
@functools.cached_property
@@ -118,7 +118,7 @@ class ProgramSpec:
ins: list[int] = []
global_size: list[int] = [1, 1, 1]
local_size: list[int]|None = [1, 1, 1]
for u in uops:
for u in sink.toposort():
if u.op is Ops.DEFINE_VAR: _vars.append(u)
if u.op is Ops.PARAM: _globals.append(u.arg)
if u.op in (Ops.STORE, Ops.LOAD):

View File

@@ -99,7 +99,7 @@ bits = _Bits()
class BitField:
name: str | None
def __init__(self, hi: int, lo: int, default: int = 0):
def __init__(self, hi: int, lo: int, default = 0):
self.hi, self.lo, self.default, self.name, self.mask = hi, lo, default, None, (1 << (hi - lo + 1)) - 1
def __set_name__(self, owner, name: str): self.name = name
def __eq__(self, other) -> 'FixedBitField': # type: ignore[override]

View File

@@ -0,0 +1,127 @@
# minimal amdgpu elf packer
import ctypes
from tinygrad.helpers import ceildiv, round_up
from tinygrad.uop.ops import UOp, Ops
from tinygrad.runtime.autogen import amdgpu_kd, hsa, libc
from tinygrad.renderer.amd.dsl import Reg, FixedBitField
# instructions used for padding
from tinygrad.runtime.autogen.amd.rdna3.ins import s_code_end # same encoding as RDNA4
from tinygrad.runtime.autogen.amd.cdna.ins import s_nop as s_nop_cdna
def put(dst:bytearray, off:int, data:bytes) -> None:
end = off + len(data)
if end > len(dst): raise ValueError("write past end of buffer")
dst[off:end] = data
def create_elf(prg:bytes, kd:dict, arch:str) -> bytes:
is_cdna, is_rdna4 = arch == "cdna", arch == "rdna4"
padding_inst = (s_nop_cdna(0) if is_cdna else s_code_end()).to_bytes()
text = prg + padding_inst * ((hsa.AMD_ISA_ALIGN_BYTES - len(prg) % hsa.AMD_ISA_ALIGN_BYTES) % hsa.AMD_ISA_ALIGN_BYTES)
text_offset = round_up(ctypes.sizeof(libc.Elf64_Ehdr), hsa.AMD_ISA_ALIGN_BYTES)
rodata_offset = text_offset + len(text)
# ** pack rodata object
desc = amdgpu_kd.llvm_amdhsa_kernel_descriptor_t()
desc.group_segment_fixed_size = kd.get("group_segment_fixed_size", 0)
desc.private_segment_fixed_size = kd.get("private_segment_fixed_size", 0)
desc.kernarg_size = kd.get("kernarg_size", 0)
desc.kernel_code_entry_byte_offset = text_offset-rodata_offset
# rsrc1
vgpr_granule = max(0, (kd["next_free_vgpr"] + 7) // 8 - 1)
# CDNA: add 6 for VCC(2) + FLAT_SCRATCH(2) + XNACK_MASK(2)
# next_free_sgpr is unused in RDNA
# NOTE: CU mode is the default, it seems faster and simpler
sgpr_granule = max(0, ceildiv(kd["next_free_sgpr"] + 6, 8) - 1) if is_cdna else 0
desc.compute_pgm_rsrc1 = (vgpr_granule << amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WORKITEM_VGPR_COUNT_SHIFT |
sgpr_granule << amdgpu_kd.COMPUTE_PGM_RSRC1_GRANULATED_WAVEFRONT_SGPR_COUNT_SHIFT |
kd.get("float_round_mode_32", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_ROUND_MODE_32_SHIFT |
kd.get("float_round_mode_16_64", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_ROUND_MODE_16_64_SHIFT |
kd.get("float_denorm_mode_32", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_DENORM_MODE_32_SHIFT |
kd.get("float_denorm_mode_16_64", 3) << amdgpu_kd.COMPUTE_PGM_RSRC1_FLOAT_DENORM_MODE_16_64_SHIFT |
kd.get("dx10_clamp", 0 if is_rdna4 else 1) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX6_GFX11_ENABLE_DX10_CLAMP_SHIFT |
kd.get("ieee_mode", 0 if is_rdna4 else 1) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX6_GFX11_ENABLE_IEEE_MODE_SHIFT |
kd.get("fp16_overflow", 0) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX9_PLUS_FP16_OVFL_SHIFT |
(0 if is_cdna else kd.get("workgroup_processor_mode", 0)) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX10_PLUS_WGP_MODE_SHIFT |
(0 if is_cdna else kd.get("memory_ordered", 1)) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX10_PLUS_MEM_ORDERED_SHIFT |
(0 if is_cdna else kd.get("forward_progress", 0)) << amdgpu_kd.COMPUTE_PGM_RSRC1_GFX10_PLUS_FWD_PROGRESS_SHIFT)
# rsrc2
desc.compute_pgm_rsrc2 = (kd.get("enable_private_segment", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_PRIVATE_SEGMENT_SHIFT |
kd.get("user_sgpr_count", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_USER_SGPR_COUNT_SHIFT |
kd.get("system_sgpr_workgroup_id_x", 1) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_X_SHIFT |
kd.get("system_sgpr_workgroup_id_y", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_Y_SHIFT |
kd.get("system_sgpr_workgroup_id_z", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_ID_Z_SHIFT |
kd.get("system_sgpr_workgroup_info", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_SGPR_WORKGROUP_INFO_SHIFT |
kd.get("system_vgpr_workitem_id", 0) << amdgpu_kd.COMPUTE_PGM_RSRC2_ENABLE_VGPR_WORKITEM_ID_SHIFT)
# rsrc3
if is_cdna:
amdhsa_accum_offset = ((kd.get("accum_offset", 4) // 4) - 1) & amdgpu_kd.COMPUTE_PGM_RSRC3_GFX90A_ACCUM_OFFSET
desc.compute_pgm_rsrc3 = amdhsa_accum_offset << amdgpu_kd.COMPUTE_PGM_RSRC3_GFX90A_ACCUM_OFFSET_SHIFT
else:
desc.compute_pgm_rsrc3 = kd.get("shared_vgpr_count", 0) << amdgpu_kd.COMPUTE_PGM_RSRC3_GFX10_GFX11_SHARED_VGPR_COUNT_SHIFT
# kernel code properties
desc.kernel_code_properties = (kd.get("user_sgpr_dispatch_ptr", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_DISPATCH_PTR_SHIFT |
kd.get("user_sgpr_queue_ptr", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_QUEUE_PTR_SHIFT |
kd.get("user_sgpr_kernarg_segment_ptr", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_KERNARG_SEGMENT_PTR_SHIFT |
kd.get("user_sgpr_dispatch_id", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_DISPATCH_ID_SHIFT |
kd.get("user_sgpr_private_segment_size",0) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_SGPR_PRIVATE_SEGMENT_SIZE_SHIFT |
kd.get("wavefront_size32", 0 if is_cdna else 1) << amdgpu_kd.KERNEL_CODE_PROPERTY_ENABLE_WAVEFRONT_SIZE32_SHIFT |
kd.get("uses_dynamic_stack", 0) << amdgpu_kd.KERNEL_CODE_PROPERTY_USES_DYNAMIC_STACK_SHIFT)
rodata = bytes(desc)
# ** pack elf sections
sh_names:list[int] = []
strtab = bytearray(b"\x00")
for name in [".text", ".rodata", ".strtab"]:
sh_names.append(len(strtab))
strtab += name.encode("ascii") + b"\x00"
rodata_offset = round_up(text_offset+(text_size:=len(text)), hsa.AMD_KERNEL_CODE_ALIGN_BYTES)
strtab_offset = rodata_offset+(rodata_size:=len(rodata))
shdr_offset = strtab_offset+(strtab_size:=len(strtab))
sections = [(libc.SHT_PROGBITS, libc.SHF_ALLOC | libc.SHF_EXECINSTR, text_offset, text_offset, text_size),
(libc.SHT_PROGBITS, libc.SHF_ALLOC, rodata_offset, rodata_offset, rodata_size),
(libc.SHT_STRTAB, 0, 0, strtab_offset, strtab_size)]
shdrs = (libc.Elf64_Shdr * len(sections))()
for i,s in enumerate(sections): shdrs[i] = libc.Elf64_Shdr(sh_names[i], *s)
ehdr = libc.Elf64_Ehdr()
ehdr.e_shoff, ehdr.e_shnum, ehdr.e_shstrndx = shdr_offset, len(sections), 2
elf = bytearray(shdr_offset + ctypes.sizeof(shdrs))
put(elf, 0, bytes(ehdr))
put(elf, text_offset, text)
put(elf, rodata_offset, rodata)
put(elf, strtab_offset, strtab)
put(elf, shdr_offset, bytes(shdrs))
return bytes(elf)
_arch_map = {"gfx9": "cdna", "gfx10": "rdna3", "gfx11": "rdna3", "gfx12": "rdna4"}
def do_assemble_amd(ctx, prg:UOp, lin:UOp) -> UOp:
insts = [u.arg for u in lin.src]
# scan for max vgpr/sgpr
max_vgpr, max_sgpr = 0, 0
for inst in insts:
for name, field in inst._fields:
if isinstance(field, FixedBitField): continue
val = getattr(inst, name)
if not isinstance(val, Reg): continue
if 256 <= val.offset < 512: max_vgpr = max(max_vgpr, (val.offset - 256) + val.sz)
elif val.offset < 106: max_sgpr = max(max_sgpr, val.offset + val.sz)
# scan sink for metadata
sink, n_bufs, n_vars, lds_size, gids = prg.src[0], 0, 0, 0, set()
for u in sink.toposort():
if u.op is Ops.PARAM: n_bufs += 1
elif u.op is Ops.DEFINE_VAR: n_vars += 1
elif u.op is Ops.DEFINE_LOCAL: lds_size += u.ptrdtype.size * u.ptrdtype.base.itemsize
elif u.op is Ops.SPECIAL and u.arg.startswith("gidx"): gids.add(int(u.arg[-1]))
src = "\n".join(str(inst) for inst in insts)
code_bytes = b"".join(inst.to_bytes() for inst in insts)
arch = next(v for k, v in _arch_map.items() if ctx.arch.startswith(k))
kd = {"kernarg_size":n_bufs*8+n_vars*4, "group_segment_fixed_size":lds_size,
"user_sgpr_kernarg_segment_ptr":1, "user_sgpr_count":2,
"system_sgpr_workgroup_id_x":int(0 in gids), "system_sgpr_workgroup_id_y":int(1 in gids), "system_sgpr_workgroup_id_z":int(2 in gids),
"next_free_vgpr":round_up(max_vgpr, 8), "next_free_sgpr":round_up(max_sgpr, 8)}
binary = create_elf(code_bytes, kd, arch)
return prg.replace(src=prg.src[:3]+(UOp(Ops.SOURCE, arg=src), UOp(Ops.BINARY, arg=binary)))

View File

@@ -23,7 +23,7 @@ def elf_loader(blob:bytes, force_section_align:int=1, link_libs:list[str]|None=N
def _to_carray(sh, ctype): return (ctype * (sh.header.sh_size // sh.header.sh_entsize)).from_buffer_copy(sh.content)
rel = [(sh, sh.name[4:], _to_carray(sh, libc.Elf64_Rel)) for sh in sections if sh.header.sh_type == libc.SHT_REL]
rela = [(sh, sh.name[5:], _to_carray(sh, libc.Elf64_Rela)) for sh in sections if sh.header.sh_type == libc.SHT_RELA]
symtab = [_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB][0]
symtab = next((_to_carray(sh, libc.Elf64_Sym) for sh in sections if sh.header.sh_type == libc.SHT_SYMTAB), None)
progbits = [sh for sh in sections if sh.header.sh_type == libc.SHT_PROGBITS]
# Prealloc image for all fixed addresses.
@@ -39,7 +39,7 @@ def elf_loader(blob:bytes, force_section_align:int=1, link_libs:list[str]|None=N
for sh, trgt_sh_name, c_rels in rel + rela:
if trgt_sh_name == ".eh_frame": continue
target_image_off = next(tsh for tsh in sections if tsh.name == trgt_sh_name).header.sh_addr
rels = [(r.r_offset, symtab[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
rels = [(r.r_offset, unwrap(symtab)[libc.ELF64_R_SYM(r.r_info)], libc.ELF64_R_TYPE(r.r_info), getattr(r, "r_addend", 0)) for r in c_rels]
relocs += [(target_image_off + roff, link_sym(_strtab(sh_strtab, sym.st_name), link_libs or []) if sym.st_shndx == 0 else
sections[sym.st_shndx].header.sh_addr + sym.st_value, rtype, raddend) for roff, sym, rtype, raddend in rels]

View File

@@ -76,6 +76,9 @@ class Ops(FastEnum):
# CUSTOM/CUSTOMI are used to output strings into codegen. the I makes the string inline
CUSTOM = auto(); CUSTOMI = auto()
# INS is a machine instruction
INS = auto()
# ** 6 -- ops that don't exist in programs **
# tensor graph ops

View File

@@ -209,7 +209,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
# late ops don't have shape
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT | Ops.SINK | \
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY:
Ops.LINEAR | Ops.PROGRAM | Ops.SOURCE | Ops.BINARY | Ops.INS:
return None
case Ops.INDEX:

View File

@@ -177,6 +177,9 @@ shared_codegen_spec = PatternMatcher([
# CUSTOM (inline and non inline)
(UPat((Ops.CUSTOMI, Ops.CUSTOM)), lambda: True),
# assembly instruction
(UPat(Ops.INS), lambda: True),
# INDEX (2-arg and 3-arg with bool gate)
(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf").index(UPat.var("idx")), validate_index),
(UPat(Ops.INDEX, src=(UPat(GroupOp.Defines|{Ops.AFTER}, name="buf"), UPat.var("idx"), UPat.var("gate", dtype=dtypes.bool))), validate_index),

View File

@@ -186,6 +186,7 @@ const waveColor = (op) => {
: op.includes("LOAD") || op === "SMEM" ? "LOAD" : op.includes("STORE") ? "STORE" : op;
ret = WAVE_COLORS[cat] ?? "#ffffff";
if (op.includes("OTHER_") || op.includes("_ALT")) { ret = darkenHex(ret, 75) }
if (op.includes("LDS_")) { ret = darkenHex(ret, 25) }
return ret
};
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),

View File

@@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff",
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.INS: "#eec4ff",
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F", Ops.SOURCE: "#c0c0c0", Ops.LINEAR: "#808080", Ops.BINARY: "#404040",