mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZATION] Fix performance for attention backward path with mma v3 (#2411)
Support having chain of mma with mixed size. Serialize the different block calculation in backward attention to workaround problem with ptxas and wgmma.
This commit is contained in:
@@ -323,106 +323,102 @@ def _attn_bwd(
|
||||
# load scales
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
if (tl.program_id(1) == 0):
|
||||
# THIS BLOCK DOES DK/DV/DR:
|
||||
|
||||
# THIS BLOCK DOES DK/DV/DR:
|
||||
start_n = pid * BLOCK_N1
|
||||
start_m = start_n
|
||||
|
||||
start_n = pid * BLOCK_N1
|
||||
start_m = start_n
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dv_ptrs, dv)
|
||||
|
||||
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dv_ptrs, dv)
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dk_ptrs, dk)
|
||||
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
tl.store(dk_ptrs, dk)
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
|
||||
else:
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
|
||||
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * BLOCK_N2, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
dq *= LN2
|
||||
tl.store(dq_ptrs, dq)
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
|
||||
MASK=True,
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(
|
||||
dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * BLOCK_N2, num_steps,
|
||||
MASK=False,
|
||||
)
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
|
||||
dq *= LN2
|
||||
tl.store(dq_ptrs, dq)
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
@@ -491,7 +487,7 @@ class _attention(torch.autograd.Function):
|
||||
BATCH, N_HEAD, N_CTX,
|
||||
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
grid = (N_CTX // BLOCK_N1, 2, BATCH * N_HEAD)
|
||||
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
|
||||
_attn_bwd[grid](
|
||||
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
|
||||
M, delta,
|
||||
|
||||
Reference in New Issue
Block a user