From 9f082e8e25d1a22bb330ab70ea95014aab4109a7 Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 2 Jan 2026 21:45:51 -0500 Subject: [PATCH] fa: split kv bwd into 2 kernels (#13981) --- extra/thunder/tiny/fa.py | 131 ++++++++++++++++++++++++++------- extra/thunder/tiny/tk/group.py | 4 +- 2 files changed, 105 insertions(+), 30 deletions(-) diff --git a/extra/thunder/tiny/fa.py b/extra/thunder/tiny/fa.py index 7e7bb70658..fa07f7e690 100644 --- a/extra/thunder/tiny/fa.py +++ b/extra/thunder/tiny/fa.py @@ -45,7 +45,6 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False k_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16) v_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16) - q_reg_fl = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32) q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16) q_reg_transposed = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16) @@ -69,9 +68,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False scale_vec = warp.ones(scale_vec) # load q tile - q_reg_fl = warp.load(q_reg_fl, q, (), (batch, q_seq, head, 0), axis=1) - q_reg_fl *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2)) - q_reg = warp.copy(q_reg, q_reg_fl) + q_reg = warp.load(q_reg, q, (), (batch, q_seq, head, 0), axis=1) q_reg_transposed = warp.transpose(q_reg_transposed, q_reg) for kv_idx in ker.range(N // KV_BLOCK_SIZE): @@ -85,6 +82,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False att_block = warp.zero(att_block.after(kv_idx)) k_reg_transposed = warp.transpose(k_reg_transposed, k_reg) att_block = warp.mma_AtB(att_block, k_reg_transposed, q_reg_transposed) + att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2)) # apply attention mask mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2) @@ -217,11 +215,11 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False return ker.finish() - def custom_backward_kv(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: - with Kernel("fa_custom_backward_kv", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker: + def custom_backward_k(dku:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: + with Kernel("fa_custom_backward_k", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker: warp = ker.warp - dk, dv, do, q, k, v, mask = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker) + dk, do, q, k, v, mask = GL(dku, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker) l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker) head_kv = ker.blockIdx_x @@ -242,7 +240,6 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL) dk_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL) - dv_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL) do_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16) do_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL) @@ -256,6 +253,98 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False delta_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32) dk_reg = warp.zero(dk_reg) + + # load kv tile + k_reg = warp.load(k_reg, k, (), (batch, kv_seq, head_kv, 0), axis=1) + k_reg_t = warp.transpose(k_reg_t, k_reg) + v_reg = warp.load(v_reg, v, (), (batch, kv_seq, head_kv, 0), axis=1) + + for q_idx in ker.range(N // Q_BLOCK_SIZE): + for g in ker.range(GROUP_SIZE): + head_q = head_kv * GROUP_SIZE + g + + # load q and do + q_smem = warp.load(q_smem, q, (), (batch, q_idx, head_q, 0), axis=1) + do_smem = warp.load(do_smem, do, (), (batch, q_idx, head_q, 0), axis=1) + + q_reg = warp.load(q_reg, q_smem) + q_reg_t = warp.transpose(q_reg_t, q_reg) + q_reg_col = warp.load(q_reg_col, q_smem) + do_reg = warp.load(do_reg, do_smem) + do_reg_col = warp.load(do_reg_col, do_smem) + + # load l_vec and delta_vec + l_vec_reg = warp.load(l_vec_reg, l_vec, (), (batch, head_q, 0, q_idx), axis=2) + l_vec_reg *= 1.0 / math.log(2) + delta_vec_reg = warp.load(delta_vec_reg, delta_vec, (), (batch, head_q, 0, q_idx), axis=2) + + # mma qk^t + att_block = warp.zero(att_block.after(g)) + att_block = warp.mma_AtB(att_block, k_reg_t, q_reg_t) + + # apply attention mask + mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_idx, kv_seq), axis=2) + mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg) + att_block += mask_reg_transposed + + att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2)) + att_block -= l_vec_reg + att_block = att_block.exp2() + + dp_block = warp.zero(dp_block.after(g, q_idx)) + dp_block = warp.mma_ABt(dp_block, v_reg, do_reg) + dp_block -= delta_vec_reg + att_block *= dp_block + + att_block_mma = warp.copy(att_block_mma, att_block) + att_block_transposed = warp.transpose(att_block_transposed, att_block_mma) + att_smem = warp.store(att_smem, att_block_transposed) + att_block_row = warp.load(att_block_row, att_smem) + dk_reg = warp.mma_AB(dk_reg, att_block_row, q_reg_col) + dk_reg = ker.endrange(2) + + dk_reg *= 1.0 / math.sqrt(D) + + dk = warp.store(dk, dk_reg, (batch, kv_seq, head_kv, 0), axis=1) + + return ker.finish() + + def custom_backward_v(dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp: + with Kernel("fa_custom_backward_v", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker: + warp = ker.warp + + dv, do, q, k, v, mask = GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker) + l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker) + + head_kv = ker.blockIdx_x + batch = ker.blockIdx_z + kv_seq = ker.blockIdx_y * NUM_WORKERS + ker.warpid + + q_smem = ker.st((Q_BLOCK_SIZE, D), dtypes.bfloat16) + do_smem = ker.st((Q_BLOCK_SIZE, D), dtypes.bfloat16) + att_smem = ker.st((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16) + + q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16) + q_reg_t = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) + q_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL) + k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16) + k_reg_t = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) + v_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16) + mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32) + mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL) + + dv_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL) + do_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16) + do_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL) + + att_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL) + att_block_mma = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) + att_block_transposed = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL) + att_block_row = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16) + + l_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32) + delta_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32) + dv_reg = warp.zero(dv_reg) # load kv tile @@ -299,27 +388,12 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False att_block_transposed = warp.transpose(att_block_transposed, att_block_mma) att_smem = warp.store(att_smem, att_block_transposed) att_block_row = warp.load(att_block_row, att_smem) - dv_reg_ = warp.mma_AB(dv_reg, att_block_row, do_reg_col) + dv_reg = warp.mma_AB(dv_reg, att_block_row, do_reg_col) + dv_reg = ker.endrange(2) - dp_block = warp.zero(dp_block.after(g, q_idx, dv_reg_)) - dp_block = warp.mma_ABt(dp_block, v_reg, do_reg) - dp_block -= delta_vec_reg - att_block *= dp_block - - att_block_mma = warp.copy(att_block_mma, att_block) - att_block_transposed = warp.transpose(att_block_transposed, att_block_mma) - att_smem = warp.store(att_smem, att_block_transposed) - att_block_row = warp.load(att_block_row, att_smem) - dk_reg = warp.mma_AB(dk_reg, att_block_row, q_reg_col) - dk_reg = ker.endrange(2) - dv_reg = dv_reg.after(dk_reg) - - dk_reg *= 1.0 / math.sqrt(D) - - dk = warp.store(dk, dk_reg, (batch, kv_seq, head_kv, 0), axis=1) dv = warp.store(dv, dv_reg, (batch, kv_seq, head_kv, 0), axis=1) - return ker.finish(2) + return ker.finish() if is_causal: if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True") @@ -339,10 +413,11 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False grad_v = Tensor.empty_like(v := Tensor(kernel.src[4])) mask = Tensor(kernel.src[5]) - delta_vec = (grad * attn).sum(-1).transpose(1, 2).unsqueeze(-2).detach() + delta_vec = (grad * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach() grad_q = Tensor.custom_kernel(grad_q, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0] - grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2] + grad_k = Tensor.custom_kernel(grad_k, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_k)[0] + grad_v = Tensor.custom_kernel(grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_v)[0] return (None, None, grad_q.uop, grad_k.uop, grad_v.uop, None) attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward, grad_fxn=grad)[:2] diff --git a/extra/thunder/tiny/tk/group.py b/extra/thunder/tiny/tk/group.py index 9a7391db17..9fd75579c9 100644 --- a/extra/thunder/tiny/tk/group.py +++ b/extra/thunder/tiny/tk/group.py @@ -24,7 +24,7 @@ class Group: # ops that only work on a single warp - clear_rid = 1000 + clear_rid = 1000000 def clear(self, reg:ALL_TILES, value:float=0): reg = cast(UOp, reg) assert self.warps == 1 @@ -41,7 +41,7 @@ class Group: def ones(self, reg:ALL_TILES): return self.clear(reg, 1) def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf) - copy_rid = 300 + copy_rid = 3000000 def copy(self, dst:ALL_TILES, src:ALL_TILES): dst, src = cast(UOp, dst), cast(UOp, src) assert self.warps == 1