mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fa: change bwd grid dim + unshuffle using mops (#15068)
This commit is contained in:
@@ -47,19 +47,24 @@ def flash_attention(xq, xk, xv, attn_mask:Tensor|None=None, is_causal:bool=False
|
||||
|
||||
def grad(dou:UOp, _) -> tuple[None, None, UOp, UOp, UOp]:
|
||||
do = Tensor(dou, device=dou.device)
|
||||
dq_in = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t)
|
||||
dq = _sharded_empty_like(xq, axis=shard_axis)
|
||||
dk = _sharded_empty_like(xk, axis=shard_axis)
|
||||
dv = _sharded_empty_like(xv, axis=shard_axis)
|
||||
dq = _sharded_empty((B, H, N, D), xq, axis=shard_axis_t)
|
||||
GROUP_SIZE = H_local // H_KV_local
|
||||
dk_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xk, axis=shard_axis)
|
||||
dv_partial = _sharded_empty((B * GROUP_SIZE, N, H_KV, D), xv, axis=shard_axis)
|
||||
|
||||
# delta_vec = (do * attn).sum(-1, dtype=dtypes.float32).transpose(1, 2).unsqueeze(-2).detach()
|
||||
delta_vec = _sharded_empty((B, H, 1, N), xq, dtype=dtypes.float32, axis=shard_axis_t)
|
||||
delta_vec, dq_in = Tensor.custom_kernel(delta_vec, dq_in, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2]
|
||||
delta_vec, dq = Tensor.custom_kernel(delta_vec, dq, attn, do, fxn=functools.partial(custom_fa_backward_pre, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:2]
|
||||
|
||||
dq_in, dk, dv = Tensor.custom_kernel(dq_in, dk, dv, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3]
|
||||
dq, dk_partial, dv_partial = Tensor.custom_kernel(dq, dk_partial, dv_partial, do, xq, xk, xv, l_vec, delta_vec, fxn=functools.partial(custom_fa_backward, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[:3]
|
||||
|
||||
# unshuffle dq
|
||||
dq = Tensor.custom_kernel(dq, dq_in, fxn=functools.partial(custom_fa_backward_post, device=single_device, arch=arch, B=B_local, N=N, H=H_local, H_KV=H_KV_local, D=D))[0]
|
||||
# unshuffle dq: atomic_pk_add_bf16_with_warpid creates a shuffled layout within each 16x128 tile
|
||||
# decompose each tile into (j=4, a=2, b=2, d=4, e=4, k=4, c=2) and permute to (e, k, j, a, d, b, c) = standard row-major
|
||||
dq = dq.reshape(B, H, N//16, 4, 2, 2, 4, 4, 4, 2).permute(0, 1, 2, 7, 8, 3, 4, 6, 5, 9).reshape(B, H, N, D).transpose(1, 2)
|
||||
|
||||
# reduce partial dK/dV across GROUP_SIZE query heads
|
||||
dk = dk_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1)
|
||||
dv = dv_partial.reshape(B, GROUP_SIZE, N, H_KV, D).sum(1)
|
||||
|
||||
return None, None, dq.uop, dk.uop, dv.uop
|
||||
|
||||
@@ -89,7 +94,6 @@ def custom_fa_forward(o:UOp, l_vec:UOp, q:UOp, k:UOp, v:UOp, device:str, arch:st
|
||||
arg=KernelInfo(name="custom_fa_forward", estimates=estimates))
|
||||
|
||||
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
|
||||
|
||||
lib = bytearray(lib)
|
||||
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
|
||||
struct.pack_into('<I', lib, rodata_off, 160000)
|
||||
@@ -120,7 +124,6 @@ def custom_fa_backward_pre(delta_vec:UOp, dq:UOp, o:UOp, do:UOp, device:str, arc
|
||||
arg=KernelInfo(name="custom_fa_backward_pre", estimates=estimates))
|
||||
|
||||
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
|
||||
|
||||
lib = bytearray(lib)
|
||||
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
|
||||
struct.pack_into('<I', lib, rodata_off, 160000)
|
||||
@@ -138,7 +141,7 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
|
||||
BLOCK_SIZE_KV = 256
|
||||
NUM_WARPS = 4
|
||||
NUM_THREADS = 64 * NUM_WARPS
|
||||
gsz = (H_KV, N // BLOCK_SIZE_KV, B)
|
||||
gsz = (H, N // BLOCK_SIZE_KV, B)
|
||||
lsz = (NUM_THREADS, 1, 1)
|
||||
threadIdx_x = UOp.special(lsz[0], "lidx0")
|
||||
blockIdx_x, blockIdx_y, blockIdx_z = UOp.special(gsz[0], "gidx0"), UOp.special(gsz[1], "gidx1"), UOp.special(gsz[2], "gidx2")
|
||||
@@ -151,7 +154,6 @@ def custom_fa_backward(dq:UOp, dk:UOp, dv:UOp, do:UOp, q:UOp, k:UOp, v:UOp, l_ve
|
||||
arg=KernelInfo(name="custom_fa_backward", estimates=estimates))
|
||||
|
||||
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
|
||||
|
||||
lib = bytearray(lib)
|
||||
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
|
||||
struct.pack_into('<I', lib, rodata_off, 160000)
|
||||
@@ -182,7 +184,6 @@ def custom_fa_backward_post(dq_out:UOp, dq_in:UOp, device:str, arch:str, B:int,
|
||||
arg=KernelInfo(name="custom_fa_backward_post", estimates=estimates))
|
||||
|
||||
lib = HIPCCCompiler(arch, compile_args).compile_cached(code)
|
||||
|
||||
lib = bytearray(lib)
|
||||
rodata_off = next(sh.header.sh_offset for sh in elf_loader(bytes(lib))[1] if sh.name == ".rodata")
|
||||
struct.pack_into('<I', lib, rodata_off, 160000)
|
||||
|
||||
@@ -37,7 +37,7 @@ using namespace kittens;
|
||||
using _gl_QdO = gl<bf16, ATTN_B, ATTN_N, ATTN_H, ATTN_D>;
|
||||
using _gl_KV = gl<bf16, ATTN_B, ATTN_N, ATTN_H_KV, ATTN_D>;
|
||||
using _gl_dQ = gl<bf16, ATTN_B, ATTN_H, ATTN_N, ATTN_D>;
|
||||
using _gl_dKV = gl<bf16, ATTN_B, ATTN_N, ATTN_H_KV, ATTN_D>;
|
||||
using _gl_dKV = gl<bf16, ATTN_B * GROUP_SIZE, ATTN_N, ATTN_H_KV, ATTN_D>;
|
||||
using _gl_Lvec = gl<float, ATTN_B, ATTN_H, 1, ATTN_N>;
|
||||
|
||||
template<int D> struct attn_bwd_combined_globals {
|
||||
@@ -47,7 +47,7 @@ template<int D> struct attn_bwd_combined_globals {
|
||||
_gl_dQ dQg;
|
||||
_gl_dKV dKg, dVg;
|
||||
_gl_Lvec L_vec, delta_vec;
|
||||
dim3 grid() { return dim3(ATTN_H_KV, (ATTN_N / BLOCK_SIZE_KV), ATTN_B); }
|
||||
dim3 grid() { return dim3(ATTN_H, (ATTN_N / BLOCK_SIZE_KV), ATTN_B); }
|
||||
dim3 block() { return dim3(NUM_THREADS); }
|
||||
size_t dynamic_shared_memory() { return MAX_SHARED_MEMORY; }
|
||||
};
|
||||
@@ -55,10 +55,12 @@ template<int D> struct attn_bwd_combined_globals {
|
||||
template<int D> __launch_bounds__(NUM_THREADS, 1)
|
||||
__global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr, bf16 *dO_ptr, bf16 *Q_ptr, bf16 *K_ptr, bf16 *V_ptr, float *L_vec_ptr, float *delta_vec_ptr) {
|
||||
|
||||
const int kv_head_idx = blockIdx.x; // This is the KV head index
|
||||
const int q_head_idx_fixed = blockIdx.x; // This is the query head index [0, ATTN_H)
|
||||
const int kv_head_idx = q_head_idx_fixed / GROUP_SIZE;
|
||||
const int q_head_in_group = q_head_idx_fixed % GROUP_SIZE;
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int batch_idx = blockIdx.z;
|
||||
const int first_q_head = kv_head_idx * GROUP_SIZE;
|
||||
const int first_q_head = q_head_idx_fixed;
|
||||
|
||||
const int warpid = kittens::warpid();
|
||||
const int j = seq_idx * NUM_WARPS + warpid;
|
||||
@@ -70,7 +72,7 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
||||
// first Q step that can overlap this K_span:
|
||||
const int first_step = max(0, k_start_min / STEP_QO);
|
||||
const int num_steps_per_head = total_steps_per_head - first_step;
|
||||
const int num_steps = num_steps_per_head * GROUP_SIZE;
|
||||
const int num_steps = num_steps_per_head;
|
||||
const int k_pos = j * WARP_SIZE_KV;
|
||||
|
||||
constexpr float L_SCALE_FACTOR = 1.44269504089f;
|
||||
@@ -3355,14 +3357,14 @@ __global__ void attend_bwd_combined_ker(bf16 *dQ_ptr, bf16 *dK_ptr, bf16 *dV_ptr
|
||||
}
|
||||
}
|
||||
|
||||
store<1>(g.dVg, dV_j, {batch_idx, 0, kv_head_idx, 0}, {0, j, 0, 0});
|
||||
store<1>(g.dVg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0});
|
||||
__builtin_amdgcn_s_waitcnt(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
|
||||
// We first copy dV_j_T from accumulator GPRs to vector GPRs and then perform the store
|
||||
accvgpr_read(dV_j_T, dK_j_T);
|
||||
mul(dV_j_T, dV_j_T, dP_SCALE_FACTOR);
|
||||
store<1>(g.dKg, dV_j, {batch_idx, 0, kv_head_idx, 0}, {0, j, 0, 0});
|
||||
store<1>(g.dKg, dV_j, {batch_idx * GROUP_SIZE + q_head_in_group, 0, kv_head_idx, 0}, {0, j, 0, 0});
|
||||
|
||||
// Write out final dQ_i slice
|
||||
mul(dQ_i_T, dQ_i_T, dP_SCALE_FACTOR);
|
||||
|
||||
Reference in New Issue
Block a user