feat: tk fa in tensor (#13580)

This commit is contained in:
wozeparrot
2025-12-05 14:36:29 -08:00
committed by GitHub
parent cb4c6324ef
commit 93f1baca77
3 changed files with 208 additions and 2 deletions

View File

@@ -1,8 +1,9 @@
import unittest, math
import unittest, math, time
from tinygrad import Tensor, Device, dtypes, Context
from tinygrad.uop.ops import UOp, Ops
from tinygrad.engine.realize import ExecItem, get_runner
from tinygrad.engine.jit import TinyJit
from tinygrad.helpers import CI
import numpy as np
@@ -629,7 +630,12 @@ class TestTK(unittest.TestCase):
Tensor.realize(q, k, v, out)
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (out, q, k, v)])
for _ in range(5): ei.run(wait=True)
for _ in range(5):
et = ei.run(wait=True)
attn_flops = 2 * B * H * N * N * D + \
4 * B * H * N * N + \
2 * B * H * N * N * D
print(f"{attn_flops/(et*1e9):2f} GFLOPS")
out = out.float()
q_permuted = q.permute(0, 2, 1, 3)
@@ -640,5 +646,34 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
def test_fast_fa(self):
from extra.thunder.tiny.fa import flash_attention
B, N, H, H_KV, D = 2, 8192, 32, 8, 128
with Context(DEBUG=0):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16).contiguous()
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
Tensor.realize(q, k, v)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
fa_jitted = TinyJit(flash_attention)
for _ in range(10):
st = time.perf_counter()
out = fa_jitted(q, k, v, is_causal=True)
et = time.perf_counter() - st
attn_flops = 2 * B * H * N * N * D + \
4 * B * H * N * N + \
2 * B * H * N * N * D
print(f"{attn_flops/(et*1e9):2f} GFLOPS")
out = out.float().transpose(1, 2)
ref = q.scaled_dot_product_attention(k, v, is_causal=True, enable_gqa=True).float().transpose(1, 2)
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
if __name__ == "__main__":
unittest.main()