mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TUTORIAL] more attention cleanup (#1958)
This commit is contained in:
@@ -26,7 +26,7 @@ def max_fn(x, y):
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
L, M,
|
||||
L,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
@@ -35,7 +35,7 @@ def _fwd_kernel(
|
||||
Z, H, N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
MODE: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
):
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
@@ -64,6 +64,51 @@ def _fwd_kernel(
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
# loop over k, v and update accumulator
|
||||
lo = 0
|
||||
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
# -- load k, v --
|
||||
k = tl.load(K_block_ptr)
|
||||
v = tl.load(V_block_ptr)
|
||||
# -- compute qk ---
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
if IS_CAUSAL:
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
qk += tl.dot(q, k)
|
||||
# -- compute scaling constant ---
|
||||
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
p = tl.math.exp2(qk - m_i_new[:, None])
|
||||
# -- scale and update acc --
|
||||
acc_scale = l_i * 0 + alpha # workaround some compiler bug
|
||||
acc *= acc_scale[:, None]
|
||||
acc += tl.dot(p.to(tl.float16), v)
|
||||
# -- update m_i and l_i --
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
m_i = m_i_new
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
# write back l and m
|
||||
acc = acc / l_i[:, None]
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
@@ -72,82 +117,12 @@ def _fwd_kernel(
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# initialize offsets
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# causal check on every loop iteration can be expensive
|
||||
# and peeling the last iteration of the loop does not work well with ptxas
|
||||
# so we have a mode to do the causal check in a separate kernel entirely
|
||||
if MODE == 0: # entire non-causal attention
|
||||
lo, hi = 0, N_CTX
|
||||
if MODE == 1: # entire causal attention
|
||||
lo, hi = 0, (start_m + 1) * BLOCK_M
|
||||
if MODE == 2: # off band-diagonal
|
||||
lo, hi = 0, start_m * BLOCK_M
|
||||
if MODE == 3: # on band-diagonal
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
m_i = tl.load(m_ptrs)
|
||||
l_i = tl.load(l_ptrs)
|
||||
acc += tl.load(O_block_ptr).to(tl.float32)
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
q = (q * qk_scale).to(tl.float16)
|
||||
# advance block pointers to first iteration of the loop
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(lo, hi, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(K_block_ptr)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, k)
|
||||
if MODE == 1 or MODE == 3:
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.math.exp2(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_i *= alpha
|
||||
l_i_new = l_i + l_ij
|
||||
# scale acc
|
||||
acc_scale = l_i * 0 + alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_block_ptr)
|
||||
p = p.to(tl.float16)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_ij
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
# write back l and m
|
||||
acc = acc / l_i[:, None]
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
# write back O
|
||||
tl.store(O_block_ptr, acc.to(tl.float16))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bwd_preprocess(
|
||||
Out, DO, L,
|
||||
Out, DO,
|
||||
NewDO, Delta,
|
||||
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,
|
||||
):
|
||||
@@ -156,9 +131,7 @@ def _bwd_preprocess(
|
||||
# load
|
||||
o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
||||
denom = tl.load(L + off_m).to(tl.float32)
|
||||
# compute
|
||||
do = do / denom[:, None]
|
||||
delta = tl.sum(o * do, axis=1)
|
||||
# write-back
|
||||
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
||||
@@ -169,7 +142,7 @@ def _bwd_preprocess(
|
||||
def _bwd_kernel(
|
||||
Q, K, V, sm_scale, Out, DO,
|
||||
DQ, DK, DV,
|
||||
L, M,
|
||||
L,
|
||||
D,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
@@ -178,7 +151,7 @@ def _bwd_kernel(
|
||||
num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
MODE: tl.constexpr,
|
||||
CAUSAL: tl.constexpr,
|
||||
):
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
@@ -193,10 +166,10 @@ def _bwd_kernel(
|
||||
DK += off_z * stride_qz + off_h * stride_qh
|
||||
DV += off_z * stride_qz + off_h * stride_qh
|
||||
for start_n in range(0, num_block):
|
||||
if MODE == 0:
|
||||
lo = 0
|
||||
else:
|
||||
if CAUSAL:
|
||||
lo = start_n * BLOCK_M
|
||||
else:
|
||||
lo = 0
|
||||
# initialize row/col offsets
|
||||
offs_qm = lo + tl.arange(0, BLOCK_M)
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -210,7 +183,7 @@ def _bwd_kernel(
|
||||
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
|
||||
# pointer to row-wise quantities in value-like data
|
||||
D_ptrs = D + off_hz * N_CTX
|
||||
m_ptrs = M + off_hz * N_CTX
|
||||
l_ptrs = L + off_hz * N_CTX
|
||||
# initialize dv amd dk
|
||||
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
@@ -223,16 +196,14 @@ def _bwd_kernel(
|
||||
# load q, k, v, do on-chip
|
||||
q = tl.load(q_ptrs)
|
||||
# recompute p = softmax(qk, dim=-1).T
|
||||
# NOTE: `do` is pre-divided by `l`; no normalization here
|
||||
# if MODE == 1:
|
||||
if MODE == 1:
|
||||
if CAUSAL:
|
||||
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), float(0.), float("-inf"))
|
||||
else:
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= qk_scale
|
||||
m = tl.load(m_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - m[:, None])
|
||||
l_i = tl.load(l_ptrs + offs_m_curr)
|
||||
p = tl.math.exp2(qk - l_i[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
@@ -275,29 +246,23 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_N = 64
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
if causal:
|
||||
modes = [1] if q.shape[2] <= 2048 else [2, 3]
|
||||
else:
|
||||
modes = [0]
|
||||
for mode in modes:
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
MODE=mode,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
L,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
num_stages=4)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.save_for_backward(q, k, v, o, L)
|
||||
ctx.grid = grid
|
||||
ctx.sm_scale = sm_scale
|
||||
ctx.BLOCK_DMODEL = Lk
|
||||
@@ -307,19 +272,15 @@ class _attention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, do):
|
||||
BLOCK = 128
|
||||
q, k, v, o, l, m = ctx.saved_tensors
|
||||
q, k, v, o, L = ctx.saved_tensors
|
||||
do = do.contiguous()
|
||||
dq = torch.zeros_like(q, dtype=torch.float32)
|
||||
dk = torch.empty_like(k)
|
||||
dv = torch.empty_like(v)
|
||||
do_scaled = torch.empty_like(do)
|
||||
delta = torch.empty_like(l)
|
||||
if ctx.causal:
|
||||
mode = 1
|
||||
else:
|
||||
mode = 0
|
||||
delta = torch.empty_like(L)
|
||||
_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
||||
o, do, l,
|
||||
o, do,
|
||||
do_scaled, delta,
|
||||
BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,
|
||||
)
|
||||
@@ -327,8 +288,7 @@ class _attention(torch.autograd.Function):
|
||||
q, k, v, ctx.sm_scale,
|
||||
o, do_scaled,
|
||||
dq, dk, dv,
|
||||
l, m,
|
||||
delta,
|
||||
L, delta,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
@@ -336,7 +296,7 @@ class _attention(torch.autograd.Function):
|
||||
ctx.grid[0],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,
|
||||
MODE=mode,
|
||||
CAUSAL=ctx.causal,
|
||||
num_stages=1,
|
||||
)
|
||||
return dq, dk, dv, None, None
|
||||
@@ -397,7 +357,7 @@ configs = [triton.testing.Benchmark(
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
|
||||
) for mode in ['fwd'] for causal in [False]]
|
||||
) for mode in ['bwd'] for causal in [False]]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
|
||||
Reference in New Issue
Block a user