fa: change bwd grid dim + unshuffle using mops (#15068)

This commit is contained in:
wozeparrot
2026-03-04 17:23:40 +08:00
committed by GitHub
parent 5623cea7b1
commit df23057984
2 changed files with 23 additions and 20 deletions

View File

@@ -47,19 +47,24 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
def grad(dou:UOp, _) -> tuple[None, None, UOp, UOp, UOp]:
do = Tensor(dou, device=dou.device)
dq_in = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t)
dq = _sharded_empty_like(xq, axis=shard_axis)
dk = _sharded_empty_like(xk, axis=shard_axis)
dv = _sharded_empty_like(xv, axis=shard_axis)
dq = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t)
GROUP_SIZE = H_local // H_KV_local
dk_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xk, axis=shard_axis)
dv_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xv, axis=shard_axis)
# delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
delta_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2]
delta_vec, dq = Tensor.custom_kernel(delta_vec, dq, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2]
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3]
dq, dk_partial, dv_partial = Tensor.custom_kernel(dq, dk_partial, dv_partial, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3]
# unshuffle dq
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[0]
# unshuffle dq: atomic_pk_add_bf16_with_warpid creates a shuffled layout within each 16x128 tile
# decompose each tile into (j=4, a=2, b=2, d=4, e=4, k=4, c=2) and permute to (e, k, j, a, d, b, c) = standard row-major
dq = dq.reshape(B, H, N//16, 4, 2, 2, 4, 4, 4, 2).permute(0, 1, 2, 7, 8, 3, 4, 6, 5, 9).reshape(B, H, N, D).transpose(1, 2)
# reduce partial dK/dV across GROUP_SIZE query heads
dk = dk_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1)
dv = dv_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1)
return None, None, dq.uop, dk.uop, dv.uop
@@ -89,7 +94,6 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st
arg=KernelInfo(name="custom_fa_forward", estimates=estimates))
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
lib = bytearray(lib)
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
struct.pack_into('<I', lib, rodata_off, 160000)
@@ -120,7 +124,6 @@ def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arc
arg=KernelInfo(name="custom_fa_backward_pre", estimates=estimates))
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
lib = bytearray(lib)
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
struct.pack_into('<I', lib, rodata_off, 160000)
@@ -138,7 +141,7 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
BLOCK_SIZE_KV = 256
NUM_WARPS = 4
NUM_THREADS = 64 * NUM_WARPS
gsz = (H_KV, N // BLOCK_SIZE_KV, B)
gsz = (H, N // BLOCK_SIZE_KV, B)
lsz = (NUM_THREADS, 1, 1)
threadIdx_x = UOp.special(lsz[0], "lidx0")
blockIdx_x, blockIdx_y, blockIdx_z = UOp.special(gsz[0], "gidx0"), UOp.special(gsz[1], "gidx1"), UOp.special(gsz[2], "gidx2")
@@ -151,7 +154,6 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
arg=KernelInfo(name="custom_fa_backward", estimates=estimates))
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
lib = bytearray(lib)
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
struct.pack_into('<I', lib, rodata_off, 160000)
@@ -182,7 +184,6 @@ def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str, B:int,
arg=KernelInfo(name="custom_fa_backward_post", estimates=estimates))
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
lib = bytearray(lib)
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
struct.pack_into('<I', lib, rodata_off, 160000)

View File

@@ -37,7 +37,7 @@ using namespace kittens;
using _gl_QdO = gl<bf16, ATTN_B, ATTN_N, ATTN_H, ATTN_D>;
using _gl_KV = gl<bf16, ATTN_B, ATTN_N, ATTN_H_KV, ATTN_D>;
using _gl_dQ = gl<bf16, ATTN_B, ATTN_H, ATTN_N, ATTN_D>;
using _gl_dKV = gl<bf16, ATTN_B, ATTN_N, ATTN_H_KV, ATTN_D>;
using _gl_dKV = gl<bf16, ATTN_B * GROUP_SIZE, ATTN_N, ATTN_H_KV, ATTN_D>;
using _gl_Lvec = gl<float, ATTN_B, ATTN_H, 1, ATTN_N>;
template<int D> struct attn_bwd_combined_globals {
@@ -47,7 +47,7 @@ template<int D> struct attn_bwd_combined_globals {
_gl_dQ dQg;
_gl_dKV dKg, dVg;
_gl_Lvec L_vec, delta_vec;
dim3 grid() { return dim3(ATTN_H_KV, (ATTN_N / BLOCK_SIZE_KV), ATTN_B); }
dim3 grid() { return dim3(ATTN_H, (ATTN_N / BLOCK_SIZE_KV), ATTN_B); }
dim3 block() { return dim3(NUM_THREADS); }
size_t dynamic_shared_memory() { return MAX_SHARED_MEMORY; }
};
@@ -55,10 +55,12 @@ template<int D> struct attn_bwd_combined_globals {
template<int D> __launch_bounds__(NUM_THREADS, 1)
__global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr, bf16 *dO_ptr, bf16 *Q_ptr, bf16 *K_ptr, bf16 *V_ptr, float *L_vec_ptr, float *delta_vec_ptr) {
const int kv_head_idx = blockIdx.x; // This is the KV head index
const int q_head_idx_fixed = blockIdx.x; // This is the query head index [0, ATTN_H)
const int kv_head_idx = q_head_idx_fixed / GROUP_SIZE;
const int q_head_in_group = q_head_idx_fixed % GROUP_SIZE;
const int seq_idx = blockIdx.y;
const int batch_idx = blockIdx.z;
const int first_q_head = kv_head_idx * GROUP_SIZE;
const int first_q_head = q_head_idx_fixed;
const int warpid = kittens::warpid();
const int j = seq_idx * NUM_WARPS + warpid;
@@ -70,7 +72,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
// first Q step that can overlap this K_span:
const int first_step = max(0, k_start_min / STEP_QO);
const int num_steps_per_head = total_steps_per_head - first_step;
const int num_steps = num_steps_per_head * GROUP_SIZE;
const int num_steps = num_steps_per_head;
const int k_pos = j * WARP_SIZE_KV;
constexpr float L_SCALE_FACTOR = 1.44269504089f;
@@ -3355,14 +3357,14 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
}
}
store<1>(g.dVg, dV_j, {batch_idx, 0, kv_head_idx, 0}, {0, j, 0, 0});
store<1>(g.dVg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0});
__builtin_amdgcn_s_waitcnt(0);
__builtin_amdgcn_s_barrier();
// We first copy dV_j_T from accumulator GPRs to vector GPRs and then perform the store
accvgpr_read(dV_j_T, dK_j_T);
mul(dV_j_T, dV_j_T, dP_SCALE_FACTOR);
store<1>(g.dKg, dV_j, {batch_idx, 0, kv_head_idx, 0}, {0, j, 0, 0});
store<1>(g.dKg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0});
// Write out final dQ_i slice
mul(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR);