Add autotuning for FA (#459)

This commit is contained in:
Vinayak Gokhale
2024-01-12 17:15:12 -06:00
committed by GitHub
parent 2e217c5a5c
commit 1fec965c06

View File

@@ -150,6 +150,46 @@ def _attn_fwd_inner(
bias_ptr += BLOCK_N
return acc, l_i, m_i
def pre_hook(nargs):
BLOCK_N = nargs['BLOCK_N']
seqlen_k = nargs['N_CTX_K']
# This is the default. Below we check if N_CTX_K needs padding.
# We don't care if N_CTX_Q needs padding, as we
# always boundary_check on Q and O so even if it is the last M block
# that needs padding, we would just fill the tile with 0s beyond the boundary.
need_padding = False
extra_tokens_n = 0
if seqlen_k < BLOCK_N:
need_padding = True
extra_tokens_n = BLOCK_N - seqlen_k
# This effectively means we cannot slice across Q.
elif seqlen_k % BLOCK_N:
need_padding = True
extra_tokens_n = seqlen_k % BLOCK_N
nargs['need_padding'] = need_padding
nargs['extra_tokens_n'] = extra_tokens_n
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=8,
pre_hook=pre_hook),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4,
pre_hook=pre_hook),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=8,
pre_hook=pre_hook),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 1, 'pre_load_v': True}, num_stages=1, num_warps=4,
pre_hook=pre_hook),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 3, 'pre_load_v': False}, num_stages=1, num_warps=4,
pre_hook=pre_hook),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'pre_load_v': False}, num_stages=1, num_warps=4,
pre_hook=pre_hook),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 1, 'pre_load_v': False}, num_stages=1, num_warps=8,
pre_hook=pre_hook),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'pre_load_v': True}, num_stages=1, num_warps=4,
pre_hook=pre_hook),
],
key=['Z', 'H', 'N_CTX_Q', 'N_CTX_K', 'STAGE', 'BLOCK_DMODEL'],
)
@triton.jit
def _attn_fwd(
@@ -160,7 +200,8 @@ def _attn_fwd(
stride_oz, stride_oh, stride_om, stride_on,
stride_bz, stride_bh, stride_bm, stride_bn,
Z, H,
N_CTX,
N_CTX_Q,
N_CTX_K,
BLOCK_DMODEL: tl.constexpr,
STAGE: tl.constexpr,
BLOCK_M: tl.constexpr,
@@ -177,7 +218,7 @@ def _attn_fwd(
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
shape=(N_CTX_Q, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
@@ -185,7 +226,7 @@ def _attn_fwd(
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
shape=(N_CTX_K, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
@@ -193,7 +234,7 @@ def _attn_fwd(
)
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
shape=(BLOCK_DMODEL, N_CTX_K),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
@@ -201,7 +242,7 @@ def _attn_fwd(
)
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
shape=(N_CTX_Q, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
@@ -216,7 +257,7 @@ def _attn_fwd(
elif bias_type == "matrix":
bias_ptr = tl.make_block_ptr(
base=bias + ((off_hz % H) * stride_bh),
shape=(N_CTX, N_CTX),
shape=(N_CTX_K, N_CTX_K),
strides=(stride_bm, stride_bn),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
@@ -241,9 +282,9 @@ def _attn_fwd(
if STAGE & 1:
# We don't currently support causal masking and padding.
tl.static_assert((STAGE != 3) or not need_padding)
# equal to N_CTX if N_CTX is already a multiple of block_M
seqlen_aligned = N_CTX - extra_tokens_n
if N_CTX >= BLOCK_N:
# equal to N_CTX_K if N_CTX_K is already a multiple of block_N
seqlen_aligned = N_CTX_K - extra_tokens_n
if N_CTX_K >= BLOCK_N:
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
@@ -255,7 +296,7 @@ def _attn_fwd(
)
tl.debug_barrier()
if need_padding:
if N_CTX < BLOCK_N:
if N_CTX_K < BLOCK_N:
seqlen_aligned = 0
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
@@ -263,7 +304,7 @@ def _attn_fwd(
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
seqlen_aligned, pre_load_v,
True, N_CTX,
True, N_CTX_K,
bias_ptr
)
# stage 2: on-band
@@ -276,14 +317,14 @@ def _attn_fwd(
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
2, offs_m, offs_n,
N_CTX, pre_load_v,
N_CTX_K, pre_load_v,
)
# epilogue
# write back m
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX_Q + offs_m
# Check for last block_M
overflow_size = (start_m * BLOCK_M) - N_CTX
overflow_size = (start_m * BLOCK_M) - N_CTX_Q
if overflow_size > 0:
boundary = tl.full((BLOCK_M,), overflow_size, dtype=tl.float32)
# This is a > check because mask being 0 blocks the store.
@@ -536,32 +577,15 @@ class _attention(torch.autograd.Function):
# For now we assume K and V seqlen = Q seqlen
assert seqlen == k.shape[-2] and seqlen == v.shape[-2]
# We've derived these previously from tuning the kernel
BLOCK_M = 256
BLOCK_N = 128 if Lq == 128 else 64
waves_per_eu = 2 if Lq == 128 else 3
num_warps = 8 if Lq == 128 else 4
pre_load_v = False if Lq == 128 else True
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
grid = lambda META: (
triton.cdiv(q.shape[2], META['BLOCK_M']), q.shape[0] * q.shape[1], 1
)
stage = 3 if causal else 1
seqlen_k = k.shape[-2]
if seqlen_k < BLOCK_N:
need_padding = True
extra_tokens_n = BLOCK_N - seqlen_k
# This effectively means we cannot slice across Q.
assert(grid[0] == 1)
elif seqlen_k % BLOCK_N:
need_padding = True
extra_tokens_n = seqlen_k % BLOCK_N
else:
# We don't care if the M dim needs padding, as we
# always boundary_check on Q and O
need_padding = False
extra_tokens_n = 0
o = torch.empty_like(q, dtype=v.dtype)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]),
device=q.device, dtype=torch.float32)
if bias is not None:
bias, bias_type = prepare_bias(bias, batch, nheads, seqlen)
@@ -577,14 +601,12 @@ class _attention(torch.autograd.Function):
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
*bias_strides,
q.shape[0], q.shape[1],
N_CTX=q.shape[2],
N_CTX_Q=q.shape[-2],
N_CTX_K=k.shape[-2],
BLOCK_DMODEL=Lk,
STAGE=stage,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
waves_per_eu=waves_per_eu, pre_load_v=pre_load_v,
need_padding=need_padding, extra_tokens_n=extra_tokens_n,
need_padding=False, extra_tokens_n=0,
bias_type=bias_type,
num_stages=1, num_warps=num_warps
)
ctx.save_for_backward(q, k, v, o, M)