From 80b0119cef77211209f85d4da4ea8fcbb47cfa39 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 9 Feb 2026 23:34:29 +0800 Subject: [PATCH] llama: add new asm gemm shape (#14611) * llama: add new asm gemm shape * work * cleanup * half dtype * more comment --- extra/gemm/asm/cdna/asm.py | 5 +++++ extra/gemm/asm/cdna/gemm.py | 14 +++++++++----- test/testextra/test_asm_gemm.py | 32 ++++++++++++++++++++++++-------- 3 files changed, 38 insertions(+), 13 deletions(-) diff --git a/extra/gemm/asm/cdna/asm.py b/extra/gemm/asm/cdna/asm.py index 7d3785f704..ed05646579 100644 --- a/extra/gemm/asm/cdna/asm.py +++ b/extra/gemm/asm/cdna/asm.py @@ -14,7 +14,12 @@ GEMM_ARGS = { (8192, 8192, 8192): (256, 128, 131072), (4096, 4096, 4096): (256, 64, 16384), (4096, 14336, 4096): (256, 64, 57344), + (4096, 14336, 8192): (256, 128, 114688), (4096, 4096, 14336): (256, 224, 57344), + (14336, 4096, 8192): (256, 128, 114688), + (4096, 8192, 14336): (256, 224, 114688), + (4096, 4096, 8192): (256, 128, 32768), + (4096, 8192, 4096): (256, 64, 32768), } ITERS_ARGS = {64: (67108864, 0), 128: (33554432, 0), 224: (613566757, 2147483656)} diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 275f5d2113..d276ed7a1a 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -30,13 +30,13 @@ atexit.register(lambda: print(f'asm_gemm: {counters["used"]} used, {len(counters def can_use_asm_gemm(a:Tensor, b:Tensor) -> bool: if a.dtype != b.dtype: return todo(f"dtypes must match {a.dtype} != {b.dtype}") if a.dtype not in {dtypes.bfloat16, dtypes.float16}: return todo(f"only bfloat16/float16, got {a.dtype}") - # only sharding on the batch is tested, others might work too - if isinstance(a.device, tuple) and not (a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None): - return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}") batch, M, K = (1, *a.shape) if a.ndim == 2 else a.shape N = b.shape[1] + # only sharding on the batch or K is tested, others might work too if isinstance(a.device, tuple): - batch //= len(a.device) + if a.ndim == 2 and a.uop.axis == 1 and b.uop.axis == 0: K //= len(a.device) + elif a.ndim == 3 and a.uop.axis == 0 and b.uop.axis is None: batch //= len(a.device) + else: return todo(f"sharding mismatch a.ndim={a.ndim} a.uop.axis={a.uop.axis} b.uop.axis={b.uop.axis}") dname = a.device[0] else: dname = a.device arch = getattr(Device[dname].renderer, "arch", "") @@ -65,6 +65,8 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): out, a, b = kernel.src[1:] assert all_same([gradient.device, a.device, b.device, out.device]) a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device) + # TODO: this needs to be cleaned up and done properly, the batch dim of grad and a multi need to align + g_t = g_t[:a.shape[0]] grad_a = (g_t @ b_t.T).uop grad_b = (a_t.permute(2, 0, 1).reshape(a_t.shape[2], -1) @ g_t.reshape(-1, g_t.shape[-1])).uop return (None, grad_a, grad_b) @@ -80,9 +82,10 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: batch, M, K = a.shape N = b.shape[1] is_multi = isinstance(a.device, tuple) + if (k_sharded:=is_multi and a.uop.axis == 2): K //= len(a.device) if is_multi: - out = Tensor(Tensor.empty(batch//len(a.device), M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device) + out = Tensor(Tensor.empty(batch//len(a.device) if a.uop.axis==0 else batch, M, N, dtype=a.dtype, device=a.device).uop.multi(0), device=a.device) else: out = Tensor.empty(batch, M, N, dtype=a.dtype, device=a.device) @@ -93,4 +96,5 @@ def asm_gemm(a:Tensor, b:Tensor) -> Tensor: out = Tensor.custom_kernel(out, a, b, fxn=functools.partial(custom_asm_gemm, dname=dname, wg=numWG, arch=arch), grad_fxn=custom_gemm_bw)[0] else: out = Tensor.custom_kernel(out, a, b, fxn=custom_uop_gemm, grad_fxn=custom_gemm_bw)[0] + if k_sharded: out = out.sum(0) return out.squeeze(0) if squeeze else out diff --git a/test/testextra/test_asm_gemm.py b/test/testextra/test_asm_gemm.py index e1dac741c7..fd6995435b 100644 --- a/test/testextra/test_asm_gemm.py +++ b/test/testextra/test_asm_gemm.py @@ -9,24 +9,24 @@ from test.helpers import needs_second_gpu # Use NULL=1 EMULATE=AMD_CDNA4 to also test the assembly def is_cdna4(): return getattr(Device[Device.DEFAULT].renderer, "arch", "").startswith("gfx950") -def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None: +def run_asm_gemm(a_shape, b_shape, dtype=dtypes.float16, a_shard=None, b_shard=None, gpus:int=1) -> None: Tensor.manual_seed(0) - a_rand = Tensor.randn((batch, M, K), dtype=dtypes.float).sub(0.5).cast(dtype) - b_rand = Tensor.randn((K, N), dtype=dtypes.float).sub(0.5).cast(dtype) + a_rand = Tensor.randn(a_shape, dtype=dtypes.float).sub(0.5).cast(dtype) + b_rand = Tensor.randn(b_shape, dtype=dtypes.float).sub(0.5).cast(dtype) with Context(DEBUG=0): Tensor.realize(a_rand, b_rand) devs = tuple(f"{Device.DEFAULT}:{i}" for i in range(gpus)) if (multi:=gpus>1) else None a, b = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype) - if multi: a, b = a.shard(devs, axis=0), b.shard(devs, axis=None) + if multi: a, b = a.shard(devs, axis=a_shard), b.shard(devs, axis=b_shard) with Context(ASM_GEMM=1): tst = asm_gemm(a, b) tst.sum().backward() Tensor.realize(tst, a.grad, b.grad) a_ref, b_ref = Tensor(a_rand.numpy(), requires_grad=True).cast(dtype), Tensor(b_rand.numpy(), requires_grad=True).cast(dtype) - if multi: a_ref, b_ref = a_ref.shard(devs, axis=0), b_ref.shard(devs, axis=None) + if multi: a_ref, b_ref = a_ref.shard(devs, axis=a_shard), b_ref.shard(devs, axis=b_shard) with Context(ASM_GEMM=0): ref = asm_gemm(a_ref, b_ref) ref.sum().backward() @@ -34,10 +34,18 @@ def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:i # no validation on the NULL device if a_rand.device.startswith("NULL"): return None + atol, rtol = (1e-2, 1e-3) with Context(DEBUG=0): - assert (tst - ref).square().max().float().item() < 1e-6, "forward mismatch" - assert (a.grad - a_ref.grad).square().max().float().item() < 1e-3, "grad_a mismatch" - assert (b.grad - b_ref.grad).square().max().float().item() < 1e-3, "grad_b mismatch" + assert tst.allclose(ref, atol=atol, rtol=rtol), "forward mismatch" + assert a.grad.allclose(a_ref.grad, atol=atol, rtol=rtol), "grad_a mismatch" + assert b.grad.allclose(b_ref.grad, atol=atol, rtol=rtol), "grad_b mismatch" + + +def verify_asm_gemm(batch:int, M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=1) -> None: + run_asm_gemm((batch, M, K), (K, N), dtype=dtype, a_shard=0, b_shard=None, gpus=gpus) + +def verify_asm_gemm_k_sharded(M:int, N:int, K:int, dtype=dtypes.float16, gpus:int=8) -> None: + run_asm_gemm((M, K), (K, N), dtype=dtype, a_shard=1, b_shard=0, gpus=gpus) # 128x smaller than usual # uses the UOp GEMM, runs on non CDNA4 and CI @@ -50,6 +58,8 @@ class TestGemm(unittest.TestCase): def test_gemm_batched(self): verify_asm_gemm(2, 64, 32, 32) @needs_second_gpu def test_gemm_multi(self): verify_asm_gemm(2, 64, 32, 32, gpus=2) + @needs_second_gpu + def test_gemm_k_sharded(self): verify_asm_gemm_k_sharded(64, 64, 2*64, gpus=2) # uses the Asm GEMM on CDNA4 only for speed reasons class TestGemmLarge(unittest.TestCase): @@ -70,6 +80,12 @@ class TestGemmLarge(unittest.TestCase): def test_gemm6(self): verify_asm_gemm(16, 4096, 4096, 14336, dtype=dtypes.bfloat16, gpus=8) @unittest.skip("disabled, asm in this shape is slower than tinygrad") def test_gemm7(self): verify_asm_gemm(1, 8192, 128256, 4096) + def test_gemm8(self): verify_asm_gemm(1, 4096, 14336, 8192) + def test_gemm9(self): verify_asm_gemm(8, 4096, 14336, 8192, dtype=dtypes.bfloat16, gpus=8) + def test_gemm10(self): verify_asm_gemm(1, 4096, 8192, 4096) + def test_k_sharded_1(self): verify_asm_gemm_k_sharded(14336, 4096, 8*8192, gpus=8) + def test_k_sharded_2(self): verify_asm_gemm_k_sharded(4096, 14336, 8*8192, gpus=8) + def test_k_sharded_3(self): verify_asm_gemm_k_sharded(4096, 4096, 8*8192, gpus=8) def test_gemm_unsupported(self): with self.assertRaisesRegex(AssertionError, "shape not supported"): verify_asm_gemm(8, 1024, 1024, 4096, gpus=8)