This commit is contained in:
wozeparrot
2025-12-17 23:56:37 -08:00
committed by GitHub
parent aeb7516c8a
commit 99e667bdcd
2 changed files with 236 additions and 12 deletions

View File

@@ -31,11 +31,11 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
GROUP_SIZE = H // H_KV
print(f"Flash Attention {B=} {N=} {H=} {D=} {H_KV=} {GROUP_SIZE=}")
def custom_forward(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, mu:UOp) -> UOp:
def custom_forward(ou:UOp, l_vecu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp) -> UOp:
with Kernel("fa_custom_forward", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
warp = ker.warp
o, q, k, v, mask, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(mu, ker), GL(l_vecu, ker)
o, q, k, v, mask, l_vec = GL(ou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker), GL(l_vecu, ker)
head = ker.blockIdx_x
head_kv = head // GROUP_SIZE
@@ -129,11 +129,197 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
return ker.finish()
def custom_backward_q(out_qu:UOp, gradu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
return UOp.sink(arg=KernelInfo(name="fa_custom_backward_q"))
def custom_backward_q(dqu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
with Kernel("fa_custom_backward_q", (H, N // (Q_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
warp = ker.warp
def custom_backward_kv(out_ku:UOp, out_vu:UOp, gradu:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
return UOp.sink(arg=KernelInfo(name="fa_custom_backward_kv"))
dq, do, q, k, v, mask = GL(dqu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker)
head = ker.blockIdx_x
head_kv = head // GROUP_SIZE
batch = ker.blockIdx_z
q_seq = ker.blockIdx_y * NUM_WORKERS + ker.warpid
k_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
v_smem = ker.st((KV_BLOCK_SIZE, D), dtypes.bfloat16)
q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
q_reg_t = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
k_reg_t = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
k_reg_col = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
k_reg_col_t = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16)
v_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32)
mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
dq_reg = ker.rt((D, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
dq_reg_transposed = ker.rt((Q_BLOCK_SIZE, D), dtypes.float32)
do_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
dp_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
att_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
att_block_mma = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
l_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
delta_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
dq_reg = warp.zero(dq_reg)
# load q tile
q_reg = warp.load(q_reg, q, (), (batch, q_seq, head, 0), axis=1)
q_reg_t = warp.transpose(q_reg_t, q_reg)
# load do tile
do_reg = warp.load(do_reg, do, (), (batch, q_seq, head, 0), axis=1)
# load l_vec
l_vec_reg = warp.load(l_vec_reg, l_vec, (), (batch, head, 0, q_seq), axis=2)
l_vec_reg *= 1.0 / math.log(2)
delta_vec_reg = warp.load(delta_vec_reg, delta_vec, (), (batch, head, 0, q_seq), axis=2)
for kv_idx in ker.range(N // KV_BLOCK_SIZE):
k_smem = warp.load(k_smem, k, (), (batch, kv_idx, head_kv, 0), axis=1)
v_smem = warp.load(v_smem, v, (), (batch, kv_idx, head_kv, 0), axis=1)
k_reg = warp.load(k_reg, k_smem)
k_reg_t = warp.transpose(k_reg_t, k_reg)
k_reg_col = warp.load(k_reg_col, k_smem)
k_reg_col_t = warp.transpose(k_reg_col_t, k_reg_col)
v_reg = warp.load(v_reg, v_smem)
# mma qk^t
att_block = warp.zero(att_block.after(kv_idx))
att_block = warp.mma_AtB(att_block, k_reg_t, q_reg_t)
# apply attention mask
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_seq, kv_idx), axis=2)
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
att_block += mask_reg_transposed
att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
att_block -= l_vec_reg
att_block = att_block.exp2()
dp_block = warp.zero(dp_block.after(kv_idx, att_block))
dp_block = warp.mma_ABt(dp_block, v_reg, do_reg)
dp_block -= delta_vec_reg
att_block *= dp_block
att_block_mma = warp.copy(att_block_mma.after(att_block), att_block)
dq_reg = warp.mma_AB(dq_reg, k_reg_col_t, att_block_mma)
dq_reg = ker.endrange()
dq_reg *= 1.0 / math.sqrt(D)
dq_reg_transposed = warp.transpose(dq_reg_transposed, dq_reg)
dq = warp.store(dq, dq_reg_transposed, (batch, q_seq, head, 0), axis=1)
return ker.finish()
def custom_backward_kv(dku:UOp, dvu:UOp, dou:UOp, qu:UOp, ku:UOp, vu:UOp, masku:UOp, l_vecu:UOp, delta_vecu:UOp) -> UOp:
with Kernel("fa_custom_backward_kv", (H_KV, N // (KV_BLOCK_SIZE*NUM_WORKERS), B), NUM_WORKERS * WARP_THREADS) as ker:
warp = ker.warp
dk, dv, do, q, k, v, mask = GL(dku, ker), GL(dvu, ker), GL(dou, ker), GL(qu, ker), GL(ku, ker), GL(vu, ker), GL(masku, ker)
l_vec, delta_vec = GL(l_vecu, ker), GL(delta_vecu, ker)
head_kv = ker.blockIdx_x
batch = ker.blockIdx_z
kv_seq = ker.blockIdx_y * NUM_WORKERS + ker.warpid
q_smem = ker.st((Q_BLOCK_SIZE, D), dtypes.bfloat16)
do_smem = ker.st((Q_BLOCK_SIZE, D), dtypes.bfloat16)
att_smem = ker.st((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16)
q_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
q_reg_t = ker.rt((D, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
q_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
k_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
k_reg_t = ker.rt((D, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
v_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.bfloat16)
mask_reg = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.float32)
mask_reg_transposed = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
dk_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL)
dv_reg = ker.rt((KV_BLOCK_SIZE, D), dtypes.float32, TileLayout.COL)
do_reg = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16)
do_reg_col = ker.rt((Q_BLOCK_SIZE, D), dtypes.bfloat16, TileLayout.COL)
dp_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
att_block = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.float32, TileLayout.COL)
att_block_mma = ker.rt((KV_BLOCK_SIZE, Q_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
att_block_transposed = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16, TileLayout.COL)
att_block_row = ker.rt((Q_BLOCK_SIZE, KV_BLOCK_SIZE), dtypes.bfloat16)
l_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
delta_vec_reg = ker.rv(Q_BLOCK_SIZE, dtypes.float32)
dk_reg = warp.zero(dk_reg)
dv_reg = warp.zero(dv_reg)
# load kv tile
k_reg = warp.load(k_reg, k, (), (batch, kv_seq, head_kv, 0), axis=1)
k_reg_t = warp.transpose(k_reg_t, k_reg)
v_reg = warp.load(v_reg, v, (), (batch, kv_seq, head_kv, 0), axis=1)
for q_idx in ker.range(N // Q_BLOCK_SIZE):
for g in ker.range(GROUP_SIZE):
head_q = head_kv * GROUP_SIZE + g
# load q and do
q_smem = warp.load(q_smem, q, (), (batch, q_idx, head_q, 0), axis=1)
do_smem = warp.load(do_smem, do, (), (batch, q_idx, head_q, 0), axis=1)
q_reg = warp.load(q_reg, q_smem)
q_reg_t = warp.transpose(q_reg_t, q_reg)
q_reg_col = warp.load(q_reg_col, q_smem)
do_reg = warp.load(do_reg, do_smem)
do_reg_col = warp.load(do_reg_col, do_smem)
# load l_vec and delta_vec
l_vec_reg = warp.load(l_vec_reg, l_vec, (), (batch, head_q, 0, q_idx), axis=2)
l_vec_reg *= 1.0 / math.log(2)
delta_vec_reg = warp.load(delta_vec_reg, delta_vec, (), (batch, head_q, 0, q_idx), axis=2)
# mma qk^t
att_block = warp.zero(att_block.after(g))
att_block = warp.mma_AtB(att_block, k_reg_t, q_reg_t)
# apply attention mask
mask_reg = warp.load(mask_reg, mask, (), (batch, 0, q_idx, kv_seq), axis=2)
mask_reg_transposed = warp.transpose(mask_reg_transposed, mask_reg)
att_block += mask_reg_transposed
att_block *= (1.0 / math.sqrt(D)) * (1.0 / math.log(2))
att_block -= l_vec_reg
att_block = att_block.exp2()
att_block_mma = warp.copy(att_block_mma, att_block)
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
att_smem = warp.store(att_smem, att_block_transposed)
att_block_row = warp.load(att_block_row, att_smem)
dv_reg_ = warp.mma_AB(dv_reg, att_block_row, do_reg_col)
dp_block = warp.zero(dp_block.after(g, q_idx, dv_reg_))
dp_block = warp.mma_ABt(dp_block, v_reg, do_reg)
dp_block -= delta_vec_reg
att_block *= dp_block
att_block_mma = warp.copy(att_block_mma, att_block)
att_block_transposed = warp.transpose(att_block_transposed, att_block_mma)
att_smem = warp.store(att_smem, att_block_transposed)
att_block_row = warp.load(att_block_row, att_smem)
dk_reg = warp.mma_AB(dk_reg, att_block_row, q_reg_col)
dk_reg = ker.endrange(2)
dv_reg = dv_reg.after(dk_reg)
dk_reg *= 1.0 / math.sqrt(D)
dk = warp.store(dk, dk_reg, (batch, kv_seq, head_kv, 0), axis=1)
dv = warp.store(dv, dv_reg, (batch, kv_seq, head_kv, 0), axis=1)
return ker.finish(2)
if is_causal:
if attn_mask is not None: raise RuntimeError("cannot set attn_mask when is_causal=True")
@@ -146,18 +332,18 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
attn = Tensor.empty_like(xq)
l_vec = Tensor.empty(B, H, 1, N, requires_grad=False, device=xq.device, dtype=dtypes.float32).detach()
def grad(grad:UOp, kernel:UOp) -> tuple[None, None, UOp, UOp, UOp, None]:
def grad(gradu:UOp, kernel:UOp) -> tuple[None, None, UOp, UOp, UOp, None]:
grad = Tensor(gradu)
grad_q = Tensor.empty_like(q := Tensor(kernel.src[2]))
grad_k = Tensor.empty_like(k := Tensor(kernel.src[3]))
grad_v = Tensor.empty_like(v := Tensor(kernel.src[4]))
mask = Tensor(kernel.src[5])
delta_vec = (Tensor(grad) * attn).sum(-1).unsqueeze(-2).detach()
delta_vec = (grad * attn).sum(-1).transpose(1, 2).unsqueeze(-2).detach()
print(l_vec.shape, delta_vec.shape, grad.shape, attn.shape, grad_q.shape, grad_k.shape, grad_v.shape)
print(l_vec.numpy())
grad_q = Tensor.custom_kernel(grad_q, Tensor(grad), q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, Tensor(grad), q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]
grad_q = Tensor.custom_kernel(grad_q, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_q)[0]
grad_k, grad_v = Tensor.custom_kernel(grad_k, grad_v, grad, q, k, v, mask, l_vec, delta_vec, fxn=custom_backward_kv)[:2]
return (None, None, grad_q.uop, grad_k.uop, grad_v.uop, None)
attn, l_vec = Tensor.custom_kernel(attn, l_vec, xq, xk, xv, attn_mask, fxn=custom_forward, grad_fxn=grad)[:2]

View File

@@ -763,5 +763,43 @@ class TestTK(unittest.TestCase):
np.testing.assert_allclose(out.numpy(), ref.numpy(), atol=2e-2, rtol=2e-2)
def test_fast_fa_bwd(self):
from extra.thunder.tiny.fa import flash_attention
Tensor.manual_seed(42)
B, N, H, H_KV, D = 1, 32, 2, 1, 32
with Context(DEBUG=0):
q = Tensor.randn(B, N, H, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
k = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
v = Tensor.randn(B, N, H_KV, D, dtype=dtypes.bfloat16, requires_grad=True).contiguous()
Tensor.realize(q, k, v)
do = Tensor.ones(B, N, H, D, dtype=dtypes.float32).contiguous()
Tensor.realize(do)
q_, k_, v_ = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
out = flash_attention(q_, k_, v_)
out = out.float().transpose(1, 2)
out.backward(do)
Tensor.realize(q.grad, k.grad, v.grad)
with Context(DEBUG=0):
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
Tensor.realize(q_ref, k_ref, v_ref)
q_ref_, k_ref_, v_ref_ = q_ref.transpose(1, 2), k_ref.transpose(1, 2), v_ref.transpose(1, 2)
ref = q_ref_.scaled_dot_product_attention(k_ref_, v_ref_)
ref = ref.float().transpose(1, 2)
ref.backward(do)
Tensor.realize(q_ref.grad, k_ref.grad, v_ref.grad)
np.testing.assert_allclose(q.grad.numpy(), q_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(v.grad.numpy(), v_ref.grad.numpy(), atol=2e-2, rtol=2e-2)
np.testing.assert_allclose(k.grad.numpy(), k_ref.grad.numpy(), atol=5e-2, rtol=2e-2)
if __name__ == "__main__":
unittest.main()