mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
test flash attention backward
This commit is contained in:
@@ -28,18 +28,43 @@ class TestRangeifyEdgeCase(unittest.TestCase):
|
||||
res = Tensor.cat(a, c, dim=0)
|
||||
self.assertEqual(res.numpy()[-1, :16].tolist(), [512] * 16)
|
||||
|
||||
if getenv("BIG") > 1:
|
||||
# llama 8B
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
|
||||
elif getenv("BIG") > 0:
|
||||
# bigger
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
|
||||
else:
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
@unittest.skipIf(CPU_LVP, "broken in LVP")
|
||||
class TestPcontig(unittest.TestCase):
|
||||
def test_flash_attention(self):
|
||||
if getenv("BIG") > 1:
|
||||
# llama 8B
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 2048, 128
|
||||
elif getenv("BIG") > 0:
|
||||
# bigger
|
||||
BS, HEADS, SEQLEN, EMB = 4, 32, 1024, 64
|
||||
else:
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
def test_flash_attention_bw(self):
|
||||
def fa_bw():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
|
||||
q.scaled_dot_product_attention(k, v).sum().backward()
|
||||
return q,k,v
|
||||
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
q,k,v = fa_bw()
|
||||
grads = q.grad, k.grad, v.grad
|
||||
Tensor.realize(*grads)
|
||||
|
||||
with Context(DEBUG=2):
|
||||
GlobalCounters.reset()
|
||||
q,k,v = fa_bw()
|
||||
cmp_grads = q.grad, k.grad, v.grad
|
||||
Tensor.realize(*cmp_grads)
|
||||
|
||||
with Context(DEBUG=0):
|
||||
mses = [((x-y)**2).sum().item() for x,y in zip(grads, cmp_grads)]
|
||||
mse = sum(mses)
|
||||
print(f"mse: {mse}")
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
def test_flash_attention(self):
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
|
||||
Reference in New Issue
Block a user