mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Add autotuning for FA (#459)
This commit is contained in:
@@ -150,6 +150,46 @@ def _attn_fwd_inner(
|
||||
bias_ptr += BLOCK_N
|
||||
return acc, l_i, m_i
|
||||
|
||||
def pre_hook(nargs):
|
||||
BLOCK_N = nargs['BLOCK_N']
|
||||
seqlen_k = nargs['N_CTX_K']
|
||||
# This is the default. Below we check if N_CTX_K needs padding.
|
||||
# We don't care if N_CTX_Q needs padding, as we
|
||||
# always boundary_check on Q and O so even if it is the last M block
|
||||
# that needs padding, we would just fill the tile with 0s beyond the boundary.
|
||||
need_padding = False
|
||||
extra_tokens_n = 0
|
||||
if seqlen_k < BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = BLOCK_N - seqlen_k
|
||||
# This effectively means we cannot slice across Q.
|
||||
elif seqlen_k % BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = seqlen_k % BLOCK_N
|
||||
nargs['need_padding'] = need_padding
|
||||
nargs['extra_tokens_n'] = extra_tokens_n
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8,
|
||||
pre_hook=pre_hook),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4,
|
||||
pre_hook=pre_hook),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=8,
|
||||
pre_hook=pre_hook),
|
||||
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4,
|
||||
pre_hook=pre_hook),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4,
|
||||
pre_hook=pre_hook),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4,
|
||||
pre_hook=pre_hook),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=8,
|
||||
pre_hook=pre_hook),
|
||||
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4,
|
||||
pre_hook=pre_hook),
|
||||
],
|
||||
key=['Z', 'H', 'N_CTX_Q', 'N_CTX_K', 'STAGE', 'BLOCK_DMODEL'],
|
||||
)
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
@@ -160,7 +200,8 @@ def _attn_fwd(
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
stride_bz, stride_bh, stride_bm, stride_bn,
|
||||
Z, H,
|
||||
N_CTX,
|
||||
N_CTX_Q,
|
||||
N_CTX_K,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
STAGE: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
@@ -177,7 +218,7 @@ def _attn_fwd(
|
||||
# block pointers
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(N_CTX_Q, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
@@ -185,7 +226,7 @@ def _attn_fwd(
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(N_CTX_K, BLOCK_DMODEL),
|
||||
strides=(stride_vk, stride_vn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
@@ -193,7 +234,7 @@ def _attn_fwd(
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
shape=(BLOCK_DMODEL, N_CTX_K),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
@@ -201,7 +242,7 @@ def _attn_fwd(
|
||||
)
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(N_CTX_Q, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
@@ -216,7 +257,7 @@ def _attn_fwd(
|
||||
elif bias_type == "matrix":
|
||||
bias_ptr = tl.make_block_ptr(
|
||||
base=bias + ((off_hz % H) * stride_bh),
|
||||
shape=(N_CTX, N_CTX),
|
||||
shape=(N_CTX_K, N_CTX_K),
|
||||
strides=(stride_bm, stride_bn),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_N),
|
||||
@@ -241,9 +282,9 @@ def _attn_fwd(
|
||||
if STAGE & 1:
|
||||
# We don't currently support causal masking and padding.
|
||||
tl.static_assert((STAGE != 3) or not need_padding)
|
||||
# equal to N_CTX if N_CTX is already a multiple of block_M
|
||||
seqlen_aligned = N_CTX - extra_tokens_n
|
||||
if N_CTX >= BLOCK_N:
|
||||
# equal to N_CTX_K if N_CTX_K is already a multiple of block_N
|
||||
seqlen_aligned = N_CTX_K - extra_tokens_n
|
||||
if N_CTX_K >= BLOCK_N:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
@@ -255,7 +296,7 @@ def _attn_fwd(
|
||||
)
|
||||
tl.debug_barrier()
|
||||
if need_padding:
|
||||
if N_CTX < BLOCK_N:
|
||||
if N_CTX_K < BLOCK_N:
|
||||
seqlen_aligned = 0
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
@@ -263,7 +304,7 @@ def _attn_fwd(
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
seqlen_aligned, pre_load_v,
|
||||
True, N_CTX,
|
||||
True, N_CTX_K,
|
||||
bias_ptr
|
||||
)
|
||||
# stage 2: on-band
|
||||
@@ -276,14 +317,14 @@ def _attn_fwd(
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
2, offs_m, offs_n,
|
||||
N_CTX, pre_load_v,
|
||||
N_CTX_K, pre_load_v,
|
||||
)
|
||||
# epilogue
|
||||
# write back m
|
||||
acc = acc / l_i[:, None]
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX_Q + offs_m
|
||||
# Check for last block_M
|
||||
overflow_size = (start_m * BLOCK_M) - N_CTX
|
||||
overflow_size = (start_m * BLOCK_M) - N_CTX_Q
|
||||
if overflow_size > 0:
|
||||
boundary = tl.full((BLOCK_M,), overflow_size, dtype=tl.float32)
|
||||
# This is a > check because mask being 0 blocks the store.
|
||||
@@ -536,32 +577,15 @@ class _attention(torch.autograd.Function):
|
||||
# For now we assume K and V seqlen = Q seqlen
|
||||
assert seqlen == k.shape[-2] and seqlen == v.shape[-2]
|
||||
|
||||
# We've derived these previously from tuning the kernel
|
||||
BLOCK_M = 256
|
||||
BLOCK_N = 128 if Lq == 128 else 64
|
||||
waves_per_eu = 2 if Lq == 128 else 3
|
||||
num_warps = 8 if Lq == 128 else 4
|
||||
pre_load_v = False if Lq == 128 else True
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
grid = lambda META: (
|
||||
triton.cdiv(q.shape[2], META['BLOCK_M']), q.shape[0] * q.shape[1], 1
|
||||
)
|
||||
|
||||
stage = 3 if causal else 1
|
||||
seqlen_k = k.shape[-2]
|
||||
if seqlen_k < BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = BLOCK_N - seqlen_k
|
||||
# This effectively means we cannot slice across Q.
|
||||
assert(grid[0] == 1)
|
||||
elif seqlen_k % BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = seqlen_k % BLOCK_N
|
||||
else:
|
||||
# We don't care if the M dim needs padding, as we
|
||||
# always boundary_check on Q and O
|
||||
need_padding = False
|
||||
extra_tokens_n = 0
|
||||
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
|
||||
device=q.device, dtype=torch.float32)
|
||||
|
||||
if bias is not None:
|
||||
bias, bias_type = prepare_bias(bias, batch, nheads, seqlen)
|
||||
@@ -577,14 +601,12 @@ class _attention(torch.autograd.Function):
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
*bias_strides,
|
||||
q.shape[0], q.shape[1],
|
||||
N_CTX=q.shape[2],
|
||||
N_CTX_Q=q.shape[-2],
|
||||
N_CTX_K=k.shape[-2],
|
||||
BLOCK_DMODEL=Lk,
|
||||
STAGE=stage,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
|
||||
waves_per_eu=waves_per_eu, pre_load_v=pre_load_v,
|
||||
need_padding=need_padding, extra_tokens_n=extra_tokens_n,
|
||||
need_padding=False, extra_tokens_n=0,
|
||||
bias_type=bias_type,
|
||||
num_stages=1, num_warps=num_warps
|
||||
)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, M)
|
||||
|
||||
Reference in New Issue
Block a user