test flash attention backward

This commit is contained in:
George Hotz
2025-10-17 17:28:59 +08:00
parent c9a3464f76
commit eb5070786a

View File

@@ -28,18 +28,43 @@ class TestRangeifyEdgeCase(unittest.TestCase):
res = Tensor.cat(a, c, dim=0)
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
if getenv("BIG") > 1:
# llama 8B
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
elif getenv("BIG") > 0:
# bigger
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
else:
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
@unittest.skipIf(CPU_LVP, "broken in LVP")
class TestPcontig(unittest.TestCase):
def test_flash_attention(self):
if getenv("BIG") > 1:
# llama 8B
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
elif getenv("BIG") > 0:
# bigger
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
else:
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
def test_flash_attention_bw(self):
def fa_bw():
Tensor.manual_seed(1337)
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
q.scaled_dot_product_attention(k, v).sum().backward()
return q,k,v
with Context(PCONTIG=2, DEBUG=2):
GlobalCounters.reset()
q,k,v = fa_bw()
grads = q.grad, k.grad, v.grad
Tensor.realize(*grads)
with Context(DEBUG=2):
GlobalCounters.reset()
q,k,v = fa_bw()
cmp_grads = q.grad, k.grad, v.grad
Tensor.realize(*cmp_grads)
with Context(DEBUG=0):
mses = [((x-y)**2).sum().item() for x,y in zip(grads, cmp_grads)]
mse = sum(mses)
print(f"mse: {mse}")
self.assertLessEqual(mse, 1e-6)
def test_flash_attention(self):
def fa():
Tensor.manual_seed(1337)
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]