mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: tk fa in tensor (#13580)
This commit is contained in:
166
extra/thunder/tiny/fa.py
Normal file
166
extra/thunder/tiny/fa.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import math
|
||||
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||
|
||||
from extra.thunder.tiny.tk import WARP_THREADS
|
||||
from extra.thunder.tiny.tk.kernel import Kernel
|
||||
from extra.thunder.tiny.tk.tiles import GL, TileLayout
|
||||
|
||||
NUM_WORKERS = 1
|
||||
Q_BLOCK_SIZE = 16
|
||||
KV_BLOCK_SIZE = 16
|
||||
|
||||
def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False):
|
||||
if len(xq.shape) == 3: xq, xk, xv = xq.unsqueeze(0), xk.unsqueeze(0), xv.unsqueeze(0)
|
||||
|
||||
odtype = xq.dtype
|
||||
xq, xk, xv = xq.transpose(1, 2).cast(dtypes.bfloat16), xk.transpose(1, 2).cast(dtypes.bfloat16), xv.transpose(1, 2).cast(dtypes.bfloat16)
|
||||
|
||||
_, N_, _, D_ = xq.shape
|
||||
block_size = max(Q_BLOCK_SIZE, KV_BLOCK_SIZE)
|
||||
assert D_ % block_size == 0, f"embedding dimension must be multiple of block size, got {D_=} {block_size=}"
|
||||
|
||||
# pad to multiple of block size
|
||||
xq = xq.pad(((0, 0), (0, (block_size - (xq.shape[1] % block_size)) % block_size), (0, 0), (0, 0)))
|
||||
xk = xk.pad(((0, 0), (0, (block_size - (xk.shape[1] % block_size)) % block_size), (0, 0), (0, 0)))
|
||||
xv = xv.pad(((0, 0), (0, (block_size - (xv.shape[1] % block_size)) % block_size), (0, 0), (0, 0)))
|
||||
|
||||
B, N, H, D = xq.shape
|
||||
H_KV = xk.shape[2]
|
||||
GROUP_SIZE = H // H_KV
|
||||
print(f"Flash Attention {B=} {N=} {H=} {D=} {H_KV=} {GROUP_SIZE=}")
|
||||
|
||||
def custom_forward(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, mu:UOp) -> UOp:
|
||||
with Kernel("fa_custom_forward", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
o, q, k, v, mask, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(mu, ker), GL(l_vecu, ker)
|
||||
|
||||
head = ker.blockIdx_x
|
||||
head_kv = head // GROUP_SIZE
|
||||
batch = ker.blockIdx_z
|
||||
q_seq = ker.blockIdx_y * NUM_WORKERS + ker.warpid
|
||||
|
||||
k_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
v_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
|
||||
q_reg_fl = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32)
|
||||
q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
q_reg_transposed = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
k_reg_transposed = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
v_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
|
||||
o_reg = ker.rt((D, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
o_reg_transposed = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32)
|
||||
att_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
att_block_mma = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32)
|
||||
mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
|
||||
max_vec_last = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
max_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
norm_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
scale_vec = ker.rv(KV_BLOCK_SIZE, dtypes.float32)
|
||||
|
||||
max_vec = warp.neg_inf(max_vec)
|
||||
norm_vec = warp.zero(norm_vec)
|
||||
o_reg = warp.zero(o_reg)
|
||||
scale_vec = warp.ones(scale_vec)
|
||||
|
||||
# load q tile
|
||||
q_reg_fl = warp.load(q_reg_fl, q, (), (batch, q_seq, head, 0), axis=1)
|
||||
q_reg_fl *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
|
||||
q_reg = warp.copy(q_reg, q_reg_fl)
|
||||
q_reg_transposed = warp.transpose(q_reg_transposed, q_reg)
|
||||
|
||||
for kv_idx in ker.range(N // KV_BLOCK_SIZE):
|
||||
k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1)
|
||||
v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1)
|
||||
|
||||
k_reg = warp.load(k_reg, k_smem)
|
||||
v_reg = warp.load(v_reg, v_smem)
|
||||
|
||||
# mma qk^t
|
||||
att_block = warp.zero(att_block.after(kv_idx))
|
||||
k_reg_transposed = warp.transpose(k_reg_transposed, k_reg)
|
||||
att_block = warp.mma_AtB(att_block, k_reg_transposed, q_reg_transposed)
|
||||
|
||||
# apply attention mask
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
|
||||
# softmax
|
||||
max_vec_last = warp.copy(max_vec_last.after(kv_idx), max_vec)
|
||||
max_vec = warp.row_reduce(max_vec.after(max_vec_last), att_block, lambda a, b: a.maximum(b), init_value=-math.inf)
|
||||
|
||||
scale_vec = warp.map(scale_vec.after(max_vec_last, max_vec), lambda _, idx: max_vec_last[*idx] - max_vec[*idx])
|
||||
scale_vec = scale_vec.exp2()
|
||||
|
||||
o_reg *= scale_vec
|
||||
norm_vec *= scale_vec
|
||||
|
||||
att_block -= max_vec
|
||||
att_block = att_block.exp2()
|
||||
|
||||
norm_vec = warp.row_reduce(norm_vec.after(scale_vec), att_block, lambda a, b: a + b)
|
||||
|
||||
# mma av
|
||||
att_block_mma = warp.copy(att_block_mma.after(kv_idx, norm_vec), att_block)
|
||||
o_reg = warp.mma_AtB(o_reg, v_reg, att_block_mma)
|
||||
o_reg = ker.endrange()
|
||||
norm_vec = norm_vec.after(o_reg)
|
||||
max_vec = max_vec.after(o_reg)
|
||||
|
||||
o_reg /= norm_vec
|
||||
|
||||
o_reg_transposed = warp.transpose(o_reg_transposed, o_reg)
|
||||
o = warp.store(o, o_reg_transposed, (batch, q_seq, head, 0), (), axis=1)
|
||||
|
||||
norm_vec = norm_vec.after(o)
|
||||
max_vec = max_vec.after(o)
|
||||
|
||||
max_vec *= math.log(2)
|
||||
norm_vec = norm_vec.log2() * math.log(2)
|
||||
norm_vec += max_vec
|
||||
l_vec = warp.store(l_vec, norm_vec, (batch, head, 0, q_seq), (), axis=2)
|
||||
o = o.after(l_vec)
|
||||
|
||||
return ker.finish()
|
||||
|
||||
def custom_backward_q(out_qu:UOp, gradu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
return UOp.sink(arg=KernelInfo(name="fa_custom_backward_q"))
|
||||
|
||||
def custom_backward_kv(out_ku:UOp, out_vu:UOp, gradu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
return UOp.sink(arg=KernelInfo(name="fa_custom_backward_kv"))
|
||||
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
attn_mask = Tensor.ones((B, 1, N, N), requires_grad=False, device=xq.device, dtype=dtypes.bool).tril()
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == dtypes.bool: attn_mask = attn_mask.where(0, -float("inf"))
|
||||
else:
|
||||
attn_mask = Tensor.zeros((B, 1, N, N), requires_grad=False, device=xq.device, dtype=dtypes.float32)
|
||||
|
||||
attn = Tensor.empty_like(xq)
|
||||
l_vec = Tensor.empty(B, H, 1, N, requires_grad=False, device=xq.device, dtype=dtypes.float32).detach()
|
||||
|
||||
def grad(grad:UOp, kernel:UOp) -> tuple[None, None, UOp, UOp, UOp, None]:
|
||||
grad_q = Tensor.empty_like(q := Tensor(kernel.src[2]))
|
||||
grad_k = Tensor.empty_like(k := Tensor(kernel.src[3]))
|
||||
grad_v = Tensor.empty_like(v := Tensor(kernel.src[4]))
|
||||
mask = Tensor(kernel.src[5])
|
||||
|
||||
delta_vec = (Tensor(grad) * attn).sum(-1).unsqueeze(-2).detach()
|
||||
|
||||
print(l_vec.numpy())
|
||||
|
||||
grad_q = Tensor.custom_kernel(grad_q, Tensor(grad), q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
|
||||
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, Tensor(grad), q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]
|
||||
return (None, None, grad_q.uop, grad_k.uop, grad_v.uop, None)
|
||||
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward, grad_fxn=grad)[:2]
|
||||
attn = attn[:, :N_, :, :D_]
|
||||
|
||||
return attn.transpose(1, 2).cast(odtype)
|
||||
Reference in New Issue
Block a user