mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
gemm/asm: cleanup custom function args (#15007)
This commit is contained in:
@@ -11,12 +11,12 @@ from extra.gemm.asm.cdna.asm import build_kernel, TILE_M, TILE_N, TILE_K, NUM_WG
|
||||
WORKGROUP_SIZE = 256
|
||||
|
||||
@functools.cache
|
||||
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str, arch:str, wg:int) -> UOp:
|
||||
def custom_asm_gemm(C:UOp, A:UOp, B:UOp, dname:str) -> UOp:
|
||||
batch, M, K = A.shape
|
||||
K2, N = B.shape[(1 if B.ndim == 3 else 0):]
|
||||
assert K == K2
|
||||
lidx = UOp.special(WORKGROUP_SIZE, "lidx0")
|
||||
gidx = UOp.special(wg, "gidx0")
|
||||
gidx = UOp.special(NUM_WG, "gidx0")
|
||||
insts = build_kernel(batch, M, N, K, A.dtype.base)
|
||||
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=133_120, addrspace=AddrSpace.LOCAL), (), 'lds')
|
||||
sink = UOp.sink(C.base, A.base, B.base, lds, lidx, gidx,
|
||||
@@ -94,7 +94,7 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor:
|
||||
renderer = Device[a.device[0] if is_multi else a.device].renderer
|
||||
dname, arch = renderer.device, getattr(renderer, "arch", "")
|
||||
if arch.startswith("gfx950") and getenv("USE_ASM", 1):
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=NUM_WG, arch=arch), grad_fxn=custom_gemm_bw)[0]
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname), grad_fxn=custom_gemm_bw)[0]
|
||||
else:
|
||||
out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0]
|
||||
if k_sharded: out = out.sum(0)
|
||||
|
||||
Reference in New Issue
Block a user