mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge changes from upstream FA bwd kernel (#444)
* Add optimized FA bwd from upstream * Add autotuning * Change loads and stores to use block ptrs * Cleanup
This commit is contained in:
@@ -186,230 +186,361 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out,
|
||||
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
|
||||
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_bwd_preprocess(O, DO, #
|
||||
NewDO, Delta, #
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
|
||||
def _attn_bwd_preprocess(O, DO,
|
||||
Delta,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
|
||||
):
|
||||
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
off_hz = tl.program_id(1)
|
||||
off_n = tl.arange(0, D_HEAD)
|
||||
# load
|
||||
o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :])
|
||||
do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
tl.store(Delta + off_m, delta)
|
||||
tl.store(Delta + off_hz * N_CTX + off_m, delta)
|
||||
|
||||
|
||||
# The main inner-loop logic for computing dK and dV.
|
||||
@triton.jit
|
||||
def _bwd_kernel_dk_dv(Q, K, V, sm_scale,
|
||||
Out, DO,
|
||||
DK, DV,
|
||||
L, D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
# Q is consumed depending on block ID. Every block uses
|
||||
# previous block offset by BLOCK_M x D_HEAD.
|
||||
qvk_offset = off_hz * stride_qh
|
||||
qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
# Initialize pointers to Q, K, V
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qdo_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
def _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX, BLOCK_M1: tl.constexpr,
|
||||
BLOCK_N1: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
# Filled in by the wrapper.
|
||||
start_n, start_m, num_steps,
|
||||
MASK: tl.constexpr):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M1)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
QT_block_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, start_m * BLOCK_M),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, start_m * BLOCK_M),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
strides=(stride_d, stride_tok),
|
||||
offsets=(0, start_m),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_M1),
|
||||
order=(0,1)
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO + qdo_offset,
|
||||
base=DO,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_m, 0),
|
||||
block_shape=(BLOCK_M1, BLOCK_DMODEL),
|
||||
order=(1,0)
|
||||
)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load k and v: they will stay in SRAM throughout
|
||||
k = tl.load(K_block_ptr)
|
||||
k = (k * qk_scale).to(k.dtype)
|
||||
v = tl.load(V_block_ptr)
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# This lower loop bound is because of the causal mask. We create a lower triangular
|
||||
# result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
|
||||
# be ignored in the GEMM.
|
||||
lo = start_m * BLOCK_M
|
||||
hi = N_CTX
|
||||
# loop over q, do
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
offs_m_curr = offs_n[:, None] + start_n
|
||||
# -- load q, do --
|
||||
q = tl.load(Q_block_ptr)
|
||||
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
||||
curr_m = start_m
|
||||
step_m = BLOCK_M1
|
||||
for blk_idx in range(num_steps):
|
||||
qT = tl.load(QT_block_ptr)
|
||||
# Load m before computing qk to reduce pipeline stall.
|
||||
offs_m = curr_m + tl.arange(0, BLOCK_M1)
|
||||
m = tl.load(M + offs_m)
|
||||
qkT = tl.dot(k, qT)
|
||||
pT = tl.math.exp2(qkT - m[None, :])
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
mask = (offs_m[None, :] >= offs_n[:, None])
|
||||
pT = tl.where(mask, pT, 0.0)
|
||||
do = tl.load(DO_block_ptr)
|
||||
# -- compute qk ----
|
||||
qk = tl.dot(q, k)
|
||||
qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i)
|
||||
# -- compute dv ----
|
||||
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
|
||||
dp += tl.dot(do, v)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp
|
||||
# compute dk
|
||||
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
||||
# update pointers
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
|
||||
# initialize pointers to output
|
||||
DK_block_ptr = tl.make_block_ptr(
|
||||
base=DK + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
# Compute dV.
|
||||
ppT = pT
|
||||
ppT = ppT.to(tl.float16)
|
||||
dv += tl.dot(ppT, do)
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do))
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
dsT = dsT.to(tl.float16)
|
||||
dk += tl.dot(dsT, tl.trans(qT))
|
||||
# Increment pointers.
|
||||
curr_m += step_m
|
||||
QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
|
||||
return dk, dv
|
||||
|
||||
|
||||
# the main inner-loop logic for computing dQ
|
||||
@triton.jit
|
||||
def _attn_bwd_dq(dq, q, K, V,
|
||||
do, m, D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2: tl.constexpr,
|
||||
BLOCK_N2: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
# Filled in by the wrapper.
|
||||
start_m, start_n, num_steps,
|
||||
MASK: tl.constexpr):
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N2)
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
KT_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_d, stride_tok),
|
||||
offsets=(0, start_n),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
||||
order=(0, 1)
|
||||
)
|
||||
DV_block_ptr = tl.make_block_ptr(
|
||||
base=DV + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
VT_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_d, stride_tok),
|
||||
offsets=(0, start_n),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N2),
|
||||
order=(0, 1)
|
||||
)
|
||||
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
|
||||
tl.store(DV_block_ptr, dv.to(tl.float16))
|
||||
# D (= delta) is pre-divided by ds_scale.
|
||||
Di = tl.load(D + offs_m)
|
||||
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
||||
curr_n = start_n
|
||||
step_n = BLOCK_N2
|
||||
for blk_idx in range(num_steps):
|
||||
kT = tl.load(KT_block_ptr)
|
||||
qk = tl.dot(q, kT)
|
||||
p = tl.math.exp2(qk - m)
|
||||
# Autoregressive masking.
|
||||
if MASK:
|
||||
offs_n = curr_n + tl.arange(0, BLOCK_N2)
|
||||
mask = (offs_m[:, None] >= offs_n[None, :])
|
||||
p = tl.where(mask, p, 0.0)
|
||||
# Compute dP and dS.
|
||||
vT = tl.load(VT_block_ptr)
|
||||
dp = tl.dot(do, vT).to(tl.float32)
|
||||
ds = p * (dp - Di[:, None])
|
||||
ds = ds.to(tl.float16)
|
||||
# Compute dQ.
|
||||
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
|
||||
dq += tl.dot(ds, tl.trans(kT))
|
||||
# Increment pointers.
|
||||
curr_n += step_n
|
||||
KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
|
||||
VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
|
||||
return dq
|
||||
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
||||
num_stages=1, num_warps=4),
|
||||
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
|
||||
num_stages=1, num_warps=8),
|
||||
],
|
||||
key=['H', 'N_CTX', 'BLOCK_DMODEL'],
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def _bwd_kernel_dq(Q, K, V, sm_scale,
|
||||
Out, DO,
|
||||
DQ,
|
||||
L,D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vk, stride_vn,
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
# Initialize pointers to Q, K, V
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
def _attn_bwd(Q, K, V, sm_scale,
|
||||
DO,
|
||||
DQ, DK, DV,
|
||||
M, D,
|
||||
# shared by Q/K/V/DO.
|
||||
stride_z, stride_h, stride_tok, stride_d,
|
||||
# H = 16, N_CTX = 1024
|
||||
H, N_CTX,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_M1: tl.constexpr,
|
||||
BLOCK_N1: tl.constexpr,
|
||||
BLOCK_M2: tl.constexpr,
|
||||
BLOCK_N2: tl.constexpr,
|
||||
BLK_SLICE_FACTOR: tl.constexpr):
|
||||
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
|
||||
|
||||
bhid = tl.program_id(2)
|
||||
off_chz = (bhid * N_CTX).to(tl.int64)
|
||||
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
|
||||
pid = tl.program_id(0)
|
||||
|
||||
# offset pointers for batch/head
|
||||
Q += adj
|
||||
K += adj
|
||||
V += adj
|
||||
DO += adj
|
||||
DQ += adj
|
||||
DK += adj
|
||||
DV += adj
|
||||
M += off_chz
|
||||
D += off_chz
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
|
||||
start_n = pid * BLOCK_N1
|
||||
# This assignment is important. It is what allows us to pick the diagonal
|
||||
# blocks. Later, when we want to do the lower triangular, we update start_m
|
||||
# after the first dkdv call.
|
||||
start_m = start_n
|
||||
|
||||
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N1)
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
base=K,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_n, 0),
|
||||
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
base=V,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_n, 0),
|
||||
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop for dkdv.
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
|
||||
num_steps = BLOCK_N1 // MASK_BLOCK_M1
|
||||
|
||||
dk, dv = _attn_bwd_dkdv(dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=True
|
||||
)
|
||||
|
||||
start_m += num_steps * MASK_BLOCK_M1
|
||||
num_steps = (N_CTX - start_m) // BLOCK_M1
|
||||
|
||||
# Compute dK and dV for non-masked blocks.
|
||||
dk, dv = _attn_bwd_dkdv(
|
||||
dk, dv,
|
||||
Q, k, v, sm_scale,
|
||||
DO,
|
||||
M, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
|
||||
start_n, start_m, num_steps,
|
||||
MASK=False
|
||||
)
|
||||
|
||||
DV_block_ptrs = tl.make_block_ptr(
|
||||
base=DV,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_n, 0),
|
||||
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
||||
order=(1,0)
|
||||
)
|
||||
tl.store(DV_block_ptrs, dv.to(tl.float16))
|
||||
|
||||
# Write back dK.
|
||||
dk *= sm_scale
|
||||
DK_block_ptrs = tl.make_block_ptr(
|
||||
base=DK,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_n, 0),
|
||||
block_shape=(BLOCK_N1, BLOCK_DMODEL),
|
||||
order=(1,0)
|
||||
)
|
||||
tl.store(DK_block_ptrs, dk.to(tl.float16))
|
||||
|
||||
# THIS BLOCK DOES DQ:
|
||||
start_m = pid * BLOCK_M2
|
||||
end_n = start_m + BLOCK_M2
|
||||
|
||||
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
|
||||
offs_m = start_m + tl.arange(0, BLOCK_M2)
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_m, 0),
|
||||
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO + qvk_offset,
|
||||
base=DO,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_m, 0),
|
||||
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q and do: they will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(q.dtype)
|
||||
do = tl.load(DO_block_ptr)
|
||||
Di = tl.load(D_ptrs + offs_m)
|
||||
l_i = tl.load(l_ptrs + offs_m)
|
||||
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# loop over k, v
|
||||
lo = 0
|
||||
hi = (start_m + 1) * BLOCK_M
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
# -- compute qk ----
|
||||
qk = tl.dot(q, k)
|
||||
qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
# compute dp = dot(v, do)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
dp += tl.dot(do, v)
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp
|
||||
# compute dq. Unfortunately we cannot avoid transpose here as this loop
|
||||
# uses k both normal and transpose.
|
||||
dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
|
||||
# initialize pointers to output
|
||||
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
|
||||
|
||||
m = tl.load(M + offs_m)
|
||||
m = m[:, None]
|
||||
|
||||
# Compute dQ for masked (diagonal) blocks.
|
||||
# NOTE: This code scans each row of QK^T backward (from right to left,
|
||||
# but inside each call to _attn_bwd_dq, from left to right), but that's
|
||||
# not due to anything important. I just wanted to reuse the loop
|
||||
# structure for dK & dV above as much as possible.
|
||||
num_steps = BLOCK_M2 // MASK_BLOCK_N2
|
||||
dq = _attn_bwd_dq(dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
|
||||
MASK=True
|
||||
)
|
||||
end_n -= num_steps * MASK_BLOCK_N2
|
||||
# stage 2
|
||||
num_steps = end_n // BLOCK_N2
|
||||
dq = _attn_bwd_dq(dq, q, K, V,
|
||||
do, m, D,
|
||||
stride_tok, stride_d,
|
||||
H, N_CTX,
|
||||
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
|
||||
start_m, end_n - num_steps * BLOCK_N2, num_steps,
|
||||
MASK=False
|
||||
)
|
||||
# Write back dQ.
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ + qvk_offset,
|
||||
base=DQ,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
strides=(stride_tok, stride_d),
|
||||
offsets=(start_m, 0),
|
||||
block_shape=(BLOCK_M2, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
|
||||
dq *= LN2
|
||||
tl.store(DQ_block_ptr, dq.to(tl.float16))
|
||||
|
||||
|
||||
empty = torch.empty(128, device="cuda")
|
||||
|
||||
@@ -417,12 +548,12 @@ empty = torch.empty(128, device="cuda")
|
||||
class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
o = torch.empty_like(q)
|
||||
if torch.version.hip is None:
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
@@ -432,7 +563,6 @@ class _attention(torch.autograd.Function):
|
||||
if torch.cuda.get_device_capability()[0] == 9:
|
||||
num_warps = 8
|
||||
num_stages = 7 if Lk >= 64 else 3
|
||||
|
||||
stage = 3 if causal else 1
|
||||
grid = lambda META: (
|
||||
triton.cdiv(q.shape[2], META['BLOCK_M']),
|
||||
@@ -440,7 +570,6 @@ class _attention(torch.autograd.Function):
|
||||
1
|
||||
)
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, M, o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
@@ -463,85 +592,51 @@ class _attention(torch.autograd.Function):
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
ctx.causal = causal
|
||||
ctx.split_kernel = split_kernel
|
||||
return o
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
# configuration is not supported
|
||||
assert(not (ctx.split_kernel and not ctx.causal))
|
||||
if torch.version.hip is not None:
|
||||
BLOCK = 64
|
||||
else:
|
||||
BLOCK = 128
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
q, k, v, o, M = ctx.saved_tensors
|
||||
assert do.is_contiguous()
|
||||
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q)
|
||||
dq = torch.empty_like(q)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
||||
delta = torch.empty_like(L)
|
||||
do_scaled = torch.empty_like(do)
|
||||
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
|
||||
# If the two are the same, we don't need this but the bwd pass block size
|
||||
# is smaller than the fwd so we need this scaling to ensure we loop over all
|
||||
# values and don't skip some blocks.
|
||||
# Alternatively we could compute a new grid but this keeps it consistent
|
||||
# with fwd and easier to reason about.
|
||||
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
|
||||
_attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, #
|
||||
do_scaled, delta, #
|
||||
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
|
||||
PRE_BLOCK = 128
|
||||
NUM_WARPS, NUM_STAGES = 4, 1
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
|
||||
BLK_SLICE_FACTOR = 2
|
||||
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
||||
arg_k = k
|
||||
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
|
||||
assert N_CTX % PRE_BLOCK == 0
|
||||
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
|
||||
delta = torch.empty_like(M)
|
||||
_attn_bwd_preprocess[pre_grid](
|
||||
o, do,
|
||||
delta,
|
||||
BATCH, N_HEAD, N_CTX,
|
||||
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
|
||||
)
|
||||
if not ctx.split_kernel:
|
||||
_bwd_kernel[(ctx.grid[1],)](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
block_scale * ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
)
|
||||
else :
|
||||
dq = torch.zeros_like(q)
|
||||
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dk, dv,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
||||
num_stages=1,
|
||||
)
|
||||
_bwd_kernel_dq[ctx.grid](
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
|
||||
num_stages=1,
|
||||
)
|
||||
# print(h.asm["ttgir"])
|
||||
return dq, dk, dv, None, None, None
|
||||
grid = lambda META: (
|
||||
triton.cdiv(N_CTX, META['BLOCK_N1']),
|
||||
1,
|
||||
BATCH * N_HEAD
|
||||
)
|
||||
_attn_bwd[grid](
|
||||
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
|
||||
M, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
N_HEAD, N_CTX,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL
|
||||
)
|
||||
|
||||
return dq, dk, dv, None, None
|
||||
|
||||
attention = _attention.apply
|
||||
|
||||
@@ -594,7 +689,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
|
||||
|
||||
sm_scale = 0.5
|
||||
split_kernel = True
|
||||
dout = torch.randn_like(q)
|
||||
# reference implementation
|
||||
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
||||
@@ -608,7 +702,7 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
ref_dk, k.grad = k.grad.clone(), None
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
tri_out = attention(q, k, v, causal, sm_scale)
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
@@ -636,14 +730,12 @@ except BaseException:
|
||||
configs = []
|
||||
for mode in ['fwd', 'bwd']:
|
||||
for D_HEAD in [128, 64]:
|
||||
if mode == 'bwd' and D_HEAD == 128:
|
||||
continue
|
||||
for causal in [False, True]:
|
||||
if mode == 'bwd' and causal == False:
|
||||
continue
|
||||
configs.append(triton.testing.Benchmark(
|
||||
x_names=['BATCH', 'H','N_CTX'],
|
||||
x_vals=[(16, 16, 1024),
|
||||
x_names=['BATCH', 'H', 'N_CTX'],
|
||||
x_vals=[(4, 16, 1024),
|
||||
(8, 16, 2048),
|
||||
(4, 16, 4096),
|
||||
(2, 16, 8192),
|
||||
@@ -673,12 +765,10 @@ for mode in ['fwd', 'bwd']:
|
||||
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
|
||||
assert mode in ["fwd", "bwd"]
|
||||
warmup = 25
|
||||
rep = 100
|
||||
split_kernel = False
|
||||
rep = 10
|
||||
# Bwd pass only supports causal=True right now
|
||||
if mode == 'bwd':
|
||||
causal = True
|
||||
split_kernel = True
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
@@ -686,8 +776,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
|
||||
if mode == "fwd" and TORCH_HAS_FP8:
|
||||
q = q.to(torch_dtype)
|
||||
k = k.to(torch_dtype)
|
||||
sm_scale = 1.3
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
|
||||
sm_scale = D_HEAD ** -0.5
|
||||
fn = lambda: attention(q, k, v, causal, sm_scale)
|
||||
if mode == 'bwd':
|
||||
o = fn()
|
||||
do = torch.randn_like(o)
|
||||
@@ -709,6 +799,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
|
||||
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
|
||||
return total_flops / ms * 1e-9
|
||||
|
||||
|
||||
# only works on post-Ampere GPUs right now
|
||||
bench_flash_attention.run(save_path=".", print_data=True)
|
||||
|
||||
Reference in New Issue
Block a user