Merge changes from upstream FA bwd kernel (#444)

* Add optimized FA bwd from upstream

* Add autotuning

* Change loads and stores to use block ptrs

* Cleanup
This commit is contained in:
Vinayak Gokhale
2024-01-05 15:12:05 -06:00
committed by GitHub
parent bcea3051af
commit c2766bbd5f

View File

@@ -186,230 +186,361 @@ def _attn_fwd(Q, K, V, sm_scale, M, Out,
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
tl.store(O_block_ptr, acc.to(Out.type.element_ty))
@triton.jit
def _attn_bwd_preprocess(O, DO, #
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
def _attn_bwd_preprocess(O, DO,
Delta,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :])
do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
tl.store(Delta + off_hz * N_CTX + off_m, delta)
# The main inner-loop logic for computing dK and dV.
@triton.jit
def _bwd_kernel_dk_dv(Q, K, V, sm_scale,
Out, DO,
DK, DV,
L, D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# Q is consumed depending on block ID. Every block uses
# previous block offset by BLOCK_M x D_HEAD.
qvk_offset = off_hz * stride_qh
qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qdo_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
def _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
# shared by Q/K/V/DO.
stride_tok, stride_d,
H, N_CTX, BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
# Filled in by the wrapper.
start_n, start_m, num_steps,
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M1)
offs_n = start_n + tl.arange(0, BLOCK_N1)
offs_k = tl.arange(0, BLOCK_DMODEL)
QT_block_ptr = tl.make_block_ptr(
base=Q,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
strides=(stride_d, stride_tok),
offsets=(0, start_m),
block_shape=(BLOCK_DMODEL, BLOCK_M1),
order=(0,1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qdo_offset,
base=DO,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M1, BLOCK_DMODEL),
order=(1,0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load k and v: they will stay in SRAM throughout
k = tl.load(K_block_ptr)
k = (k * qk_scale).to(k.dtype)
v = tl.load(V_block_ptr)
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# This lower loop bound is because of the causal mask. We create a lower triangular
# result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
# be ignored in the GEMM.
lo = start_m * BLOCK_M
hi = N_CTX
# loop over q, do
for start_n in range(lo, hi, BLOCK_N):
offs_m_curr = offs_n[:, None] + start_n
# -- load q, do --
q = tl.load(Q_block_ptr)
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
curr_m = start_m
step_m = BLOCK_M1
for blk_idx in range(num_steps):
qT = tl.load(QT_block_ptr)
# Load m before computing qk to reduce pipeline stall.
offs_m = curr_m + tl.arange(0, BLOCK_M1)
m = tl.load(M + offs_m)
qkT = tl.dot(k, qT)
pT = tl.math.exp2(qkT - m[None, :])
# Autoregressive masking.
if MASK:
mask = (offs_m[None, :] >= offs_n[:, None])
pT = tl.where(mask, pT, 0.0)
do = tl.load(DO_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i)
# -- compute dv ----
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dk
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# update pointers
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
# initialize pointers to output
DK_block_ptr = tl.make_block_ptr(
base=DK + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
# Compute dV.
ppT = pT
ppT = ppT.to(tl.float16)
dv += tl.dot(ppT, do)
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(tl.float16)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
return dk, dv
# the main inner-loop logic for computing dQ
@triton.jit
def _attn_bwd_dq(dq, q, K, V,
do, m, D,
# shared by Q/K/V/DO.
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
# Filled in by the wrapper.
start_m, start_n, num_steps,
MASK: tl.constexpr):
offs_m = start_m + tl.arange(0, BLOCK_M2)
offs_n = start_n + tl.arange(0, BLOCK_N2)
offs_k = tl.arange(0, BLOCK_DMODEL)
KT_block_ptr = tl.make_block_ptr(
base=K,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_n),
block_shape=(BLOCK_DMODEL, BLOCK_N2),
order=(0, 1)
)
DV_block_ptr = tl.make_block_ptr(
base=DV + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
VT_block_ptr = tl.make_block_ptr(
base=V,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_d, stride_tok),
offsets=(0, start_n),
block_shape=(BLOCK_DMODEL, BLOCK_N2),
order=(0, 1)
)
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
tl.store(DV_block_ptr, dv.to(tl.float16))
# D (= delta) is pre-divided by ds_scale.
Di = tl.load(D + offs_m)
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
curr_n = start_n
step_n = BLOCK_N2
for blk_idx in range(num_steps):
kT = tl.load(KT_block_ptr)
qk = tl.dot(q, kT)
p = tl.math.exp2(qk - m)
# Autoregressive masking.
if MASK:
offs_n = curr_n + tl.arange(0, BLOCK_N2)
mask = (offs_m[:, None] >= offs_n[None, :])
p = tl.where(mask, p, 0.0)
# Compute dP and dS.
vT = tl.load(VT_block_ptr)
dp = tl.dot(do, vT).to(tl.float32)
ds = p * (dp - Di[:, None])
ds = ds.to(tl.float16)
# Compute dQ.
# NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
dq += tl.dot(ds, tl.trans(kT))
# Increment pointers.
curr_n += step_n
KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
return dq
@triton.autotune(
configs=[
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 64, 'BLOCK_N1': 64, 'BLOCK_M2': 64, 'BLOCK_N2': 64, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 1},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=4),
triton.Config({'BLOCK_M1': 32, 'BLOCK_N1': 128, 'BLOCK_M2': 128, 'BLOCK_N2': 32, 'BLK_SLICE_FACTOR': 2},
num_stages=1, num_warps=8),
],
key=['H', 'N_CTX', 'BLOCK_DMODEL'],
)
@triton.jit
def _bwd_kernel_dq(Q, K, V, sm_scale,
Out, DO,
DQ,
L,D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
def _attn_bwd(Q, K, V, sm_scale,
DO,
DQ, DK, DV,
M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d,
# H = 16, N_CTX = 1024
H, N_CTX,
BLOCK_DMODEL: tl.constexpr,
BLOCK_M1: tl.constexpr,
BLOCK_N1: tl.constexpr,
BLOCK_M2: tl.constexpr,
BLOCK_N2: tl.constexpr,
BLK_SLICE_FACTOR: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
bhid = tl.program_id(2)
off_chz = (bhid * N_CTX).to(tl.int64)
adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
pid = tl.program_id(0)
# offset pointers for batch/head
Q += adj
K += adj
V += adj
DO += adj
DQ += adj
DK += adj
DV += adj
M += off_chz
D += off_chz
offs_k = tl.arange(0, BLOCK_DMODEL)
start_n = pid * BLOCK_N1
# This assignment is important. It is what allows us to pick the diagonal
# blocks. Later, when we want to do the lower triangular, we update start_m
# after the first dkdv call.
start_m = start_n
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
base=K,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
base=V,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1, 0),
)
# load K and V: they stay in SRAM throughout the inner loop for dkdv.
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=True
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv(
dk, dv,
Q, k, v, sm_scale,
DO,
M, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M1, BLOCK_N1, BLOCK_DMODEL,
start_n, start_m, num_steps,
MASK=False
)
DV_block_ptrs = tl.make_block_ptr(
base=DV,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1,0)
)
tl.store(DV_block_ptrs, dv.to(tl.float16))
# Write back dK.
dk *= sm_scale
DK_block_ptrs = tl.make_block_ptr(
base=DK,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_n, 0),
block_shape=(BLOCK_N1, BLOCK_DMODEL),
order=(1,0)
)
tl.store(DK_block_ptrs, dk.to(tl.float16))
# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2
MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
Q_block_ptr = tl.make_block_ptr(
base=Q,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qvk_offset,
base=DO,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load q and do: they will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(q.dtype)
do = tl.load(DO_block_ptr)
Di = tl.load(D_ptrs + offs_m)
l_i = tl.load(l_ptrs + offs_m)
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# loop over k, v
lo = 0
hi = (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
p = tl.math.exp2(qk - l_i[:, None])
# compute dp = dot(v, do)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dq. Unfortunately we cannot avoid transpose here as this loop
# uses k both normal and transpose.
dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
# initialize pointers to output
dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
m = tl.load(M + offs_m)
m = m[:, None]
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps,
MASK=True
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V,
do, m, D,
stride_tok, stride_d,
H, N_CTX,
BLOCK_M2, BLOCK_N2, BLOCK_DMODEL,
start_m, end_n - num_steps * BLOCK_N2, num_steps,
MASK=False
)
# Write back dQ.
DQ_block_ptr = tl.make_block_ptr(
base=DQ + qvk_offset,
base=DQ,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
strides=(stride_tok, stride_d),
offsets=(start_m, 0),
block_shape=(BLOCK_M2, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
dq *= LN2
tl.store(DQ_block_ptr, dq.to(tl.float16))
empty = torch.empty(128, device="cuda")
@@ -417,12 +548,12 @@ empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False):
def forward(ctx, q, k, v, causal, sm_scale):
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q, dtype=v.dtype)
o = torch.empty_like(q)
if torch.version.hip is None:
BLOCK_M = 128
BLOCK_N = 64 if Lk <= 64 else 32
@@ -432,7 +563,6 @@ class _attention(torch.autograd.Function):
if torch.cuda.get_device_capability()[0] == 9:
num_warps = 8
num_stages = 7 if Lk >= 64 else 3
stage = 3 if causal else 1
grid = lambda META: (
triton.cdiv(q.shape[2], META['BLOCK_M']),
@@ -440,7 +570,6 @@ class _attention(torch.autograd.Function):
1
)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
_attn_fwd[grid](
q, k, v, sm_scale, M, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
@@ -463,85 +592,51 @@ class _attention(torch.autograd.Function):
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
ctx.split_kernel = split_kernel
return o
@staticmethod
def backward(ctx, do):
# configuration is not supported
assert(not (ctx.split_kernel and not ctx.causal))
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
do = do.contiguous()
dq = torch.zeros_like(q)
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
delta = torch.empty_like(L)
do_scaled = torch.empty_like(do)
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
# If the two are the same, we don't need this but the bwd pass block size
# is smaller than the fwd so we need this scaling to ensure we loop over all
# values and don't skip some blocks.
# Alternatively we could compute a new grid but this keeps it consistent
# with fwd and easier to reason about.
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
_attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, #
do_scaled, delta, #
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 1
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
arg_k = k
arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
assert N_CTX % PRE_BLOCK == 0
pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
delta = torch.empty_like(M)
_attn_bwd_preprocess[pre_grid](
o, do,
delta,
BATCH, N_HEAD, N_CTX,
BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL
)
if not ctx.split_kernel:
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
block_scale * ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
CAUSAL=ctx.causal,
num_stages=1,
)
else :
dq = torch.zeros_like(q)
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
q, k, v, ctx.sm_scale,
o, do_scaled,
dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
num_stages=1,
)
_bwd_kernel_dq[ctx.grid](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
num_stages=1,
)
# print(h.asm["ttgir"])
return dq, dk, dv, None, None, None
grid = lambda META: (
triton.cdiv(N_CTX, META['BLOCK_N1']),
1,
BATCH * N_HEAD
)
_attn_bwd[grid](
q, arg_k, v, ctx.sm_scale, do, dq, dk, dv,
M, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
N_HEAD, N_CTX,
BLOCK_DMODEL=ctx.BLOCK_DMODEL
)
return dq, dk, dv, None, None
attention = _attention.apply
@@ -594,7 +689,6 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
sm_scale = 0.5
split_kernel = True
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
@@ -608,7 +702,7 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
tri_out = attention(q, k, v, causal, sm_scale)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
@@ -636,14 +730,12 @@ except BaseException:
configs = []
for mode in ['fwd', 'bwd']:
for D_HEAD in [128, 64]:
if mode == 'bwd' and D_HEAD == 128:
continue
for causal in [False, True]:
if mode == 'bwd' and causal == False:
continue
configs.append(triton.testing.Benchmark(
x_names=['BATCH', 'H','N_CTX'],
x_vals=[(16, 16, 1024),
x_names=['BATCH', 'H', 'N_CTX'],
x_vals=[(4, 16, 1024),
(8, 16, 2048),
(4, 16, 4096),
(2, 16, 8192),
@@ -673,12 +765,10 @@ for mode in ['fwd', 'bwd']:
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"):
assert mode in ["fwd", "bwd"]
warmup = 25
rep = 100
split_kernel = False
rep = 10
# Bwd pass only supports causal=True right now
if mode == 'bwd':
causal = True
split_kernel = True
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
@@ -686,8 +776,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
if mode == "fwd" and TORCH_HAS_FP8:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = 1.3
fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel)
sm_scale = D_HEAD ** -0.5
fn = lambda: attention(q, k, v, causal, sm_scale)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
@@ -709,6 +799,5 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path=".", print_data=True)