mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cleanups from flash attention branch (#12897)
This commit is contained in:
@@ -42,29 +42,35 @@ elif getenv("BIG") > 0:
|
||||
else:
|
||||
BS, HEADS, SEQLEN, EMB = 4, 2, 16, 8
|
||||
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
GlobalCounters.reset()
|
||||
return q.scaled_dot_product_attention(k, v)
|
||||
|
||||
def fa_bw():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0):
|
||||
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
|
||||
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
|
||||
attn_output.weight.requires_grad_().realize()
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
|
||||
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
|
||||
out = attn_output(attn)
|
||||
loss = (out - target).square().mean()
|
||||
loss.backward()
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)]
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad]
|
||||
ret = [out, q.grad, k.grad, v.grad]
|
||||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
@unittest.skipIf(isinstance(Device[Device.DEFAULT].renderer, (NIRRenderer, PTXRenderer)), "broken in LVP and PTX")
|
||||
class TestPcontig(unittest.TestCase):
|
||||
def test_flash_attention_bw(self):
|
||||
def fa_bw():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0):
|
||||
q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize().requires_grad_() for _ in range(3)]
|
||||
attn_output = nn.Linear(HEADS*EMB, HEADS*EMB, bias=False)
|
||||
attn_output.weight.requires_grad_().realize()
|
||||
target = Tensor.rand(BS, SEQLEN, HEADS*EMB).contiguous().realize()
|
||||
|
||||
GlobalCounters.reset()
|
||||
attn = q.scaled_dot_product_attention(k, v).contiguous().contiguous_backward()
|
||||
attn = attn.transpose(1, 2).reshape(BS, SEQLEN, -1)
|
||||
out = attn_output(attn)
|
||||
loss = (out - target).square().mean()
|
||||
loss.backward()
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, v.grad, dim=-1)]
|
||||
#ret = [out, Tensor.stack(q.grad, k.grad, dim=-1), v.grad]
|
||||
ret = [out, q.grad, k.grad, v.grad]
|
||||
Tensor.realize(*ret)
|
||||
return ret
|
||||
|
||||
with Context(PCONTIG=max(2, PCONTIG.value), DEBUG=2):
|
||||
grads = fa_bw()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
@@ -80,17 +86,11 @@ class TestPcontig(unittest.TestCase):
|
||||
self.assertLessEqual(mse, 1e-6)
|
||||
|
||||
def test_flash_attention(self):
|
||||
def fa():
|
||||
Tensor.manual_seed(1337)
|
||||
with Context(DEBUG=0): q,k,v = [Tensor.rand(BS, HEADS, SEQLEN, EMB).contiguous().realize() for _ in range(3)]
|
||||
GlobalCounters.reset()
|
||||
return q.scaled_dot_product_attention(k, v).realize()
|
||||
|
||||
with Context(PCONTIG=2, DEBUG=2):
|
||||
ret = fa()
|
||||
ret = fa().realize()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=2):
|
||||
cmp = fa()
|
||||
cmp = fa().realize()
|
||||
print(f"{GlobalCounters.global_ops/1e9:.2f} GFLOPS")
|
||||
with Context(DEBUG=0):
|
||||
mse = ((cmp-ret)**2).sum().item()
|
||||
|
||||
Reference in New Issue
Block a user