mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add tflops to cdna gemm custom kernel (#14281)
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
import pathlib
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
from tinygrad.engine.realize import Estimates
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
fp = pathlib.Path(__file__).parent/"gemm.s"
|
||||
@@ -44,7 +45,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||
|
||||
sz = UOp.variable("SZ", 256, 8192)
|
||||
|
||||
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm"))
|
||||
sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm", estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3)))
|
||||
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||
|
||||
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
|
||||
|
||||
Reference in New Issue
Block a user