gemm/asm: cleanup custom function args (#15007)

This commit is contained in:
qazal
2026-02-25 21:05:56 +08:00
committed by GitHub
parent c58e91942c
commit 448e997be4

View File

@@ -11,12 +11,12 @@ from extra.gemm.asm.cdna.asm import build_kernel, TILE_M, TILE_N, TILE_K, NUM_WG
WORKGROUP_SIZE = 256
@functools.cache
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp:
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
batch, M, K = A.shape
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
assert K == K2
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
gidx = UOp.special(wg, "gidx0")
gidx = UOp.special(NUM_WG, "gidx0")
insts = build_kernel(batch, M, N, K, A.dtype.base)
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
@@ -94,7 +94,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
renderer = Device[a.device[0] if is_multi else a.device].renderer
dname, arch = renderer.device, getattr(renderer, "arch", "")
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=NUM_WG, arch=arch), grad_fxn=custom_gemm_bw)[0]
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
else:
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
if k_sharded: out = out.sum(0)