mi350x gemm: use Tensor.custom_kernel in asm test (#13969)

* mi350x gemm: use Tensor.custom_kernel in asm test

* A @ B for baseline
This commit is contained in:
qazal
2026-01-02 18:30:50 +09:00
committed by GitHub
parent 5a1a561e0f
commit 5f52266225

View File

@@ -2,10 +2,8 @@
# VIZ=2 to profile # VIZ=2 to profile
import pathlib import pathlib
from tinygrad import Tensor, Device, dtypes, Context from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.engine.realize import ExecItem, CompiledRunner from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.renderer import ProgramSpec from tinygrad.helpers import getenv
from tinygrad.uop.ops import track_rewrites, UOp
from tinygrad.helpers import TracingKey, getenv
fp = pathlib.Path(__file__).parent/"gemm.s" fp = pathlib.Path(__file__).parent/"gemm.s"
@@ -23,8 +21,8 @@ import torch
torch.manual_seed(0) torch.manual_seed(0)
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous() A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous() B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
Bt = B.t().contiguous() # transpose B for the baseline gemm Bt = B.t().contiguous() # transpose B for the asm gemm
C_torch = A@Bt C_torch = A@B
# ** copy buffers to AMD # ** copy buffers to AMD
@@ -33,31 +31,33 @@ C_torch = A@Bt
def from_torch(t:torch.Tensor) -> Tensor: def from_torch(t:torch.Tensor) -> Tensor:
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize() return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
C_tiny = Tensor.matmul(from_torch(A), from_torch(Bt), dtype=dtypes.float32).cast(dtypes.bfloat16) C_tiny = from_torch(A) @ from_torch(B)
C_asm = Tensor.empty_like(C_tiny) C_asm = Tensor.empty_like(C_tiny)
C_asm.uop.buffer.allocate()
# ** assembly custom kernel
def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
lidx = UOp.special(THREADS_PER_WG, "lidx0")
gidx = UOp.special(NUM_WG, "gidx0")
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
sz = UOp.variable("SZ", 256, 8192)
wg = UOp.variable("WG", 1, 1024)
sink = UOp.sink(C.base, A.base, B.base, sz, wg, lidx, gidx, arg=KernelInfo(name="gemm"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
# ** run gemms # ** run gemms
# baseline tinygrad sched = Tensor.schedule(C_tiny, C_asm)
sched = C_tiny.schedule() eis = [si.lower() for si in sched]
assert len(sched) == 1
eis:list[ExecItem] = [sched[-1].lower()]
ast = sched[-1].ast
# assembly gemm
@track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret))
def get_asm_prg() -> ProgramSpec:
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
lib = Device[Device.DEFAULT].compiler.compile(src)
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(A).uop.buffer, from_torch(B).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
prg=CompiledRunner(get_asm_prg())))
with Context(DEBUG=2): with Context(DEBUG=2):
for ei in eis: for ei in eis:
et = ei.run(wait=True) et = ei.run({"SZ":N, "WG":NUM_WG}, wait=True)
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS") print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
# ** correctness # ** correctness