diff --git a/extra/thunder/amd/fa.py b/extra/thunder/amd/fa.py index e9501d9a82..2a919cc507 100644 --- a/extra/thunder/amd/fa.py +++ b/extra/thunder/amd/fa.py @@ -47,19 +47,24 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False def grad(dou:UOp, _) -> tuple[None, None, UOp, UOp, UOp]: do = Tensor(dou, device=dou.device) - dq_in = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t) - dq = _sharded_empty_like(xq, axis=shard_axis) - dk = _sharded_empty_like(xk, axis=shard_axis) - dv = _sharded_empty_like(xv, axis=shard_axis) + dq = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t) + GROUP_SIZE = H_local // H_KV_local + dk_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xk, axis=shard_axis) + dv_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xv, axis=shard_axis) # delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach() delta_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t) - delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2] + delta_vec, dq = Tensor.custom_kernel(delta_vec, dq, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2] - dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3] + dq, dk_partial, dv_partial = Tensor.custom_kernel(dq, dk_partial, dv_partial, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3] - # unshuffle dq - dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[0] + # unshuffle dq: atomic_pk_add_bf16_with_warpid creates a shuffled layout within each 16x128 tile + # decompose each tile into (j=4, a=2, b=2, d=4, e=4, k=4, c=2) and permute to (e, k, j, a, d, b, c) = standard row-major + dq = dq.reshape(B, H, N//16, 4, 2, 2, 4, 4, 4, 2).permute(0, 1, 2, 7, 8, 3, 4, 6, 5, 9).reshape(B, H, N, D).transpose(1, 2) + + # reduce partial dK/dV across GROUP_SIZE query heads + dk = dk_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1) + dv = dv_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1) return None, None, dq.uop, dk.uop, dv.uop @@ -89,7 +94,6 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st arg=KernelInfo(name="custom_fa_forward", estimates=estimates)) lib = HIPCCCompiler(arch, compile_args).compile_cached(code) - lib = bytearray(lib) rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata") struct.pack_into('; using _gl_KV = gl; using _gl_dQ = gl; -using _gl_dKV = gl; +using _gl_dKV = gl; using _gl_Lvec = gl; template struct attn_bwd_combined_globals { @@ -47,7 +47,7 @@ template struct attn_bwd_combined_globals { _gl_dQ dQg; _gl_dKV dKg, dVg; _gl_Lvec L_vec, delta_vec; - dim3 grid() { return dim3(ATTN_H_KV, (ATTN_N / BLOCK_SIZE_KV), ATTN_B); } + dim3 grid() { return dim3(ATTN_H, (ATTN_N / BLOCK_SIZE_KV), ATTN_B); } dim3 block() { return dim3(NUM_THREADS); } size_t dynamic_shared_memory() { return MAX_SHARED_MEMORY; } }; @@ -55,10 +55,12 @@ template struct attn_bwd_combined_globals { template __launch_bounds__(NUM_THREADS, 1) __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr, bf16 *dO_ptr, bf16 *Q_ptr, bf16 *K_ptr, bf16 *V_ptr, float *L_vec_ptr, float *delta_vec_ptr) { - const int kv_head_idx = blockIdx.x; // This is the KV head index + const int q_head_idx_fixed = blockIdx.x; // This is the query head index [0, ATTN_H) + const int kv_head_idx = q_head_idx_fixed / GROUP_SIZE; + const int q_head_in_group = q_head_idx_fixed % GROUP_SIZE; const int seq_idx = blockIdx.y; const int batch_idx = blockIdx.z; - const int first_q_head = kv_head_idx * GROUP_SIZE; + const int first_q_head = q_head_idx_fixed; const int warpid = kittens::warpid(); const int j = seq_idx * NUM_WARPS + warpid; @@ -70,7 +72,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr // first Q step that can overlap this K_span: const int first_step = max(0, k_start_min / STEP_QO); const int num_steps_per_head = total_steps_per_head - first_step; - const int num_steps = num_steps_per_head * GROUP_SIZE; + const int num_steps = num_steps_per_head; const int k_pos = j * WARP_SIZE_KV; constexpr float L_SCALE_FACTOR = 1.44269504089f; @@ -3355,14 +3357,14 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr } } - store<1>(g.dVg, dV_j, {batch_idx, 0, kv_head_idx, 0}, {0, j, 0, 0}); + store<1>(g.dVg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0}); __builtin_amdgcn_s_waitcnt(0); __builtin_amdgcn_s_barrier(); // We first copy dV_j_T from accumulator GPRs to vector GPRs and then perform the store accvgpr_read(dV_j_T, dK_j_T); mul(dV_j_T, dV_j_T, dP_SCALE_FACTOR); - store<1>(g.dKg, dV_j, {batch_idx, 0, kv_head_idx, 0}, {0, j, 0, 0}); + store<1>(g.dKg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0}); // Write out final dQ_i slice mul(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR);