assembly/amd: use Tensor.custom_kernel to run assembly (#14125)

* assembly/amd: use Tensor.custom_kernel to run assembly

* PRINT_ASM=1 is DEBUG=4
This commit is contained in:
qazal
2026-01-13 18:29:25 -05:00
committed by GitHub
parent fe00682502
commit 002ea39da7

View File

@@ -11,8 +11,9 @@
import numpy as np
from pathlib import Path
from tinygrad import Tensor, Device, Context, GlobalCounters
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.helpers import getenv, colored
from tinygrad.engine.realize import Runner, Estimates, ExecItem
from tinygrad.engine.realize import Estimates
from extra.assembly.amd.dsl import s, v, VCC_LO, NULL
from extra.assembly.amd.autogen.rdna3.ins import *
@@ -590,7 +591,6 @@ def test_matmul():
print(f"Loaded stock kernel from {stock_path}")
else:
asm = build_kernel(dev.arch)
if getenv("PRINT_ASM", 0): print(asm)
binary = dev.compiler.compile(asm)
print(f"Compiled! Binary size: {len(binary)} bytes")
@@ -604,15 +604,15 @@ def test_matmul():
grid, local = (N // BLOCK_N, N // BLOCK_M, 1), (THREADS, 1, 1)
print(f"Grid: {grid}, Local: {local}")
_prg = dev.runtime("kernel", binary)
class AsmRunner(Runner):
def __init__(self):
super().__init__(colored("kernel", "cyan"), Device.DEFAULT, Estimates(ops=N*N*N*2, mem=N*N*4*3))
def __call__(self, rawbufs, var_vals, wait=False):
c_buf, a_buf, b_buf = [x.ensure_allocated()._buf for x in rawbufs]
return _prg(a_buf, b_buf, c_buf, global_size=grid, local_size=local, wait=wait)
ei = ExecItem(None, [c.uop.buffer, a.uop.buffer, b.uop.buffer], prg=AsmRunner())
dname:str = Device.DEFAULT
def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp:
gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)]
lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)]
sink = UOp.sink(A.base, B.base, C.base, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan")))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=asm),
UOp(Ops.BINARY, arg=binary)), arg=())
c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2]
ei = c.schedule()[0].lower()
ets = []
with Context(DEBUG=2):