From 002ea39da7beeec74c9c26b67c16e721140477c0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 13 Jan 2026 18:29:25 -0500 Subject: [PATCH] assembly/amd: use Tensor.custom_kernel to run assembly (#14125) * assembly/amd: use Tensor.custom_kernel to run assembly * PRINT_ASM=1 is DEBUG=4 --- extra/gemm/amd_asm_matmul.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/extra/gemm/amd_asm_matmul.py b/extra/gemm/amd_asm_matmul.py index 2ef09f491e..cf9f8af18a 100644 --- a/extra/gemm/amd_asm_matmul.py +++ b/extra/gemm/amd_asm_matmul.py @@ -11,8 +11,9 @@ import numpy as np from pathlib import Path from tinygrad import Tensor, Device, Context, GlobalCounters +from tinygrad.uop.ops import UOp, Ops, KernelInfo from tinygrad.helpers import getenv, colored -from tinygrad.engine.realize import Runner, Estimates, ExecItem +from tinygrad.engine.realize import Estimates from extra.assembly.amd.dsl import s, v, VCC_LO, NULL from extra.assembly.amd.autogen.rdna3.ins import * @@ -590,7 +591,6 @@ def test_matmul(): print(f"Loaded stock kernel from {stock_path}") else: asm = build_kernel(dev.arch) - if getenv("PRINT_ASM", 0): print(asm) binary = dev.compiler.compile(asm) print(f"Compiled! Binary size: {len(binary)} bytes") @@ -604,15 +604,15 @@ def test_matmul(): grid, local = (N // BLOCK_N, N // BLOCK_M, 1), (THREADS, 1, 1) print(f"Grid: {grid}, Local: {local}") - _prg = dev.runtime("kernel", binary) - class AsmRunner(Runner): - def __init__(self): - super().__init__(colored("kernel", "cyan"), Device.DEFAULT, Estimates(ops=N*N*N*2, mem=N*N*4*3)) - def __call__(self, rawbufs, var_vals, wait=False): - c_buf, a_buf, b_buf = [x.ensure_allocated()._buf for x in rawbufs] - return _prg(a_buf, b_buf, c_buf, global_size=grid, local_size=local, wait=wait) - - ei = ExecItem(None, [c.uop.buffer, a.uop.buffer, b.uop.buffer], prg=AsmRunner()) + dname:str = Device.DEFAULT + def asm_kernel(A:UOp, B:UOp, C:UOp) -> UOp: + gidxs = [UOp.special(n, f"gidx{i}") for i,n in enumerate(grid)] + lidxs = [UOp.special(n, f"lidx{i}") for i,n in enumerate(local)] + sink = UOp.sink(A.base, B.base, C.base, *gidxs, *lidxs, arg=KernelInfo(name=colored("kernel", "cyan"))) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=dname), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=asm), + UOp(Ops.BINARY, arg=binary)), arg=()) + c = Tensor.custom_kernel(a, b, c, fxn=asm_kernel)[2] + ei = c.schedule()[0].lower() ets = [] with Context(DEBUG=2):