diff --git a/extra/thunder/tiny/fa.py b/extra/thunder/tiny/fa.py new file mode 100644 index 0000000000..211b4e069f --- /dev/null +++ b/extra/thunder/tiny/fa.py @@ -0,0 +1,166 @@ +import math + +from tinygrad import Tensor, dtypes +from tinygrad.uop.ops import UOp, Ops, KernelInfo + +from extra.thunder.tiny.tk import WARP_THREADS +from extra.thunder.tiny.tk.kernel import Kernel +from extra.thunder.tiny.tk.tiles import GL, TileLayout + +NUM_WORKERS = 1 +Q_BLOCK_SIZE = 16 +KV_BLOCK_SIZE = 16 + +def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False): + if len(xq.shape) == 3: xq, xk, xv = xq.unsqueeze(0), xk.unsqueeze(0), xv.unsqueeze(0) + + odtype = xq.dtype + xq, xk, xv = xq.transpose(1, 2).cast(dtypes.bfloat16), xk.transpose(1, 2).cast(dtypes.bfloat16), xv.transpose(1, 2).cast(dtypes.bfloat16) + + _, N_, _, D_ = xq.shape + block_size = max(Q_BLOCK_SIZE, KV_BLOCK_SIZE) + assert D_ % block_size == 0, f"embedding dimension must be multiple of block size, got {D_=} {block_size=}" + + # pad to multiple of block size + xq = xq.pad(((0, 0), (0, (block_size - (xq.shape[1] % block_size)) % block_size), (0, 0), (0, 0))) + xk = xk.pad(((0, 0), (0, (block_size - (xk.shape[1] % block_size)) % block_size), (0, 0), (0, 0))) + xv = xv.pad(((0, 0), (0, (block_size - (xv.shape[1] % block_size)) % block_size), (0, 0), (0, 0))) + + B, N, H, D = xq.shape + H_KV = xk.shape[2] + GROUP_SIZE = H // H_KV + print(f"Flash Attention {B=} {N=} {H=} {D=} {H_KV=} {GROUP_SIZE=}") + + def custom_forward(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, mu:UOp) -> UOp: + with Kernel("fa_custom_forward", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker: + warp = ker.warp + + o, q, k, v, mask, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(mu, ker), GL(l_vecu, ker) + + head = ker.blockIdx_x + head_kv = head // GROUP_SIZE + batch = ker.blockIdx_z + q_seq = ker.blockIdx_y * NUM_WORKERS + ker.warpid + + k_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16) + v_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16) + + q_reg_fl = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32) + q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16) + q_reg_transposed = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) + k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16) + k_reg_transposed = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) + v_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL) + o_reg = ker.rt((D, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL) + o_reg_transposed = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32) + att_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL) + att_block_mma = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) + mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32) + mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL) + + max_vec_last = ker.rv(KV_BLOCK_SIZE, dtypes.float32) + max_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32) + norm_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32) + scale_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32) + + max_vec = warp.neg_inf(max_vec) + norm_vec = warp.zero(norm_vec) + o_reg = warp.zero(o_reg) + scale_vec = warp.ones(scale_vec) + + # load q tile + q_reg_fl = warp.load(q_reg_fl, q, (), (batch, q_seq, head, 0), axis=1) + q_reg_fl *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2)) + q_reg = warp.copy(q_reg, q_reg_fl) + q_reg_transposed = warp.transpose(q_reg_transposed, q_reg) + + for kv_idx in ker.range(N // KV_BLOCK_SIZE): + k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1) + v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1) + + k_reg = warp.load(k_reg, k_smem) + v_reg = warp.load(v_reg, v_smem) + + # mma qk^t + att_block = warp.zero(att_block.after(kv_idx)) + k_reg_transposed = warp.transpose(k_reg_transposed, k_reg) + att_block = warp.mma_AtB(att_block, k_reg_transposed, q_reg_transposed) + + # apply attention mask + mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2) + mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) + att_block += mask_reg_transposed + + # softmax + max_vec_last = warp.copy(max_vec_last.after(kv_idx), max_vec) + max_vec = warp.row_reduce(max_vec.after(max_vec_last), att_block, lambda a, b: a.maximum(b), init_value=-math.inf) + + scale_vec = warp.map(scale_vec.after(max_vec_last, max_vec), lambda _, idx: max_vec_last[*idx] - max_vec[*idx]) + scale_vec = scale_vec.exp2() + + o_reg *= scale_vec + norm_vec *= scale_vec + + att_block -= max_vec + att_block = att_block.exp2() + + norm_vec = warp.row_reduce(norm_vec.after(scale_vec), att_block, lambda a, b: a + b) + + # mma av + att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block) + o_reg = warp.mma_AtB(o_reg, v_reg, att_block_mma) + o_reg = ker.endrange() + norm_vec = norm_vec.after(o_reg) + max_vec = max_vec.after(o_reg) + + o_reg /= norm_vec + + o_reg_transposed = warp.transpose(o_reg_transposed, o_reg) + o = warp.store(o, o_reg_transposed, (batch, q_seq, head, 0), (), axis=1) + + norm_vec = norm_vec.after(o) + max_vec = max_vec.after(o) + + max_vec *= math.log(2) + norm_vec = norm_vec.log2() * math.log(2) + norm_vec += max_vec + l_vec = warp.store(l_vec, norm_vec, (batch, head, 0, q_seq), (), axis=2) + o = o.after(l_vec) + + return ker.finish() + + def custom_backward_q(out_qu:UOp, gradu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: + return UOp.sink(arg=KernelInfo(name="fa_custom_backward_q")) + + def custom_backward_kv(out_ku:UOp, out_vu:UOp, gradu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: + return UOp.sink(arg=KernelInfo(name="fa_custom_backward_kv")) + + if is_causal: + if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") + attn_mask = Tensor.ones((B, 1, N, N), requires_grad=False, device=xq.device, dtype=dtypes.bool).tril() + if attn_mask is not None: + if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf")) + else: + attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=xq.device, dtype=dtypes.float32) + + attn = Tensor.empty_like(xq) + l_vec = Tensor.empty(B, H, 1, N, requires_grad=False, device=xq.device, dtype=dtypes.float32).detach() + + def grad(grad:UOp, kernel:UOp) -> tuple[None, None, UOp, UOp, UOp, None]: + grad_q = Tensor.empty_like(q := Tensor(kernel.src[2])) + grad_k = Tensor.empty_like(k := Tensor(kernel.src[3])) + grad_v = Tensor.empty_like(v := Tensor(kernel.src[4])) + mask = Tensor(kernel.src[5]) + + delta_vec = (Tensor(grad) * attn).sum(-1).unsqueeze(-2).detach() + + print(l_vec.numpy()) + + grad_q = Tensor.custom_kernel(grad_q, Tensor(grad), q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0] + grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, Tensor(grad), q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2] + return (None, None, grad_q.uop, grad_k.uop, grad_v.uop, None) + + attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward, grad_fxn=grad)[:2] + attn = attn[:, :N_, :, :D_] + + return attn.transpose(1, 2).cast(odtype) diff --git a/test/testextra/test_tk.py b/test/testextra/test_tk.py index d310c5a439..6cf9ef46f8 100644 --- a/test/testextra/test_tk.py +++ b/test/testextra/test_tk.py @@ -1,8 +1,9 @@ -import unittest, math +import unittest, math, time from tinygrad import Tensor, Device, dtypes, Context from tinygrad.uop.ops import UOp, Ops from tinygrad.engine.realize import ExecItem, get_runner +from tinygrad.engine.jit import TinyJit from tinygrad.helpers import CI import numpy as np @@ -629,7 +630,12 @@ class TestTK(unittest.TestCase): Tensor.realize(q, k, v, out) ei = ExecItem(get_runner(Device.DEFAULT, sink), [t.uop.buffer for t in (out, q, k, v)]) - for _ in range(5): ei.run(wait=True) + for _ in range(5): + et = ei.run(wait=True) + attn_flops = 2 * B * H * N * N * D + \ + 4 * B * H * N * N + \ + 2 * B * H * N * N * D + print(f"{attn_flops/(et*1e9):2f} GFLOPS") out = out.float() q_permuted = q.permute(0, 2, 1, 3) @@ -640,5 +646,34 @@ class TestTK(unittest.TestCase): np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2) + def test_fast_fa(self): + from extra.thunder.tiny.fa import flash_attention + + B, N, H, H_KV, D = 2, 8192, 32, 8, 128 + + with Context(DEBUG=0): + q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16).contiguous() + k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous() + v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16).contiguous() + Tensor.realize(q, k, v) + + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + fa_jitted = TinyJit(flash_attention) + + for _ in range(10): + st = time.perf_counter() + out = fa_jitted(q, k, v, is_causal=True) + et = time.perf_counter() - st + attn_flops = 2 * B * H * N * N * D + \ + 4 * B * H * N * N + \ + 2 * B * H * N * N * D + print(f"{attn_flops/(et*1e9):2f} GFLOPS") + out = out.float().transpose(1, 2) + + ref = q.scaled_dot_product_attention(k, v, is_causal=True, enable_gqa=True).float().transpose(1, 2) + + np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 5dfa3e9891..da7b6d15ca 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3712,6 +3712,11 @@ class Tensor(OpMixin): """ # NOTE: it also works when `key` and `value` have symbolic shape. assert all_int(self.shape), f"does not support symbolic shape {self.shape}" + + if getenv("FLASH_ATTENTION"): + from extra.thunder.tiny.fa import flash_attention + return flash_attention(self, key, value, attn_mask=attn_mask, is_causal=is_causal) + # GQA: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html if enable_gqa: key = key.repeat_interleave(self.shape[-3] // key.shape[-3], dim=-3)