Bugfix: Wrong boundary condition on qk GEMM

This commit is contained in:
Vinayak Gokhale
2023-11-28 16:49:00 +00:00
committed by Vinayak Gokhale
parent f6969f4bb3
commit 1d6b919897

View File

@@ -81,7 +81,7 @@ def _attn_fwd_inner(
qk = tl.where(mask, qk, float("-inf"))
if padded_block:
boundary = tl.full([BLOCK_M], total_tokens, dtype=tl.float32)
mask = (start_n + offs_n[None,:]) <= boundary[:,None]
mask = (start_n + offs_n[None,:]) < boundary[:,None]
qk = tl.where(mask, qk, float("-inf"))
qk += tl.dot(q, k)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
@@ -182,16 +182,19 @@ def _attn_fwd(
tl.static_assert((STAGE != 3) or not need_padding)
# equal to N_CTX if N_CTX is already a multiple of block_M
seqlen_aligned = N_CTX - extra_tokens_n
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
seqlen_aligned, pre_load_v,
False, seqlen_aligned
)
if N_CTX >= BLOCK_N:
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
4 - STAGE, offs_m, offs_n,
seqlen_aligned, pre_load_v,
False, seqlen_aligned
)
tl.debug_barrier()
if need_padding:
if N_CTX < BLOCK_N:
seqlen_aligned = 0
acc, l_i, m_i = _attn_fwd_inner(
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
start_m,
@@ -476,15 +479,28 @@ class _attention(torch.autograd.Function):
waves_per_eu = 2 if Lq == 128 else 3
num_warps = 8 if Lq == 128 else 4
pre_load_v = False if Lq == 128 else True
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
stage = 3 if causal else 1
need_padding = True if seqlen % BLOCK_M else False
extra_tokens_n = seqlen % BLOCK_N
# Compute if we need padding and how much
seqlen_k = k.shape[-2]
if seqlen_k < BLOCK_N:
need_padding = True
extra_tokens_n = BLOCK_N - seqlen_k
# This effectively ensures we do not slice across Q.
assert(grid[0] == 1)
elif seqlen_k % BLOCK_N:
need_padding = True
extra_tokens_n = seqlen_k % BLOCK_N
else:
# We don't care if BLOCK_M isn't aligned, as we
# always boundary_check on Q and O
need_padding = False
extra_tokens_n = 0
o = torch.empty_like(q, dtype=v.dtype)
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
_attn_fwd[grid](
q, k, v, sm_scale, M, o,
@@ -591,7 +607,8 @@ attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
[(4, 48, 1024, 64),
[(1, 40, 19, 128),
(4, 48, 1024, 64),
(4, 48, 997, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
@@ -771,4 +788,3 @@ def bench_flash_attention(
# only works on post-Ampere GPUs right now
bench_flash_attention.run(save_path=".", print_data=True)