mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
grad_b uses custom gemm (#14550)
* grad_b uses custom gemm * fix multi backward, acc is in float32 * test_gemm_batched * square gemm --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> Co-authored-by: qazal <qazal.software@gmail.com>
This commit is contained in:
@@ -66,10 +66,7 @@ def custom_gemm_bw(gradient:UOp, kernel:UOp):
|
||||
assert all_same([gradient.device, a.device, b.device, out.device])
|
||||
a_t, b_t, g_t = Tensor(a, device=a.device), Tensor(b, device=a.device), Tensor(gradient, device=a.device)
|
||||
grad_a = (g_t @ b_t.T).uop
|
||||
a_T = a_t.transpose(-2, -1)
|
||||
a_T = a_T.reshape(*a_T.shape[:-1], 1, a_T.shape[-1])
|
||||
g_r = g_t.reshape(*g_t.shape[:-2], 1, *g_t.shape[-2:]).transpose(-1, -2)
|
||||
grad_b = (a_T * g_r).sum((-1, 0)).uop
|
||||
grad_b = (a_t.permute(2, 0, 1).reshape(a_t.shape[2], -1) @ g_t.reshape(-1, g_t.shape[-1])).uop
|
||||
return (None, grad_a, grad_b)
|
||||
|
||||
# ** main gemm function
|
||||
|
||||
@@ -34,8 +34,9 @@ SCALE = 128 if CI else 1
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
class TestGemm(unittest.TestCase):
|
||||
def test_simple(self): verify_asm_gemm(1, N:=getenv("N", 4096)//SCALE, N//SCALE, N//SCALE, dtype=dtypes.half)
|
||||
def test_simple(self): verify_asm_gemm(1, N:=(getenv("N", 4096)//SCALE), N, N, dtype=dtypes.half)
|
||||
def test_gemm(self): verify_asm_gemm(1, 8192//SCALE, 4096//SCALE, 14336//SCALE)
|
||||
def test_gemm_batched(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE)
|
||||
def test_gemm_multi(self): verify_asm_gemm(2, 8192//SCALE, 4096//SCALE, 4096//SCALE, gpus=2)
|
||||
|
||||
class TestGemmLarge(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user