add tflops to cdna gemm custom kernel (#14281)

This commit is contained in:
qazal
2026-01-21 22:48:28 -05:00
committed by GitHub
parent 18f408a35a
commit dfefeddeed

View File

@@ -3,6 +3,7 @@
import pathlib
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops, KernelInfo
from tinygrad.engine.realize import Estimates
from tinygrad.helpers import getenv
fp = pathlib.Path(__file__).parent/"gemm.s"
@@ -44,7 +45,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
sz = UOp.variable("SZ", 256, 8192)
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm"))
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm", estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]