test flash attention backward (#12762)

* test flash attention backward

* TODO: fix pcontig

* end ranges

* render colors

* very big

* multiout at every level

* reset ending ranges

* fix tests

* ugh
This commit is contained in:
George Hotz
2025-10-17 23:15:59 +08:00
committed by GitHub
parent 33025b99f6
commit 062a6d68d7
5 changed files with 74 additions and 26 deletions

View File

@@ -28,29 +28,67 @@ class TestRangeifyEdgeCase(unittest.TestCase):
res = Tensor.cat(a, c, dim=0)
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
if getenv("BIG") > 2:
# llama 8B (8192)
BS, HEADS, SEQLEN, EMB = 4, 32, 8192, 128
elif getenv("BIG") > 1:
# llama 8B
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
elif getenv("BIG") > 0:
# bigger
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
else:
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
@unittest.skipIf(CPU_LVP, "broken in LVP")
class TestPcontig(unittest.TestCase):
def test_flash_attention(self):
if getenv("BIG") > 1:
# llama 8B
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
elif getenv("BIG") > 0:
# bigger
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
else:
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
def test_flash_attention_bw(self):
def fa_bw():
Tensor.manual_seed(1337)
with Context(DEBUG=0):
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
attn_output.weight.requires_grad_().realize()
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
GlobalCounters.reset()
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
out = attn_output(attn)
loss = (out - target).square().mean()
loss.backward()
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad)]
ret = [out, q.grad, k.grad, v.grad]
Tensor.realize(*ret)
return ret
with Context(PCONTIG=2, REAL_SUBSTITUTE=1, DEBUG=2):
grads = fa_bw()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=2):
cmp_grads = fa_bw()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=0):
mses = [((x-y)**2).sum().item() for x,y in zip(grads, cmp_grads)]
mse = sum(mses)
print(f"mse: {mse}")
self.assertLessEqual(mse, 1e-6)
def test_flash_attention(self):
def fa():
Tensor.manual_seed(1337)
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
GlobalCounters.reset()
return q.scaled_dot_product_attention(k, v).realize()
with Context(PCONTIG=2, DEBUG=2):
GlobalCounters.reset()
ret = fa()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=2):
GlobalCounters.reset()
cmp = fa()
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
with Context(DEBUG=0):
mse = ((cmp-ret)**2).sum().item()
print(f"mse: {mse}")