mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Bugfix: Wrong boundary condition on qk GEMM
This commit is contained in:
committed by
Vinayak Gokhale
parent
f6969f4bb3
commit
1d6b919897
@@ -81,7 +81,7 @@ def _attn_fwd_inner(
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
if padded_block:
|
||||
boundary = tl.full([BLOCK_M], total_tokens, dtype=tl.float32)
|
||||
mask = (start_n + offs_n[None,:]) <= boundary[:,None]
|
||||
mask = (start_n + offs_n[None,:]) < boundary[:,None]
|
||||
qk = tl.where(mask, qk, float("-inf"))
|
||||
qk += tl.dot(q, k)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
@@ -182,16 +182,19 @@ def _attn_fwd(
|
||||
tl.static_assert((STAGE != 3) or not need_padding)
|
||||
# equal to N_CTX if N_CTX is already a multiple of block_M
|
||||
seqlen_aligned = N_CTX - extra_tokens_n
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
seqlen_aligned, pre_load_v,
|
||||
False, seqlen_aligned
|
||||
)
|
||||
if N_CTX >= BLOCK_N:
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
BLOCK_M, BLOCK_DMODEL, BLOCK_N,
|
||||
4 - STAGE, offs_m, offs_n,
|
||||
seqlen_aligned, pre_load_v,
|
||||
False, seqlen_aligned
|
||||
)
|
||||
tl.debug_barrier()
|
||||
if need_padding:
|
||||
if N_CTX < BLOCK_N:
|
||||
seqlen_aligned = 0
|
||||
acc, l_i, m_i = _attn_fwd_inner(
|
||||
acc, l_i, m_i, q, K_block_ptr, V_block_ptr,
|
||||
start_m,
|
||||
@@ -476,15 +479,28 @@ class _attention(torch.autograd.Function):
|
||||
waves_per_eu = 2 if Lq == 128 else 3
|
||||
num_warps = 8 if Lq == 128 else 4
|
||||
pre_load_v = False if Lq == 128 else True
|
||||
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
stage = 3 if causal else 1
|
||||
need_padding = True if seqlen % BLOCK_M else False
|
||||
extra_tokens_n = seqlen % BLOCK_N
|
||||
|
||||
# Compute if we need padding and how much
|
||||
seqlen_k = k.shape[-2]
|
||||
|
||||
if seqlen_k < BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = BLOCK_N - seqlen_k
|
||||
# This effectively ensures we do not slice across Q.
|
||||
assert(grid[0] == 1)
|
||||
elif seqlen_k % BLOCK_N:
|
||||
need_padding = True
|
||||
extra_tokens_n = seqlen_k % BLOCK_N
|
||||
else:
|
||||
# We don't care if BLOCK_M isn't aligned, as we
|
||||
# always boundary_check on Q and O
|
||||
need_padding = False
|
||||
extra_tokens_n = 0
|
||||
|
||||
o = torch.empty_like(q, dtype=v.dtype)
|
||||
M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
|
||||
_attn_fwd[grid](
|
||||
q, k, v, sm_scale, M, o,
|
||||
@@ -591,7 +607,8 @@ attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD',
|
||||
[(4, 48, 1024, 64),
|
||||
[(1, 40, 19, 128),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 997, 64),
|
||||
(4, 48, 2048, 64),
|
||||
(4, 48, 4096, 64),
|
||||
@@ -771,4 +788,3 @@ def bench_flash_attention(
|
||||
|
||||
# only works on post-Ampere GPUs right now
|
||||
bench_flash_attention.run(save_path=".", print_data=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user