[OPTIMIZATION] Fix performance for attention backward path with mma v3 (#2411)

Support having chain of mma with mixed size.
Serialize the different block calculation in backward attention to
workaround problem with ptxas and wgmma.
This commit is contained in:
Thomas Raoux
2023-09-28 10:29:08 -07:00
committed by GitHub
parent 1e093fbfff
commit 721bdebee1
11 changed files with 164 additions and 195 deletions

View File

@@ -323,106 +323,102 @@ def _attn_bwd(
# load scales
offs_k = tl.arange(0, BLOCK_DMODEL)
if (tl.program_id(1) == 0):
# THIS BLOCK DOES DK/DV/DR:
# THIS BLOCK DOES DK/DV/DR:
start_n = pid * BLOCK_N1
start_m = start_n
start_n = pid * BLOCK_N1
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
# load K and V: they stay in SRAM throughout the inner loop.
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=True,
)
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=True,
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=False,
)
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=False,
)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dv_ptrs, dv)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# Write back dK.
dk *= sm_scale
dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
tl.store(dk_ptrs, dk)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
else:
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
m = tl.load(M + offs_m)
m = m[:, None]
q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
MASK=True,
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * BLOCK_N2, num_steps,
MASK=False,
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
MASK=True,
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(
dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * BLOCK_N2, num_steps,
MASK=False,
)
# Write back dQ.
dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
dq *= LN2
tl.store(dq_ptrs, dq)
empty = torch.empty(128, device="cuda")
@@ -491,7 +487,7 @@ class _attention(torch.autograd.Function):
BATCH, N_HEAD, N_CTX,
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
)
grid = (N_CTX // BLOCK_N1, 2, BATCH * N_HEAD)
grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
M, delta,