fa: faster (#14453)

This commit is contained in:
wozeparrot
2026-02-02 21:34:17 -08:00
committed by GitHub
parent e579613b90
commit bbcd3d67a3
4 changed files with 122 additions and 31 deletions

View File

@@ -750,6 +750,35 @@ class TestTK(unittest.TestCase):
fa_jitted = TinyJit(flash_attention)
for _ in range(10):
st = time.perf_counter()
out = fa_jitted(q, k, v, is_causal=False)
et = time.perf_counter() - st
attn_flops = 2 * B * H * N * N * D + \
4 * B * H * N * N + \
2 * B * H * N * N * D
print(f"{attn_flops/(et*1e9):2f} GFLOPS")
out = out.float().transpose(1, 2)
ref = q.scaled_dot_product_attention(k, v, is_causal=False, enable_gqa=True).float().transpose(1, 2)
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
def test_fast_fa_causal(self):
from extra.thunder.tiny.fa import flash_attention
B, N, H, H_KV, D = 2, 8192, 32, 8, 128
with Context(DEBUG=0):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16).contiguous()
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous()
Tensor.realize(q, k, v)
q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
fa_jitted = TinyJit(flash_attention)
for _ in range(10):
st = time.perf_counter()
out = fa_jitted(q, k, v, is_causal=True)
@@ -838,7 +867,7 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=6e-2, rtol=2e-2)
def test_fast_fa_bwd_causal_jitted(self):
from extra.thunder.tiny.fa import flash_attention