grad_b uses custom gemm (#14550)

* grad_b uses custom gemm

* fix multi backward, acc is in float32

* test_gemm_batched

* square gemm

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
George Hotz
2026-02-05 14:22:27 +08:00
committed by GitHub
parent f9cfb64cd9
commit 43e7eda4e7
2 changed files with 3 additions and 5 deletions

View File

@@ -66,10 +66,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp):
assert all_same([gradient.device, a.device, b.device, out.device])
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
grad_a = (g_t @ b_t.T).uop
a_T = a_t.transpose(-2, -1)
a_T = a_T.reshape(*a_T.shape[:-1], 1, a_T.shape[-1])
g_r = g_t.reshape(*g_t.shape[:-2], 1, *g_t.shape[-2:]).transpose(-1, -2)
grad_b = (a_T * g_r).sum((-1, 0)).uop
grad_b = (a_t.permute(2, 0, 1).reshape(a_t.shape[2], -1) @ g_t.reshape(-1, g_t.shape[-1])).uop
return (None, grad_a, grad_b)
# ** main gemm function

View File

@@ -34,8 +34,9 @@ SCALE = 128 if CI else 1
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
class TestGemm(unittest.TestCase):
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096)//SCALE, N//SCALE, N//SCALE, dtype=dtypes.half)
def test_simple(self): verify_asm_gemm(1, N:=(getenv("N", 4096)//SCALE), N, N, dtype=dtypes.half)
def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE)
def test_gemm_batched(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE)
def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2)
class TestGemmLarge(unittest.TestCase):