mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
fa: split kv bwd into 2 kernels (#13981)
This commit is contained in:
@@ -45,7 +45,6 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
k_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
v_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
|
||||
q_reg_fl = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32)
|
||||
q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
q_reg_transposed = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
@@ -69,9 +68,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
scale_vec = warp.ones(scale_vec)
|
||||
|
||||
# load q tile
|
||||
q_reg_fl = warp.load(q_reg_fl, q, (), (batch, q_seq, head, 0), axis=1)
|
||||
q_reg_fl *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
|
||||
q_reg = warp.copy(q_reg, q_reg_fl)
|
||||
q_reg = warp.load(q_reg, q, (), (batch, q_seq, head, 0), axis=1)
|
||||
q_reg_transposed = warp.transpose(q_reg_transposed, q_reg)
|
||||
|
||||
for kv_idx in ker.range(N // KV_BLOCK_SIZE):
|
||||
@@ -85,6 +82,7 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
att_block = warp.zero(att_block.after(kv_idx))
|
||||
k_reg_transposed = warp.transpose(k_reg_transposed, k_reg)
|
||||
att_block = warp.mma_AtB(att_block, k_reg_transposed, q_reg_transposed)
|
||||
att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
|
||||
|
||||
# apply attention mask
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2)
|
||||
@@ -217,11 +215,11 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
|
||||
return ker.finish()
|
||||
|
||||
def custom_backward_kv(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
with Kernel("fa_custom_backward_kv", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
def custom_backward_k(dku:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
with Kernel("fa_custom_backward_k", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
dk, dv, do, q, k, v, mask = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
|
||||
dk, do, q, k, v, mask = GL(dku, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
|
||||
l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker)
|
||||
|
||||
head_kv = ker.blockIdx_x
|
||||
@@ -242,7 +240,6 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
|
||||
dk_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL)
|
||||
dv_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL)
|
||||
do_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
do_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
|
||||
|
||||
@@ -256,6 +253,98 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
delta_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
|
||||
|
||||
dk_reg = warp.zero(dk_reg)
|
||||
|
||||
# load kv tile
|
||||
k_reg = warp.load(k_reg, k, (), (batch, kv_seq, head_kv, 0), axis=1)
|
||||
k_reg_t = warp.transpose(k_reg_t, k_reg)
|
||||
v_reg = warp.load(v_reg, v, (), (batch, kv_seq, head_kv, 0), axis=1)
|
||||
|
||||
for q_idx in ker.range(N // Q_BLOCK_SIZE):
|
||||
for g in ker.range(GROUP_SIZE):
|
||||
head_q = head_kv * GROUP_SIZE + g
|
||||
|
||||
# load q and do
|
||||
q_smem = warp.load(q_smem, q, (), (batch, q_idx, head_q, 0), axis=1)
|
||||
do_smem = warp.load(do_smem, do, (), (batch, q_idx, head_q, 0), axis=1)
|
||||
|
||||
q_reg = warp.load(q_reg, q_smem)
|
||||
q_reg_t = warp.transpose(q_reg_t, q_reg)
|
||||
q_reg_col = warp.load(q_reg_col, q_smem)
|
||||
do_reg = warp.load(do_reg, do_smem)
|
||||
do_reg_col = warp.load(do_reg_col, do_smem)
|
||||
|
||||
# load l_vec and delta_vec
|
||||
l_vec_reg = warp.load(l_vec_reg, l_vec, (), (batch, head_q, 0, q_idx), axis=2)
|
||||
l_vec_reg *= 1.0 / math.log(2)
|
||||
delta_vec_reg = warp.load(delta_vec_reg, delta_vec, (), (batch, head_q, 0, q_idx), axis=2)
|
||||
|
||||
# mma qk^t
|
||||
att_block = warp.zero(att_block.after(g))
|
||||
att_block = warp.mma_AtB(att_block, k_reg_t, q_reg_t)
|
||||
|
||||
# apply attention mask
|
||||
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_idx, kv_seq), axis=2)
|
||||
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
|
||||
att_block += mask_reg_transposed
|
||||
|
||||
att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
|
||||
att_block -= l_vec_reg
|
||||
att_block = att_block.exp2()
|
||||
|
||||
dp_block = warp.zero(dp_block.after(g, q_idx))
|
||||
dp_block = warp.mma_ABt(dp_block, v_reg, do_reg)
|
||||
dp_block -= delta_vec_reg
|
||||
att_block *= dp_block
|
||||
|
||||
att_block_mma = warp.copy(att_block_mma, att_block)
|
||||
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
|
||||
att_smem = warp.store(att_smem, att_block_transposed)
|
||||
att_block_row = warp.load(att_block_row, att_smem)
|
||||
dk_reg = warp.mma_AB(dk_reg, att_block_row, q_reg_col)
|
||||
dk_reg = ker.endrange(2)
|
||||
|
||||
dk_reg *= 1.0 / math.sqrt(D)
|
||||
|
||||
dk = warp.store(dk, dk_reg, (batch, kv_seq, head_kv, 0), axis=1)
|
||||
|
||||
return ker.finish()
|
||||
|
||||
def custom_backward_v(dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
|
||||
with Kernel("fa_custom_backward_v", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
|
||||
warp = ker.warp
|
||||
|
||||
dv, do, q, k, v, mask = GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
|
||||
l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker)
|
||||
|
||||
head_kv = ker.blockIdx_x
|
||||
batch = ker.blockIdx_z
|
||||
kv_seq = ker.blockIdx_y * NUM_WORKERS + ker.warpid
|
||||
|
||||
q_smem = ker.st((Q_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
do_smem = ker.st((Q_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
att_smem = ker.st((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16)
|
||||
|
||||
q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
q_reg_t = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
q_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
|
||||
k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
k_reg_t = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
v_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32)
|
||||
mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
|
||||
dv_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL)
|
||||
do_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
|
||||
do_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
|
||||
|
||||
att_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
|
||||
att_block_mma = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
att_block_transposed = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
|
||||
att_block_row = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16)
|
||||
|
||||
l_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
|
||||
delta_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
|
||||
|
||||
dv_reg = warp.zero(dv_reg)
|
||||
|
||||
# load kv tile
|
||||
@@ -299,27 +388,12 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
|
||||
att_smem = warp.store(att_smem, att_block_transposed)
|
||||
att_block_row = warp.load(att_block_row, att_smem)
|
||||
dv_reg_ = warp.mma_AB(dv_reg, att_block_row, do_reg_col)
|
||||
dv_reg = warp.mma_AB(dv_reg, att_block_row, do_reg_col)
|
||||
dv_reg = ker.endrange(2)
|
||||
|
||||
dp_block = warp.zero(dp_block.after(g, q_idx, dv_reg_))
|
||||
dp_block = warp.mma_ABt(dp_block, v_reg, do_reg)
|
||||
dp_block -= delta_vec_reg
|
||||
att_block *= dp_block
|
||||
|
||||
att_block_mma = warp.copy(att_block_mma, att_block)
|
||||
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
|
||||
att_smem = warp.store(att_smem, att_block_transposed)
|
||||
att_block_row = warp.load(att_block_row, att_smem)
|
||||
dk_reg = warp.mma_AB(dk_reg, att_block_row, q_reg_col)
|
||||
dk_reg = ker.endrange(2)
|
||||
dv_reg = dv_reg.after(dk_reg)
|
||||
|
||||
dk_reg *= 1.0 / math.sqrt(D)
|
||||
|
||||
dk = warp.store(dk, dk_reg, (batch, kv_seq, head_kv, 0), axis=1)
|
||||
dv = warp.store(dv, dv_reg, (batch, kv_seq, head_kv, 0), axis=1)
|
||||
|
||||
return ker.finish(2)
|
||||
return ker.finish()
|
||||
|
||||
if is_causal:
|
||||
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
|
||||
@@ -339,10 +413,11 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
grad_v = Tensor.empty_like(v := Tensor(kernel.src[4]))
|
||||
mask = Tensor(kernel.src[5])
|
||||
|
||||
delta_vec = (grad * attn).sum(-1).transpose(1, 2).unsqueeze(-2).detach()
|
||||
delta_vec = (grad * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
|
||||
|
||||
grad_q = Tensor.custom_kernel(grad_q, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
|
||||
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]
|
||||
grad_k = Tensor.custom_kernel(grad_k, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_k)[0]
|
||||
grad_v = Tensor.custom_kernel(grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_v)[0]
|
||||
return (None, None, grad_q.uop, grad_k.uop, grad_v.uop, None)
|
||||
|
||||
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward, grad_fxn=grad)[:2]
|
||||
|
||||
@@ -24,7 +24,7 @@ class Group:
|
||||
|
||||
# ops that only work on a single warp
|
||||
|
||||
clear_rid = 1000
|
||||
clear_rid = 1000000
|
||||
def clear(self, reg:ALL_TILES, value:float=0):
|
||||
reg = cast(UOp, reg)
|
||||
assert self.warps == 1
|
||||
@@ -41,7 +41,7 @@ class Group:
|
||||
def ones(self, reg:ALL_TILES): return self.clear(reg, 1)
|
||||
def neg_inf(self, reg:ALL_TILES): return self.clear(reg, -math.inf)
|
||||
|
||||
copy_rid = 300
|
||||
copy_rid = 3000000
|
||||
def copy(self, dst:ALL_TILES, src:ALL_TILES):
|
||||
dst, src = cast(UOp, dst), cast(UOp, src)
|
||||
assert self.warps == 1
|
||||
|
||||
Reference in New Issue
Block a user