From 43e7eda4e7a338a3c3213136601d7320ab439f77 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:22:27 +0800 Subject: [PATCH] grad_b uses custom gemm (#14550) * grad_b uses custom gemm * fix multi backward, acc is in float32 * test_gemm_batched * square gemm --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> Co-authored-by: qazal --- extra/gemm/asm/cdna/gemm.py | 5 +---- test/testextra/test_asm_gemm.py | 3 ++- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/extra/gemm/asm/cdna/gemm.py b/extra/gemm/asm/cdna/gemm.py index 463f38238c..a6beb7bacf 100644 --- a/extra/gemm/asm/cdna/gemm.py +++ b/extra/gemm/asm/cdna/gemm.py @@ -66,10 +66,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp): assert all_same([gradient.device, a.device, b.device, out.device]) a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device) grad_a = (g_t @ b_t.T).uop - a_T = a_t.transpose(-2, -1) - a_T = a_T.reshape(*a_T.shape[:-1], 1, a_T.shape[-1]) - g_r = g_t.reshape(*g_t.shape[:-2], 1, *g_t.shape[-2:]).transpose(-1, -2) - grad_b = (a_T * g_r).sum((-1, 0)).uop + grad_b = (a_t.permute(2, 0, 1).reshape(a_t.shape[2], -1) @ g_t.reshape(-1, g_t.shape[-1])).uop return (None, grad_a, grad_b) # ** main gemm function diff --git a/test/testextra/test_asm_gemm.py b/test/testextra/test_asm_gemm.py index a0607e9cb0..13bb252ed4 100644 --- a/test/testextra/test_asm_gemm.py +++ b/test/testextra/test_asm_gemm.py @@ -34,8 +34,9 @@ SCALE = 128 if CI else 1 @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") class TestGemm(unittest.TestCase): - def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096)//SCALE, N//SCALE, N//SCALE, dtype=dtypes.half) + def test_simple(self): verify_asm_gemm(1, N:=(getenv("N", 4096)//SCALE), N, N, dtype=dtypes.half) def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE) + def test_gemm_batched(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE) def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2) class TestGemmLarge(unittest.TestCase):