Files
ROCm/python/perf-kernels/flash-attention.py
Vinayak Gokhale f239abfc7e Revert "Add autotuning for FA (#459)" (#467)
This reverts commit 1fec965c06.

This change used pre_hook to edit a kernel arg. However,
pre-hook does not make the changes made within visible to
the kernel in all cases.
2024-01-16 23:02:47 -06:00

886 lines
33 KiB
Python

"""
Fused Attention
===============
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
Credits: OpenAI kernel team, AMD ML Frameworks Triton team
Features supported:
1) Fwd + bwd kernel with causal masking
2) Vector and matrix bias (currently fwd kernel only, no causal masking)
3) Any sequence lengths without padding (currently fwd kernel only, no causal masking)
4) fp8 (e5m2fnuz, QK GEMM in fwd kernel only)
Not currently supported:
1) Nested / ragged tensors ("varlen")
"""
import pytest
import random
import torch
import triton
import triton.language as tl
torch_dtype:tl.constexpr = torch.float16
TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype:tl.constexpr = torch.float8_e5m2fnuz
def prepare_bias(bias, batch, nheads, seqlen):
assert bias.is_cuda
assert bias.dim() == 4
if bias.shape[2:] == (1, seqlen):
bias_type = "vector"
elif bias.shape[2:] == (seqlen, seqlen):
bias_type = "matrix"
else:
raise RuntimeError(
"Last 2 dimensions of bias must be (1, seqlen)" " or (seqlen, seqlen)"
)
return bias.expand(batch, nheads, seqlen, seqlen), bias_type
@triton.jit
def _attn_fwd_inner(
acc, l_i, m_i, q,
K_block_ptr, V_block_ptr,
start_m,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
N_CTX,
pre_load_v: tl.constexpr,
padded_block: tl.constexpr,
total_tokens: tl.constexpr,
bias_ptr
):
# range of values handled by this stage
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
# N_CTX is the seqlen to the nearest block (round down).
# So here, we are computing the elements for that last irregular block.
# In the loop, we will mask the elements of BLOCK_N that do not exist.
elif padded_block:
lo, hi = N_CTX, N_CTX + BLOCK_N
lo = tl.multiple_of(lo, BLOCK_N)
K_block_ptr = tl.advance(K_block_ptr, (0, lo))
V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
if bias_ptr is not None:
if bias_ptr.type.element_ty.is_block():
bias_ptr = tl.advance(bias_ptr, (0, lo))
else:
bias_ptr += lo
# causal = False
else:
lo, hi = 0, N_CTX
# loop over k, v and update accumulator
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# For padded blocks, we will overrun the tensor size if
# we load all BLOCK_N. For others, the blocks are all within range.
if padded_block:
k = tl.load(K_block_ptr, boundary_check=(1,), padding_option="zero")
else:
k = tl.load(K_block_ptr)
if pre_load_v:
if padded_block:
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
else:
v = tl.load(V_block_ptr)
# -- compute qk ----
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = tl.where(mask, qk, float("-inf"))
qk += tl.dot(q, k)
if bias_ptr is not None:
if padded_block:
if bias_ptr.type.element_ty.is_block():
bias = tl.load(bias_ptr,boundary_check=(1,), padding_option="zero")
else:
size_n = start_n + offs_n
boundary_n = tl.full([BLOCK_N], total_tokens, dtype=tl.float32)
bias_padding = tl.full([BLOCK_N], 0, dtype=tl.float32)
bias = tl.load(bias_ptr, mask=size_n < boundary_n, other=bias_padding)
else:
bias = tl.load(bias_ptr)
# While bias is added after multiplying qk with sm_scale,
# our optimization to use 2^x instead of e^x results in an additional
# scale factor of log2(e) which we must also multiply the bias with.
qk += (bias * 1.44269504)
if padded_block:
boundary_m = tl.full([BLOCK_M], total_tokens, dtype=tl.float32)
size_n = start_n + offs_n[None,:]
mask = size_n < boundary_m[:,None]
qk = tl.where(mask, qk, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
# -- update output accumulator --
alpha = tl.math.exp2(m_i - m_ij)
acc = acc * alpha[:, None]
if not pre_load_v:
if padded_block:
v = tl.load(V_block_ptr, boundary_check=(0,), padding_option="zero")
else:
v = tl.load(V_block_ptr)
acc += tl.dot(p.to(v.dtype), v)
# -- update m_i and l_i
l_ij = tl.sum(p, 1)
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
if bias_ptr is not None:
if bias_ptr.type.element_ty.is_block():
bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N))
else:
bias_ptr += BLOCK_N
return acc, l_i, m_i
@triton.jit
def _attn_fwd(
Q, K, V, bias, sm_scale, M, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
stride_bz, stride_bh, stride_bm, stride_bn,
Z, H,
N_CTX,
BLOCK_DMODEL: tl.constexpr,
STAGE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
pre_load_v: tl.constexpr,
need_padding: tl.constexpr,
extra_tokens_n: tl.constexpr,
bias_type: tl.constexpr
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0),
)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
if bias is not None:
if bias_type == "vector":
bias_ptr = bias + ((off_hz % H) * stride_bh) + offs_n
elif bias_type == "matrix":
bias_ptr = tl.make_block_ptr(
base=bias + ((off_hz % H) * stride_bh),
shape=(N_CTX, N_CTX),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
else:
bias_ptr = None
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
q = tl.load(Q_block_ptr, boundary_check=(0,), padding_option="zero")
q = (q * qk_scale).to(q.dtype)
# stage 1: off-band
# For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
# For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
if STAGE & 1:
# We don't currently support causal masking and padding.
tl.static_assert((STAGE != 3) or not need_padding)
# equal to N_CTX if N_CTX is already a multiple of block_M
seqlen_aligned = N_CTX - extra_tokens_n
if N_CTX >= BLOCK_N:
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
seqlen_aligned, pre_load_v,
False, seqlen_aligned,
bias_ptr
)
tl.debug_barrier()
if need_padding:
if N_CTX < BLOCK_N:
seqlen_aligned = 0
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
seqlen_aligned, pre_load_v,
True, N_CTX,
bias_ptr
)
# stage 2: on-band
if STAGE & 2:
# barrier makes it easier for compielr to schedule the
# two loops independently
tl.debug_barrier()
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
2, offs_m, offs_n,
N_CTX, pre_load_v,
)
# epilogue
# write back m
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
# Check for last block_M
overflow_size = (start_m * BLOCK_M) - N_CTX
if overflow_size > 0:
boundary = tl.full((BLOCK_M,), overflow_size, dtype=tl.float32)
# This is a > check because mask being 0 blocks the store.
m_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
tl.store(m_ptrs, m_i + tl.math.log2(l_i), mask=m_ptrs_mask)
else:
tl.store(m_ptrs, m_i + tl.math.log2(l_i))
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
@triton.jit
def _attn_bwd_preprocess(O, DO, #
NewDO, Delta, #
BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, #
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_n = tl.arange(0, D_HEAD)
# load
o = tl.load(O + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)
tl.store(Delta + off_m, delta)
@triton.jit
def _bwd_kernel_dk_dv(
Q, K, V, sm_scale, Out, DO,
DK, DV,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# Q is consumed depending on block ID. Every block uses
# previous block offset by BLOCK_M x D_HEAD.
qvk_offset = off_hz * stride_qh
qdo_offset = qvk_offset + start_m * BLOCK_M * stride_qm
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qdo_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, start_m * BLOCK_M),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qdo_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load k and v: they will stay in SRAM throughout
k = tl.load(K_block_ptr)
k = (k * qk_scale).to(k.dtype)
v = tl.load(V_block_ptr)
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# This lower loop bound is because of the causal mask. We create a lower triangular
# result. The upper triangular is -inf (becomes 0 when we do e^x). As such, it can
# be ignored in the GEMM.
lo = start_m * BLOCK_M
hi = N_CTX
# loop over q, do
for start_n in range(lo, hi, BLOCK_N):
offs_m_curr = offs_n[:, None] + start_n
# -- load q, do --
q = tl.load(Q_block_ptr)
do = tl.load(DO_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m_curr >= offs_m[None, :], qk, float("-inf"))
l_i = tl.load(l_ptrs + offs_m_curr)
p = tl.math.exp2(qk - l_i)
# -- compute dv ----
dv += tl.dot(tl.trans(p.to(do.dtype)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_N, BLOCK_M], dtype=tl.float32) - Di
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dk
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# update pointers
Q_block_ptr = tl.advance(Q_block_ptr, (BLOCK_N, 0))
DO_block_ptr = tl.advance(DO_block_ptr, (BLOCK_N, 0))
# initialize pointers to output
DK_block_ptr = tl.make_block_ptr(
base=DK + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_kn, stride_kk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
DV_block_ptr = tl.make_block_ptr(
base=DV + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DK_block_ptr, (dk * sm_scale).to(DK.dtype.element_ty))
tl.store(DV_block_ptr, dv.to(tl.float16))
@triton.jit
def _bwd_kernel_dq(
Q, K, V, sm_scale, Out, DO,
DQ,
L,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
# Initialize pointers to Q, K, V
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
DO_block_ptr = tl.make_block_ptr(
base=DO + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
l_ptrs = L + off_hz * N_CTX
qk_scale = sm_scale * 1.44269504
# load q and do: they will stay in SRAM throughout
q = tl.load(Q_block_ptr)
q = (q * qk_scale).to(q.dtype)
do = tl.load(DO_block_ptr)
Di = tl.load(D_ptrs + offs_m)
l_i = tl.load(l_ptrs + offs_m)
dq = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# loop over k, v
lo = 0
hi = (start_m + 1) * BLOCK_M
for start_n in range(lo, hi, BLOCK_N):
# -- load k, v --
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ----
qk = tl.dot(q, k)
qk = tl.where(offs_m[:, None] >= (offs_n[None, :] + start_n), qk, float("-inf"))
p = tl.math.exp2(qk - l_i[:, None])
# compute dp = dot(v, do)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, v)
# compute ds = p * (dp - delta[:, None])
ds = p * dp
# compute dq. Unfortunately we cannot avoid transpose here as this loop
# uses k both normal and transpose.
dq += tl.dot(ds.to(Q.dtype.element_ty), tl.trans(k))
# update pointers
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (0, BLOCK_N))
# initialize pointers to output
DQ_block_ptr = tl.make_block_ptr(
base=DQ + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
tl.store(DQ_block_ptr, (dq * sm_scale).to(tl.float16))
empty = torch.empty(128, device="cuda")
class _attention(torch.autograd.Function):
@staticmethod
def forward(ctx, q, k, v, causal, bias, sm_scale, split_kernel=False):
# shape constraints
batch, nheads, seqlen, Lq = q.shape
Lk, Lv = k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
# For now we assume K and V seqlen = Q seqlen
assert seqlen == k.shape[-2] and seqlen == v.shape[-2]
# We've derived these previously from tuning the kernel
BLOCK_M = 256
BLOCK_N = 128 if Lq == 128 else 64
waves_per_eu = 2 if Lq == 128 else 3
num_warps = 8 if Lq == 128 else 4
pre_load_v = False if Lq == 128 else True
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
stage = 3 if causal else 1
seqlen_k = k.shape[-2]
if seqlen_k < BLOCK_N:
need_padding = True
extra_tokens_n = BLOCK_N - seqlen_k
# This effectively means we cannot slice across Q.
assert(grid[0] == 1)
elif seqlen_k % BLOCK_N:
need_padding = True
extra_tokens_n = seqlen_k % BLOCK_N
else:
# We don't care if the M dim needs padding, as we
# always boundary_check on Q and O
need_padding = False
extra_tokens_n = 0
o = torch.empty_like(q, dtype=v.dtype)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
if bias is not None:
bias, bias_type = prepare_bias(bias, batch, nheads, seqlen)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2), bias.stride(3))
else:
bias, bias_type, bias_strides = None, None, (0,0,0,0)
_attn_fwd[grid](
q, k, v, bias, sm_scale, M, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
*bias_strides,
q.shape[0], q.shape[1],
N_CTX=q.shape[2],
BLOCK_DMODEL=Lk,
STAGE=stage,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
waves_per_eu=waves_per_eu, pre_load_v=pre_load_v,
need_padding=need_padding, extra_tokens_n=extra_tokens_n,
bias_type=bias_type,
num_stages=1, num_warps=num_warps
)
ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = Lk
ctx.causal = causal
ctx.split_kernel = split_kernel
return o
@staticmethod
def backward(ctx, do):
# configuration is not supported
assert(not (ctx.split_kernel and not ctx.causal))
if torch.version.hip is not None:
BLOCK = 64
else:
BLOCK = 128
q, k, v, o, L = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
do = do.contiguous()
dq = torch.zeros_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
BATCH, N_HEAD, N_CTX = q.shape[:3]
delta = torch.empty_like(L)
do_scaled = torch.empty_like(do)
# Figure out what BLOCK size fwd used and adjust num_blocks accordingly.
# If the two are the same, we don't need this but the bwd pass block size
# is smaller than the fwd so we need this scaling to ensure we loop over all
# values and don't skip some blocks.
# Alternatively we could compute a new grid but this keeps it consistent
# with fwd and easier to reason about.
block_scale = (q.shape[2] // ctx.grid[0]) // BLOCK
_attn_bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](
o, do, #
do_scaled, delta, #
BLOCK_M=block_scale * BLOCK, D_HEAD=ctx.BLOCK_DMODEL, #
)
if not ctx.split_kernel:
_bwd_kernel[(ctx.grid[1],)](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq, dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
block_scale * ctx.grid[0],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
CAUSAL=ctx.causal,
num_stages=1,
)
else :
dq = torch.zeros_like(q)
_bwd_kernel_dk_dv[(block_scale * ctx.grid[0], ctx.grid[1])](
q, k, v, ctx.sm_scale,
o, do_scaled,
dk, dv,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4,
num_stages=1,
)
_bwd_kernel_dq[ctx.grid](
q, k, v, ctx.sm_scale,
o, do_scaled,
dq,
L, delta,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
q.shape[0], q.shape[1], q.shape[2],
BLOCK_M=2*BLOCK, BLOCK_N=BLOCK,
BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=4, waves_per_eu=1,
num_stages=1,
)
# print(h.asm["ttgir"])
return dq, dk, dv, None, None, None
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(4, 48, 63, 64),
(4, 48, 987, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(4, 48, 3989, 64),
(4, 48, 1024, 128),
(4, 48, 1021, 128),
(4, 48, 2048, 128),
(4, 48, 4096, 128),
(4, 16, 8192, 64),
(4, 16, 8080, 64),
(1, 48, 16384, 64)
])
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('use_bias', [False, True])
@pytest.mark.parametrize('bias_type', ["vector", "matrix"])
def test_op_fwd(Z, H, N_CTX, D_HEAD, causal, use_bias, bias_type, dtype=torch.float16):
torch.manual_seed(20)
if use_bias:
if bias_type == "vector":
bias = torch.randn((1, H, 1, N_CTX), dtype=torch.float32, device="cuda")
elif bias_type == "matrix":
bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda")
else:
bias = None
q = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8E5:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = D_HEAD ** -0.5
dout = torch.randn_like(q, dtype=torch.float16)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q.half(), k.transpose(2, 3).half()) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
if use_bias:
ref_bias, _ = prepare_bias(bias, Z, H, N_CTX)
p += ref_bias
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
# triton implementation
tri_out = attention(q, k, v, causal, bias, sm_scale)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2)
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
(1, 16, 8192, 64),
])
def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)
causal = True
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
split_kernel = True
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = attention(q, k, v, causal, sm_scale, split_kernel)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
if torch.version.hip is None:
torch.testing.assert_close(ref_dv, tri_dv, atol=1e-2, rtol=0)
# The current block size for MI200 series is 64x64. This results in
# larger differences in float results due to rounding.
else:
torch.testing.assert_close(ref_dv, tri_dv, atol=5e-2, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=5e-2, rtol=1e-2)
torch.testing.assert_close(ref_dq, tri_dq, atol=5e-2, rtol=1e-2)
try:
from flash_attn.flash_attn_interface import \
flash_attn_qkvpacked_func as flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
# vary seq length for fixed head and batch=4
configs = []
for mode in ['fwd']:
for D_HEAD in [128]:
if mode == 'bwd' and D_HEAD == 128:
continue
for causal in [False]:
if mode == 'bwd' and causal == False:
continue
for use_bias in [False, True]:
configs.append(triton.testing.Benchmark(
x_names=['BATCH', 'H','N_CTX'],
x_vals=[(16, 16, 1024),
(8, 16, 2048),
(4, 16, 4096),
(2, 16, 8192),
(1, 16, 16384),
(2, 48, 1024),
(2, 48, 2048),
(2, 48, 4096),
(2, 48, 8192),
(2, 48, 16384),
(8, 16, 1989),
(4, 16, 4097),
(2, 16, 8122),
(1, 16, 16281),
(2, 48, 1021),
(2, 48, 2001),
(2, 48, 3996),
(2, 48, 8181),
],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-{mode}-d{D_HEAD}-causal={causal}-bias={use_bias}',
args={
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'causal': causal,
'use_bias' : use_bias})
)
@triton.testing.perf_report(configs)
def bench_flash_attention(
BATCH, H, N_CTX, D_HEAD, use_bias, causal, mode, provider, dtype=torch.float16, device="cuda"
):
assert mode in ["fwd", "bwd"]
warmup = 25
rep = 100
split_kernel = False
bias_type = "vector"
if use_bias:
if bias_type == "vector":
bias = torch.randn((1, H, 1, N_CTX), dtype=torch.float32, device="cuda")
elif bias_type == "matrix":
bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda")
else:
raise RuntimeError(
f"Got unsupported bias type: {bias_type}. Supported types are vector and matrix."
)
else: bias = None
# Bwd pass only supports causal=True right now
if mode == 'bwd':
causal = True
split_kernel = True
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd":
q = q.to(torch_dtype)
k = k.to(torch_dtype)
sm_scale = 1.3
fn = lambda: attention(q, k, v, causal, bias, sm_scale, split_kernel)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash":
qkv = torch.randn(
(BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True
)
fn = lambda: flash_attn_func(qkv, causal=causal)
if mode == "bwd":
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
total_flops = 2 * flops_per_matmul
if causal:
total_flops *= 0.5
if mode == "bwd":
total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute)
return total_flops / ms * 1e-9
# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path=".", print_data=True)