Files
tinygrad/test/amd/test_custom_kernel.py
2026-04-15 21:58:27 -04:00

197 lines
8.5 KiB
Python

import unittest
import functools
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import Estimates
from tinygrad.dtype import AddrSpace
from tinygrad.runtime.autogen.amd.rdna3.ins import *
import tinygrad.runtime.autogen.amd.rdna3.ins as r3
import tinygrad.runtime.autogen.amd.rdna4.ins as r4
from tinygrad.renderer.amd.dsl import s, v
from test.amd.helpers import TARGET_TO_ARCH
from extra.gemm.amd_asm_matmul import Kernel
def custom_add_one(A:UOp) -> UOp:
A = A.flatten()
assert dtypes.is_float(A.dtype.base), f"buffer dtype must be float32, got {A.dtype}"
threads = UOp.special(A.numel(), "lidx0")
insts = [
s_load_b64(s[0:1], s[0:1], soffset=NULL),
s_waitcnt_lgkmcnt(sdst=NULL, simm16=0),
v_lshlrev_b32_e32(v[0], 2, v[0]), # element offset
global_load_b32(v[1], v[0], saddr=s[0:1]),
s_waitcnt_vmcnt(sdst=NULL, simm16=0),
v_mov_b32_e32(v[2], 1.0),
v_add_f32_e32(v[1], v[1], v[2]),
global_store_b32(addr=v[0], data=v[1], saddr=s[0:1]),
s_endpgm(),
]
sink = UOp.sink(A.base, threads, arg=KernelInfo(f"custom_add_one_{A.numel()}", estimates=Estimates(ops=A.numel(), mem=A.numel()*4*2)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_add_var(A:UOp, B:UOp) -> UOp:
A,B = A.flatten(), B.flatten()
assert A.dtype.base == dtypes.uint32, f"buffer dtype must be uint32, got {A.dtype}"
threads = UOp.special(A.numel(), "lidx0")
var = UOp.variable("var", 0, 10)
insts = [
s_load_b128(s[4:7], s[0:1]),
s_load_b32(s[8], s[0:1], offset=0x10), # all threads load the same variable
s_waitcnt_lgkmcnt(sdst=NULL, simm16=0),
v_lshlrev_b32_e32(v[0], 2, v[0]), # element offset, different per thread
global_load_b32(v[1], v[0], saddr=s[6:7]),
s_waitcnt_vmcnt(sdst=NULL, simm16=0),
v_add_nc_u32_e32(v[1], s[8], v[1]),
global_store_b32(addr=v[0], data=v[1], saddr=s[4:5]),
s_endpgm(),
]
sink = UOp.sink(A.base, B.base, var, threads, arg=KernelInfo(f"custom_add_var_{A.numel()}"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_wave_sync(A:UOp, arch:str) -> UOp:
# 4 waves across 1024 WG — enough to saturate a SIMD with many concurrent WGs
# s_sleep yields the SIMD so waves from different WGs interleave, causing barrier packet reordering
threads = UOp.special(128, "lidx0")
wg = UOp.special(1024, "gidx0")
insts = []
for _ in range(4):
insts.append(s_sleep(4))
insts += [s_barrier()] if arch == "rdna3" else [r4.s_barrier_signal(), r4.s_barrier_wait()]
insts += [s_nop(0)]*4
insts.append(s_endpgm())
sink = UOp.sink(A.base, threads, wg, arg=KernelInfo("custom_wave_sync"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_lds_sync(A:UOp, arch:str) -> UOp:
A = A.flatten()
num_threads = A.shape[0]
threads = UOp.special(num_threads, "lidx0")
wg = UOp.special(1, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
isa = r4 if arch == "rdna4" else r3
wait_kmcnt = [isa.s_wait_kmcnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
wait_dscnt = [isa.s_wait_dscnt(simm16=0)] if arch == "rdna4" else [isa.s_waitcnt_lgkmcnt(sdst=NULL, simm16=0)]
barrier = [isa.s_barrier_signal(ssrc0=-1), isa.s_barrier_wait(simm16=-1)] if arch == "rdna4" else [isa.s_barrier()]
global_store = [isa.global_store_b32(vaddr=v[6:7], saddr=s[0:1], vsrc=v[5])] if arch == "rdna4" \
else [isa.global_store_b32(addr=v[6], data=v[5], saddr=s[0:1])]
insts = [
isa.s_load_b64(s[0:1], s[0:1], soffset=NULL),
*wait_kmcnt,
isa.v_lshlrev_b32_e32(v[1], 2, v[0]),
# lds[thread_idx] = thread_idx
isa.ds_store_b32(addr=v[1], data0=v[0]),
*wait_dscnt,
*barrier,
# out[threaed_idx] = thread_idx == num_threads ? -1 : lds[thread_idx + 1]
isa.v_add_nc_u32_e32(v[2], 4, v[1]),
isa.v_cmp_gt_u32_e32(num_threads-1, v[0]),
isa.ds_load_b32(vdst=v[3], addr=v[2]),
*wait_dscnt,
isa.v_mov_b32_e32(v[4], -1),
isa.v_cndmask_b32_e32(v[5], v[4], v[3]),
isa.v_lshlrev_b32_e32(v[6], 2, v[0]),
*global_store,
isa.s_endpgm(),
]
sink = UOp.sink(A.base, lds, threads, wg, arg=KernelInfo("custom_lds_sync"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_handwritten(A:UOp, arch:str) -> UOp:
A = A.flatten()
threads = UOp.special(128, "lidx0")
wg = UOp.special(256, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=512, addrspace=AddrSpace.LOCAL), (), 'lds') # 128 * 4 bytes
k = Kernel(arch)
k.emit(r4.s_nop(0))
k.emit(r4.v_mov_b32_e32(v[1], 4))
def emit_alt():
for i in range(2):
k.emit(r4.v_mov_b32_e32(v[20+i], 4.0))
k.emit(r4.v_rcp_f32_e32(v[22+i], v[20+i]))
k.emit(r4.s_mov_b32(s[20+i], i))
k.emit(r4.s_mul_i32(s[14+i], s[12+i], 32))
def emit_wmma():
for _ in range(2):
k.emit(r4.v_wmma_f32_16x16x16_f16(v[0:7], v[8:11], v[8:11], 1))
k.label("start")
k.emit(s_mov_b32(s[1], 10))
k.label("loop")
# wmma should've overlapped here if it was a different unit?
for _ in range(2):
emit_wmma()
emit_alt()
for _ in range(8): k.emit(s_nop(1))
k.emit(s_add_u32(s[1], s[1], -1))
k.emit(s_cmp_eq_i32(s[1], 0))
k.emit(s_cbranch_scc0(), target="loop")
k.emit(r4.s_endpgm())
insts = k.finalize()
sink = UOp.sink(A.base, threads, wg, lds, arg=KernelInfo("custom_handwritten"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
def custom_data_deps(A:UOp, arch:str) -> UOp:
A = A.flatten()
threads = UOp.special(A.numel(), "lidx0")
k = Kernel(arch)
k.emit(s_load_b64(s[0:1], s[0:1], soffset=NULL))
k.emit(s_waitcnt_lgkmcnt(sdst=NULL, simm16=0))
k.emit(v_lshlrev_b32_e32(v[0], 2, v[0]))
k.emit(global_load_b32(v[1], v[0], saddr=s[0:1]))
k.emit(s_waitcnt_vmcnt(sdst=NULL, simm16=0))
k.emit(v_add_f32_e32(v[1], 1.0, v[1]))
k.emit(global_store_b32(addr=v[0], data=v[1], saddr=s[0:1]))
k.emit(s_endpgm())
insts = k.finalize()
sink = UOp.sink(A.base, threads, arg=KernelInfo("custom_data_deps"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple([UOp(Ops.INS, arg=x) for x in insts]))))
@unittest.skipUnless(Device.DEFAULT == "AMD", "requires AMD device")
class TestCustomKernel(unittest.TestCase):
def setUp(self): self.arch = TARGET_TO_ARCH[Device["AMD"].arch]
def test_simple(self):
if self.arch != "rdna3": self.skipTest("only rdna3")
a = Tensor.full((16, 16), 1.).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=custom_add_one)[0]
ei = a.schedule()[-1].lower()
self.assertEqual(ei.prg.estimates.ops, a.numel())
self.assertEqual(ei.prg.estimates.mem, a.nbytes()*2)
ei.run()
self.assertTrue((a.numpy() == 2.).all())
def test_variable(self):
if self.arch != "rdna3": self.skipTest("only rdna3")
b = Tensor.full((16, 16), 1, dtype=dtypes.uint32).contiguous().realize()
a = Tensor.zeros_like(b).contiguous().realize()
a = Tensor.custom_kernel(a, b, fxn=custom_add_var)[0]
ei = a.schedule()[-1].lower()
for i in range(4):
ei.run({"var":i})
self.assertTrue((a.numpy() == 1+i).all())
def test_lds_sync(self):
if self.arch not in ("rdna3", "rdna4"): self.skipTest("only rdna3/rdna4")
a = Tensor.empty(128, dtype=dtypes.int32).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_lds_sync, arch=self.arch))[0]
a.realize()
ref = Tensor.arange(1, 129, dtype=dtypes.int32)
ref[127] = -1
self.assertListEqual(a.tolist(), ref.tolist())
def test_handwritten(self):
if self.arch != "rdna4": self.skipTest("only tested on rdna4")
a = Tensor.empty(1024, dtype=dtypes.int32).contiguous().realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_handwritten, arch=self.arch))[0]
a.realize()
def test_data_deps(self):
if self.arch != "rdna3": self.skipTest("only tested on rdna3")
a = Tensor(np.full(32, 5.0, dtype=np.float32)).realize()
a = Tensor.custom_kernel(a, fxn=functools.partial(custom_data_deps, arch=self.arch))[0]
a.realize()
self.assertTrue((a.numpy() == 6.0).all())
if __name__ == "__main__":
unittest.main()