mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: tk fa in tensor (#13580)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import unittest, math
|
||||
import unittest, math, time
|
||||
|
||||
from tinygrad import Tensor, Device, dtypes, Context
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.engine.realize import ExecItem, get_runner
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.helpers import CI
|
||||
import numpy as np
|
||||
|
||||
@@ -629,7 +630,12 @@ class TestTK(unittest.TestCase):
|
||||
Tensor.realize(q, k, v, out)
|
||||
|
||||
ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (out, q, k, v)])
|
||||
for _ in range(5): ei.run(wait=True)
|
||||
for _ in range(5):
|
||||
et = ei.run(wait=True)
|
||||
attn_flops = 2 * B * H * N * N * D + \
|
||||
4 * B * H * N * N + \
|
||||
2 * B * H * N * N * D
|
||||
print(f"{attn_flops/(et*1e9):2f} GFLOPS")
|
||||
out = out.float()
|
||||
|
||||
q_permuted = q.permute(0, 2, 1, 3)
|
||||
@@ -640,5 +646,34 @@ class TestTK(unittest.TestCase):
|
||||
|
||||
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
|
||||
|
||||
def test_fast_fa(self):
|
||||
from extra.thunder.tiny.fa import flash_attention
|
||||
|
||||
B, N, H, H_KV, D = 2, 8192, 32, 8, 128
|
||||
|
||||
with Context(DEBUG=0):
|
||||
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16).contiguous()
|
||||
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
|
||||
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
|
||||
Tensor.realize(q, k, v)
|
||||
|
||||
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
|
||||
|
||||
fa_jitted = TinyJit(flash_attention)
|
||||
|
||||
for _ in range(10):
|
||||
st = time.perf_counter()
|
||||
out = fa_jitted(q, k, v, is_causal=True)
|
||||
et = time.perf_counter() - st
|
||||
attn_flops = 2 * B * H * N * N * D + \
|
||||
4 * B * H * N * N + \
|
||||
2 * B * H * N * N * D
|
||||
print(f"{attn_flops/(et*1e9):2f} GFLOPS")
|
||||
out = out.float().transpose(1, 2)
|
||||
|
||||
ref = q.scaled_dot_product_attention(k, v, is_causal=True, enable_gqa=True).float().transpose(1, 2)
|
||||
|
||||
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user