mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
This reverts commit 1fec965c06.
This change used pre_hook to edit a kernel arg. However,
pre-hook does not make the changes made within visible to
the kernel in all cases.
886 lines
33 KiB
Python
886 lines
33 KiB
Python
"""
|
|
Fused Attention
|
|
===============
|
|
|
|
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
|
|
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
|
|
|
|
Features supported:
|
|
|
|
1) Fwd + bwd kernel with causal masking
|
|
2) Vector and matrix bias (currently fwd kernel only, no causal masking)
|
|
3) Any sequence lengths without padding (currently fwd kernel only, no causal masking)
|
|
4) fp8 (e5m2fnuz, QK GEMM in fwd kernel only)
|
|
|
|
Not currently supported:
|
|
|
|
1) Nested / ragged tensors ("varlen")
|
|
|
|
"""
|
|
|
|
import pytest
|
|
import random
|
|
import torch
|
|
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
torch_dtype:tl.constexpr = torch.float16
|
|
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
|
|
if TORCH_HAS_FP8E5:
|
|
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
|
|
|
|
def prepare_bias(bias, batch, nheads, seqlen):
|
|
assert bias.is_cuda
|
|
assert bias.dim() == 4
|
|
if bias.shape[2:] == (1, seqlen):
|
|
bias_type = "vector"
|
|
elif bias.shape[2:] == (seqlen, seqlen):
|
|
bias_type = "matrix"
|
|
else:
|
|
raise RuntimeError(
|
|
"Last 2 dimensions of bias must be (1, seqlen)" " or (seqlen, seqlen)"
|
|
)
|
|
return bias.expand(batch, nheads, seqlen, seqlen), bias_type
|
|
|
|
@triton.jit
|
|
def _attn_fwd_inner(
|
|
acc, l_i, m_i, q,
|
|
K_block_ptr, V_block_ptr,
|
|
start_m,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
STAGE: tl.constexpr,
|
|
offs_m: tl.constexpr,
|
|
offs_n: tl.constexpr,
|
|
N_CTX,
|
|
pre_load_v: tl.constexpr,
|
|
padded_block: tl.constexpr,
|
|
total_tokens: tl.constexpr,
|
|
bias_ptr
|
|
):
|
|
# range of values handled by this stage
|
|
if STAGE == 1:
|
|
lo, hi = 0, start_m * BLOCK_M
|
|
elif STAGE == 2:
|
|
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
|
lo = tl.multiple_of(lo, BLOCK_M)
|
|
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
|
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
|
# N_CTX is the seqlen to the nearest block (round down).
|
|
# So here, we are computing the elements for that last irregular block.
|
|
# In the loop, we will mask the elements of BLOCK_N that do not exist.
|
|
elif padded_block:
|
|
lo, hi = N_CTX, N_CTX + BLOCK_N
|
|
lo = tl.multiple_of(lo, BLOCK_N)
|
|
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
|
|
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
|
|
if bias_ptr is not None:
|
|
if bias_ptr.type.element_ty.is_block():
|
|
bias_ptr = tl.advance(bias_ptr, (0, lo))
|
|
else:
|
|
bias_ptr += lo
|
|
# causal = False
|
|
else:
|
|
lo, hi = 0, N_CTX
|
|
# loop over k, v and update accumulator
|
|
for start_n in range(lo, hi, BLOCK_N):
|
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
# For padded blocks, we will overrun the tensor size if
|
|
# we load all BLOCK_N. For others, the blocks are all within range.
|
|
if padded_block:
|
|
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
|
|
else:
|
|
k = tl.load(K_block_ptr)
|
|
if pre_load_v:
|
|
if padded_block:
|
|
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
|
|
else:
|
|
v = tl.load(V_block_ptr)
|
|
# -- compute qk ----
|
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
if STAGE == 2:
|
|
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
|
|
qk = tl.where(mask, qk, float("-inf"))
|
|
qk += tl.dot(q, k)
|
|
if bias_ptr is not None:
|
|
if padded_block:
|
|
if bias_ptr.type.element_ty.is_block():
|
|
bias = tl.load(bias_ptr,boundary_check=(1,), padding_option="zero")
|
|
else:
|
|
size_n = start_n + offs_n
|
|
boundary_n = tl.full([BLOCK_N], total_tokens, dtype=tl.float32)
|
|
bias_padding = tl.full([BLOCK_N], 0, dtype=tl.float32)
|
|
bias = tl.load(bias_ptr, mask=size_n < boundary_n, other=bias_padding)
|
|
else:
|
|
bias = tl.load(bias_ptr)
|
|
# While bias is added after multiplying qk with sm_scale,
|
|
# our optimization to use 2^x instead of e^x results in an additional
|
|
# scale factor of log2(e) which we must also multiply the bias with.
|
|
qk += (bias * 1.44269504)
|
|
if padded_block:
|
|
boundary_m = tl.full([BLOCK_M], total_tokens, dtype=tl.float32)
|
|
size_n = start_n + offs_n[None,:]
|
|
mask = size_n < boundary_m[:,None]
|
|
qk = tl.where(mask, qk, float("-inf"))
|
|
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
|
qk = qk - m_ij[:, None]
|
|
p = tl.math.exp2(qk)
|
|
# -- update output accumulator --
|
|
alpha = tl.math.exp2(m_i - m_ij)
|
|
acc = acc * alpha[:, None]
|
|
if not pre_load_v:
|
|
if padded_block:
|
|
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
|
|
else:
|
|
v = tl.load(V_block_ptr)
|
|
acc += tl.dot(p.to(v.dtype), v)
|
|
# -- update m_i and l_i
|
|
l_ij = tl.sum(p, 1)
|
|
l_i = l_i * alpha + l_ij
|
|
# update m_i and l_i
|
|
m_i = m_ij
|
|
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
|
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
|
if bias_ptr is not None:
|
|
if bias_ptr.type.element_ty.is_block():
|
|
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
|
|
else:
|
|
bias_ptr += BLOCK_N
|
|
return acc, l_i, m_i
|
|
|
|
|
|
@triton.jit
|
|
def _attn_fwd(
|
|
Q, K, V, bias, sm_scale, M, Out,
|
|
stride_qz, stride_qh, stride_qm, stride_qk,
|
|
stride_kz, stride_kh, stride_kn, stride_kk,
|
|
stride_vz, stride_vh, stride_vk, stride_vn,
|
|
stride_oz, stride_oh, stride_om, stride_on,
|
|
stride_bz, stride_bh, stride_bm, stride_bn,
|
|
Z, H,
|
|
N_CTX,
|
|
BLOCK_DMODEL: tl.constexpr,
|
|
STAGE: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
pre_load_v: tl.constexpr,
|
|
need_padding: tl.constexpr,
|
|
extra_tokens_n: tl.constexpr,
|
|
bias_type: tl.constexpr
|
|
):
|
|
start_m = tl.program_id(0)
|
|
off_hz = tl.program_id(1)
|
|
qvk_offset = off_hz * stride_qh
|
|
|
|
# block pointers
|
|
Q_block_ptr = tl.make_block_ptr(
|
|
base=Q + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0),
|
|
)
|
|
V_block_ptr = tl.make_block_ptr(
|
|
base=V + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_vk, stride_vn),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
|
order=(1, 0),
|
|
)
|
|
K_block_ptr = tl.make_block_ptr(
|
|
base=K + qvk_offset,
|
|
shape=(BLOCK_DMODEL, N_CTX),
|
|
strides=(stride_kk, stride_kn),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
order=(0, 1),
|
|
)
|
|
O_block_ptr = tl.make_block_ptr(
|
|
base=Out + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_om, stride_on),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0),
|
|
)
|
|
# initialize offsets
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
if bias is not None:
|
|
if bias_type == "vector":
|
|
bias_ptr = bias + ((off_hz % H) * stride_bh) + offs_n
|
|
elif bias_type == "matrix":
|
|
bias_ptr = tl.make_block_ptr(
|
|
base=bias + ((off_hz % H) * stride_bh),
|
|
shape=(N_CTX, N_CTX),
|
|
strides=(stride_bm, stride_bn),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_N),
|
|
order=(1, 0),
|
|
)
|
|
else:
|
|
bias_ptr = None
|
|
# initialize pointer to m and l
|
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
|
|
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
# scale sm_scale by log_2(e) and use
|
|
# 2^x instead of exp in the loop because CSE and LICM
|
|
# don't work as expected with `exp` in the loop
|
|
qk_scale = sm_scale * 1.44269504
|
|
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
|
|
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")
|
|
q = (q * qk_scale).to(q.dtype)
|
|
# stage 1: off-band
|
|
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
|
|
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
|
|
if STAGE & 1:
|
|
# We don't currently support causal masking and padding.
|
|
tl.static_assert((STAGE != 3) or not need_padding)
|
|
# equal to N_CTX if N_CTX is already a multiple of block_M
|
|
seqlen_aligned = N_CTX - extra_tokens_n
|
|
if N_CTX >= BLOCK_N:
|
|
acc, l_i, m_i = _attn_fwd_inner(
|
|
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
|
start_m,
|
|
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
|
4 - STAGE, offs_m, offs_n,
|
|
seqlen_aligned, pre_load_v,
|
|
False, seqlen_aligned,
|
|
bias_ptr
|
|
)
|
|
tl.debug_barrier()
|
|
if need_padding:
|
|
if N_CTX < BLOCK_N:
|
|
seqlen_aligned = 0
|
|
acc, l_i, m_i = _attn_fwd_inner(
|
|
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
|
start_m,
|
|
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
|
4 - STAGE, offs_m, offs_n,
|
|
seqlen_aligned, pre_load_v,
|
|
True, N_CTX,
|
|
bias_ptr
|
|
)
|
|
# stage 2: on-band
|
|
if STAGE & 2:
|
|
# barrier makes it easier for compielr to schedule the
|
|
# two loops independently
|
|
tl.debug_barrier()
|
|
acc, l_i, m_i = _attn_fwd_inner(
|
|
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
|
start_m,
|
|
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
|
2, offs_m, offs_n,
|
|
N_CTX, pre_load_v,
|
|
)
|
|
# epilogue
|
|
# write back m
|
|
acc = acc / l_i[:, None]
|
|
m_ptrs = M + off_hz * N_CTX + offs_m
|
|
# Check for last block_M
|
|
overflow_size = (start_m * BLOCK_M) - N_CTX
|
|
if overflow_size > 0:
|
|
boundary = tl.full((BLOCK_M,), overflow_size, dtype=tl.float32)
|
|
# This is a > check because mask being 0 blocks the store.
|
|
m_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
|
|
tl.store(m_ptrs, m_i + tl.math.log2(l_i), mask=m_ptrs_mask)
|
|
else:
|
|
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
|
|
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
|
|
|
|
|
|
@triton.jit
|
|
def _attn_bwd_preprocess(O, DO, #
|
|
NewDO, Delta, #
|
|
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
|
|
):
|
|
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
off_n = tl.arange(0, D_HEAD)
|
|
# load
|
|
o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
|
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
|
|
delta = tl.sum(o * do, axis=1)
|
|
# write-back
|
|
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
|
|
tl.store(Delta + off_m, delta)
|
|
|
|
|
|
@triton.jit
|
|
def _bwd_kernel_dk_dv(
|
|
Q, K, V, sm_scale, Out, DO,
|
|
DK, DV,
|
|
L,
|
|
D,
|
|
stride_qz, stride_qh, stride_qm, stride_qk,
|
|
stride_kz, stride_kh, stride_kn, stride_kk,
|
|
stride_vz, stride_vh, stride_vk, stride_vn,
|
|
Z, H, N_CTX,
|
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
start_m = tl.program_id(0)
|
|
off_hz = tl.program_id(1)
|
|
# Q is consumed depending on block ID. Every block uses
|
|
# previous block offset by BLOCK_M x D_HEAD.
|
|
qvk_offset = off_hz * stride_qh
|
|
qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
|
|
# initialize offsets
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
# Initialize pointers to Q, K, V
|
|
Q_block_ptr = tl.make_block_ptr(
|
|
base=Q + qdo_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
|
order=(1, 0)
|
|
)
|
|
K_block_ptr = tl.make_block_ptr(
|
|
base=K + qvk_offset,
|
|
shape=(BLOCK_DMODEL, N_CTX),
|
|
strides=(stride_kk, stride_kn),
|
|
offsets=(0, start_m * BLOCK_M),
|
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
order=(0, 1)
|
|
)
|
|
V_block_ptr = tl.make_block_ptr(
|
|
base=V + qvk_offset,
|
|
shape=(BLOCK_DMODEL, N_CTX),
|
|
strides=(stride_vn, stride_vk),
|
|
offsets=(0, start_m * BLOCK_M),
|
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
order=(0, 1)
|
|
)
|
|
DO_block_ptr = tl.make_block_ptr(
|
|
base=DO + qdo_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
|
order=(1, 0)
|
|
)
|
|
# pointer to row-wise quantities in value-like data
|
|
D_ptrs = D + off_hz * N_CTX
|
|
l_ptrs = L + off_hz * N_CTX
|
|
qk_scale = sm_scale * 1.44269504
|
|
# load k and v: they will stay in SRAM throughout
|
|
k = tl.load(K_block_ptr)
|
|
k = (k * qk_scale).to(k.dtype)
|
|
v = tl.load(V_block_ptr)
|
|
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
# This lower loop bound is because of the causal mask. We create a lower triangular
|
|
# result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
|
|
# be ignored in the GEMM.
|
|
lo = start_m * BLOCK_M
|
|
hi = N_CTX
|
|
# loop over q, do
|
|
for start_n in range(lo, hi, BLOCK_N):
|
|
offs_m_curr = offs_n[:, None] + start_n
|
|
# -- load q, do --
|
|
q = tl.load(Q_block_ptr)
|
|
do = tl.load(DO_block_ptr)
|
|
# -- compute qk ----
|
|
qk = tl.dot(q, k)
|
|
qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
|
|
l_i = tl.load(l_ptrs + offs_m_curr)
|
|
p = tl.math.exp2(qk - l_i)
|
|
# -- compute dv ----
|
|
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
|
|
# compute dp = dot(v, do)
|
|
Di = tl.load(D_ptrs + offs_m_curr)
|
|
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
|
|
dp += tl.dot(do, v)
|
|
# compute ds = p * (dp - delta[:, None])
|
|
ds = p * dp
|
|
# compute dk
|
|
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
|
# update pointers
|
|
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
|
|
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
|
|
# initialize pointers to output
|
|
DK_block_ptr = tl.make_block_ptr(
|
|
base=DK + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_kn, stride_kk),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0)
|
|
)
|
|
DV_block_ptr = tl.make_block_ptr(
|
|
base=DV + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_vk, stride_vn),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0)
|
|
)
|
|
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
|
|
tl.store(DV_block_ptr, dv.to(tl.float16))
|
|
|
|
@triton.jit
|
|
def _bwd_kernel_dq(
|
|
Q, K, V, sm_scale, Out, DO,
|
|
DQ,
|
|
L,
|
|
D,
|
|
stride_qz, stride_qh, stride_qm, stride_qk,
|
|
stride_kz, stride_kh, stride_kn, stride_kk,
|
|
stride_vz, stride_vh, stride_vk, stride_vn,
|
|
Z, H, N_CTX,
|
|
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
):
|
|
start_m = tl.program_id(0)
|
|
off_hz = tl.program_id(1)
|
|
qvk_offset = off_hz * stride_qh
|
|
# initialize offsets
|
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
offs_n = tl.arange(0, BLOCK_N)
|
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
|
# Initialize pointers to Q, K, V
|
|
Q_block_ptr = tl.make_block_ptr(
|
|
base=Q + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0)
|
|
)
|
|
K_block_ptr = tl.make_block_ptr(
|
|
base=K + qvk_offset,
|
|
shape=(BLOCK_DMODEL, N_CTX),
|
|
strides=(stride_kk, stride_kn),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
order=(0, 1)
|
|
)
|
|
V_block_ptr = tl.make_block_ptr(
|
|
base=V + qvk_offset,
|
|
shape=(BLOCK_DMODEL, N_CTX),
|
|
strides=(stride_vn, stride_vk),
|
|
offsets=(0, 0),
|
|
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
|
order=(0, 1)
|
|
)
|
|
DO_block_ptr = tl.make_block_ptr(
|
|
base=DO + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0)
|
|
)
|
|
# pointer to row-wise quantities in value-like data
|
|
D_ptrs = D + off_hz * N_CTX
|
|
l_ptrs = L + off_hz * N_CTX
|
|
qk_scale = sm_scale * 1.44269504
|
|
# load q and do: they will stay in SRAM throughout
|
|
q = tl.load(Q_block_ptr)
|
|
q = (q * qk_scale).to(q.dtype)
|
|
do = tl.load(DO_block_ptr)
|
|
Di = tl.load(D_ptrs + offs_m)
|
|
l_i = tl.load(l_ptrs + offs_m)
|
|
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
|
# loop over k, v
|
|
lo = 0
|
|
hi = (start_m + 1) * BLOCK_M
|
|
for start_n in range(lo, hi, BLOCK_N):
|
|
# -- load k, v --
|
|
k = tl.load(K_block_ptr)
|
|
v = tl.load(V_block_ptr)
|
|
# -- compute qk ----
|
|
qk = tl.dot(q, k)
|
|
qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
|
|
p = tl.math.exp2(qk - l_i[:, None])
|
|
# compute dp = dot(v, do)
|
|
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
|
dp += tl.dot(do, v)
|
|
# compute ds = p * (dp - delta[:, None])
|
|
ds = p * dp
|
|
# compute dq. Unfortunately we cannot avoid transpose here as this loop
|
|
# uses k both normal and transpose.
|
|
dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
|
|
# update pointers
|
|
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
|
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
|
|
# initialize pointers to output
|
|
DQ_block_ptr = tl.make_block_ptr(
|
|
base=DQ + qvk_offset,
|
|
shape=(N_CTX, BLOCK_DMODEL),
|
|
strides=(stride_qm, stride_qk),
|
|
offsets=(start_m * BLOCK_M, 0),
|
|
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
|
order=(1, 0)
|
|
)
|
|
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
|
|
|
|
empty = torch.empty(128, device="cuda")
|
|
|
|
|
|
class _attention(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, q, k, v, causal, bias, sm_scale, split_kernel=False):
|
|
# shape constraints
|
|
batch, nheads, seqlen, Lq = q.shape
|
|
Lk, Lv = k.shape[-1], v.shape[-1]
|
|
assert Lq == Lk and Lk == Lv
|
|
assert Lk in {16, 32, 64, 128}
|
|
# For now we assume K and V seqlen = Q seqlen
|
|
assert seqlen == k.shape[-2] and seqlen == v.shape[-2]
|
|
|
|
# We've derived these previously from tuning the kernel
|
|
BLOCK_M = 256
|
|
BLOCK_N = 128 if Lq == 128 else 64
|
|
waves_per_eu = 2 if Lq == 128 else 3
|
|
num_warps = 8 if Lq == 128 else 4
|
|
pre_load_v = False if Lq == 128 else True
|
|
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
|
|
|
stage = 3 if causal else 1
|
|
seqlen_k = k.shape[-2]
|
|
if seqlen_k < BLOCK_N:
|
|
need_padding = True
|
|
extra_tokens_n = BLOCK_N - seqlen_k
|
|
# This effectively means we cannot slice across Q.
|
|
assert(grid[0] == 1)
|
|
elif seqlen_k % BLOCK_N:
|
|
need_padding = True
|
|
extra_tokens_n = seqlen_k % BLOCK_N
|
|
else:
|
|
# We don't care if the M dim needs padding, as we
|
|
# always boundary_check on Q and O
|
|
need_padding = False
|
|
extra_tokens_n = 0
|
|
|
|
o = torch.empty_like(q, dtype=v.dtype)
|
|
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
|
|
|
if bias is not None:
|
|
bias, bias_type = prepare_bias(bias, batch, nheads, seqlen)
|
|
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2), bias.stride(3))
|
|
else:
|
|
bias, bias_type, bias_strides = None, None, (0,0,0,0)
|
|
|
|
_attn_fwd[grid](
|
|
q, k, v, bias, sm_scale, M, o,
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
|
*bias_strides,
|
|
q.shape[0], q.shape[1],
|
|
N_CTX=q.shape[2],
|
|
BLOCK_DMODEL=Lk,
|
|
STAGE=stage,
|
|
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
|
waves_per_eu=waves_per_eu, pre_load_v=pre_load_v,
|
|
need_padding=need_padding, extra_tokens_n=extra_tokens_n,
|
|
bias_type=bias_type,
|
|
num_stages=1, num_warps=num_warps
|
|
)
|
|
|
|
ctx.save_for_backward(q, k, v, o, M)
|
|
ctx.grid = grid
|
|
ctx.sm_scale = sm_scale
|
|
ctx.BLOCK_DMODEL = Lk
|
|
ctx.causal = causal
|
|
ctx.split_kernel = split_kernel
|
|
return o
|
|
|
|
@staticmethod
|
|
def backward(ctx, do):
|
|
# configuration is not supported
|
|
assert(not (ctx.split_kernel and not ctx.causal))
|
|
if torch.version.hip is not None:
|
|
BLOCK = 64
|
|
else:
|
|
BLOCK = 128
|
|
q, k, v, o, L = ctx.saved_tensors
|
|
assert do.is_contiguous()
|
|
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
|
|
do = do.contiguous()
|
|
dq = torch.zeros_like(q)
|
|
dk = torch.empty_like(k)
|
|
dv = torch.empty_like(v)
|
|
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
|
delta = torch.empty_like(L)
|
|
do_scaled = torch.empty_like(do)
|
|
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
|
|
# If the two are the same, we don't need this but the bwd pass block size
|
|
# is smaller than the fwd so we need this scaling to ensure we loop over all
|
|
# values and don't skip some blocks.
|
|
# Alternatively we could compute a new grid but this keeps it consistent
|
|
# with fwd and easier to reason about.
|
|
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
|
|
_attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
|
|
o, do, #
|
|
do_scaled, delta, #
|
|
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
|
|
)
|
|
if not ctx.split_kernel:
|
|
_bwd_kernel[(ctx.grid[1],)](
|
|
q, k, v, ctx.sm_scale,
|
|
o, do_scaled,
|
|
dq, dk, dv,
|
|
L, delta,
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
q.shape[0], q.shape[1], q.shape[2],
|
|
block_scale * ctx.grid[0],
|
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
|
CAUSAL=ctx.causal,
|
|
num_stages=1,
|
|
)
|
|
else :
|
|
dq = torch.zeros_like(q)
|
|
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
|
|
q, k, v, ctx.sm_scale,
|
|
o, do_scaled,
|
|
dk, dv,
|
|
L, delta,
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
q.shape[0], q.shape[1], q.shape[2],
|
|
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
|
|
num_stages=1,
|
|
)
|
|
_bwd_kernel_dq[ctx.grid](
|
|
q, k, v, ctx.sm_scale,
|
|
o, do_scaled,
|
|
dq,
|
|
L, delta,
|
|
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
|
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
|
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
|
q.shape[0], q.shape[1], q.shape[2],
|
|
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
|
|
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
|
|
num_stages=1,
|
|
)
|
|
# print(h.asm["ttgir"])
|
|
return dq, dk, dv, None, None, None
|
|
|
|
attention = _attention.apply
|
|
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
|
|
[(4, 48, 63, 64),
|
|
(4, 48, 987, 64),
|
|
(4, 48, 2048, 64),
|
|
(4, 48, 4096, 64),
|
|
(4, 48, 3989, 64),
|
|
(4, 48, 1024, 128),
|
|
(4, 48, 1021, 128),
|
|
(4, 48, 2048, 128),
|
|
(4, 48, 4096, 128),
|
|
(4, 16, 8192, 64),
|
|
(4, 16, 8080, 64),
|
|
(1, 48, 16384, 64)
|
|
])
|
|
@pytest.mark.parametrize('causal', [False])
|
|
@pytest.mark.parametrize('use_bias', [False, True])
|
|
@pytest.mark.parametrize('bias_type', ["vector", "matrix"])
|
|
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, use_bias, bias_type, dtype=torch.float16):
|
|
torch.manual_seed(20)
|
|
if use_bias:
|
|
if bias_type == "vector":
|
|
bias = torch.randn((1, H, 1, N_CTX), dtype=torch.float32, device="cuda")
|
|
elif bias_type == "matrix":
|
|
bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda")
|
|
else:
|
|
bias = None
|
|
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
|
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
|
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
|
if TORCH_HAS_FP8E5:
|
|
q = q.to(torch_dtype)
|
|
k = k.to(torch_dtype)
|
|
sm_scale = D_HEAD ** -0.5
|
|
dout = torch.randn_like(q, dtype=torch.float16)
|
|
# reference implementation
|
|
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
|
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
|
|
if causal:
|
|
p[:, :, M == 0] = float("-inf")
|
|
if use_bias:
|
|
ref_bias, _ = prepare_bias(bias, Z, H, N_CTX)
|
|
p += ref_bias
|
|
p = torch.softmax(p.float(), dim=-1).half()
|
|
ref_out = torch.matmul(p, v)
|
|
# triton implementation
|
|
tri_out = attention(q, k, v, causal, bias, sm_scale)
|
|
# compare
|
|
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
|
|
|
|
|
|
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
|
|
[(4, 48, 1024, 64),
|
|
(4, 48, 2048, 64),
|
|
(4, 48, 4096, 64),
|
|
(1, 16, 8192, 64),
|
|
])
|
|
def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
|
torch.manual_seed(20)
|
|
causal = True
|
|
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
|
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
|
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
|
|
|
sm_scale = 0.5
|
|
split_kernel = True
|
|
dout = torch.randn_like(q)
|
|
# reference implementation
|
|
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
|
|
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
|
|
if causal:
|
|
p[:, :, M == 0] = float("-inf")
|
|
p = torch.softmax(p.float(), dim=-1).half()
|
|
ref_out = torch.matmul(p, v)
|
|
ref_out.backward(dout)
|
|
ref_dv, v.grad = v.grad.clone(), None
|
|
ref_dk, k.grad = k.grad.clone(), None
|
|
ref_dq, q.grad = q.grad.clone(), None
|
|
# # triton implementation
|
|
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
|
|
tri_out.backward(dout)
|
|
tri_dv, v.grad = v.grad.clone(), None
|
|
tri_dk, k.grad = k.grad.clone(), None
|
|
tri_dq, q.grad = q.grad.clone(), None
|
|
# compare
|
|
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
|
|
if torch.version.hip is None:
|
|
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0)
|
|
# The current block size for MI200 series is 64x64. This results in
|
|
# larger differences in float results due to rounding.
|
|
else:
|
|
torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0)
|
|
torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2)
|
|
torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2)
|
|
|
|
|
|
try:
|
|
from flash_attn.flash_attn_interface import \
|
|
flash_attn_qkvpacked_func as flash_attn_func
|
|
HAS_FLASH = True
|
|
except BaseException:
|
|
HAS_FLASH = False
|
|
|
|
# vary seq length for fixed head and batch=4
|
|
configs = []
|
|
for mode in ['fwd']:
|
|
for D_HEAD in [128]:
|
|
if mode == 'bwd' and D_HEAD == 128:
|
|
continue
|
|
for causal in [False]:
|
|
if mode == 'bwd' and causal == False:
|
|
continue
|
|
for use_bias in [False, True]:
|
|
configs.append(triton.testing.Benchmark(
|
|
x_names=['BATCH', 'H','N_CTX'],
|
|
x_vals=[(16, 16, 1024),
|
|
(8, 16, 2048),
|
|
(4, 16, 4096),
|
|
(2, 16, 8192),
|
|
(1, 16, 16384),
|
|
(2, 48, 1024),
|
|
(2, 48, 2048),
|
|
(2, 48, 4096),
|
|
(2, 48, 8192),
|
|
(2, 48, 16384),
|
|
(8, 16, 1989),
|
|
(4, 16, 4097),
|
|
(2, 16, 8122),
|
|
(1, 16, 16281),
|
|
(2, 48, 1021),
|
|
(2, 48, 2001),
|
|
(2, 48, 3996),
|
|
(2, 48, 8181),
|
|
],
|
|
line_arg='provider',
|
|
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
|
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
|
|
styles=[('red', '-'), ('blue', '-')],
|
|
ylabel='ms',
|
|
plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}-bias={use_bias}',
|
|
args={
|
|
'D_HEAD': D_HEAD,
|
|
'dtype': torch.float16,
|
|
'mode': mode,
|
|
'causal': causal,
|
|
'use_bias' : use_bias})
|
|
)
|
|
|
|
|
|
@triton.testing.perf_report(configs)
|
|
def bench_flash_attention(
|
|
BATCH, H, N_CTX, D_HEAD, use_bias, causal, mode, provider, dtype=torch.float16, device="cuda"
|
|
):
|
|
assert mode in ["fwd", "bwd"]
|
|
warmup = 25
|
|
rep = 100
|
|
split_kernel = False
|
|
bias_type = "vector"
|
|
if use_bias:
|
|
if bias_type == "vector":
|
|
bias = torch.randn((1, H, 1, N_CTX), dtype=torch.float32, device="cuda")
|
|
elif bias_type == "matrix":
|
|
bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda")
|
|
else:
|
|
raise RuntimeError(
|
|
f"Got unsupported bias type: {bias_type}. Supported types are vector and matrix."
|
|
)
|
|
|
|
else: bias = None
|
|
# Bwd pass only supports causal=True right now
|
|
if mode == 'bwd':
|
|
causal = True
|
|
split_kernel = True
|
|
if provider == "triton":
|
|
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
|
if mode == "fwd":
|
|
q = q.to(torch_dtype)
|
|
k = k.to(torch_dtype)
|
|
sm_scale = 1.3
|
|
fn = lambda: attention(q, k, v, causal, bias, sm_scale, split_kernel)
|
|
if mode == 'bwd':
|
|
o = fn()
|
|
do = torch.randn_like(o)
|
|
fn = lambda: o.backward(do, retain_graph=True)
|
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
|
if provider == "flash":
|
|
qkv = torch.randn(
|
|
(BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True
|
|
)
|
|
fn = lambda: flash_attn_func(qkv, causal=causal)
|
|
if mode == "bwd":
|
|
o = fn()
|
|
do = torch.randn_like(o)
|
|
fn = lambda: o.backward(do, retain_graph=True)
|
|
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
|
|
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
|
|
total_flops = 2 * flops_per_matmul
|
|
if causal:
|
|
total_flops *= 0.5
|
|
if mode == "bwd":
|
|
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
|
|
return total_flops / ms * 1e-9
|
|
|
|
|
|
# only works on post-Ampere GPUs right now
|
|
bench_flash_attention.run(save_path=".", print_data=True)
|
|
|