diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 2b569f0d9e..6b7bfc7e8d 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -11,12 +11,12 @@ from extra.gemm.asm.cdna.asm import build_kernel, TILE_M, TILE_N, TILE_K, NUM_WG WORKGROUP_SIZE = 256 @functools.cache -def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp: +def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp: batch, M, K = A.shape K2, N = B.shape[(1 if B.ndim == 3 else 0):] assert K == K2 lidx = UOp.special(WORKGROUP_SIZE, "lidx0") - gidx = UOp.special(wg, "gidx0") + gidx = UOp.special(NUM_WG, "gidx0") insts = build_kernel(batch, M, N, K, A.dtype.base) lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds') sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx, @@ -94,7 +94,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: renderer = Device[a.device[0] if is_multi else a.device].renderer dname, arch = renderer.device, getattr(renderer, "arch", "") if arch.startswith("gfx950") and getenv("USE_ASM", 1): - out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=NUM_WG, arch=arch), grad_fxn=custom_gemm_bw)[0] + out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0] else: out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0] if k_sharded: out = out.sum(0)