diff --git a/extra/gemm/asm/cdna/test.py b/extra/gemm/asm/cdna/test.py index 80ad5d9123..1370f5879b 100644 --- a/extra/gemm/asm/cdna/test.py +++ b/extra/gemm/asm/cdna/test.py @@ -3,6 +3,7 @@ import pathlib from tinygrad import Tensor, Device, dtypes, Context from tinygrad.uop.ops import UOp, Ops, KernelInfo +from tinygrad.engine.realize import Estimates from tinygrad.helpers import getenv fp = pathlib.Path(__file__).parent/"gemm.s" @@ -44,7 +45,7 @@ def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp: sz = UOp.variable("SZ", 256, 8192) - sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm")) + sink = UOp.sink(C.base, A.base, B.base, sz, lidx, gidx, arg=KernelInfo(name="gemm", estimates=Estimates(ops=N*N*N*2, mem=N*N*4*3))) return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src))) C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]