[TUTORIAL] flash attention d128 improvement (#2074)

`ptxas` is able to automatically generate a call instruction to "call"
the loop body so that instructions are better scheduled.
This commit is contained in:
Keren Zhou
2023-08-11 20:31:48 -04:00
committed by GitHub
parent c309f7e57a
commit 5162871c6c

View File

@@ -87,10 +87,10 @@ def _fwd_kernel(
k = tl.load(K_block_ptr)
v = tl.load(V_block_ptr)
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float16)
if IS_CAUSAL:
qk = tl.where(P_SEQ + offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
qk += tl.dot(q, k)
qk += tl.dot(q, k, out_dtype=tl.float16)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
@@ -243,11 +243,12 @@ class _attention(torch.autograd.Function):
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q)
BLOCK_M = 128
BLOCK_N = 64
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]
num_warps = 4 if Lk <= 64 else 8
_fwd_kernel[grid](
q, k, v, sm_scale,
L,
@@ -260,7 +261,7 @@ class _attention(torch.autograd.Function):
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,
num_stages=4)
num_stages=num_stages)
ctx.save_for_backward(q, k, v, o, L)
ctx.grid = grid