mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Better tuning for H100 flash attention. (#2444)
Improves performance of fwd pass from 420 to 440 TF
This commit is contained in:
@@ -69,6 +69,30 @@ def _attn_fwd_inner(
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
return acc, l_i, m_i
|
||||
|
||||
# We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
|
||||
# the code below and commenting out the equivalent parameters is convenient for
|
||||
# re-tuning.
|
||||
# @triton.autotune(
|
||||
# configs=[
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 256, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=3, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=4, num_warps=4),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=3, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=7, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=7, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=6, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=5, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32}, num_stages=4, num_warps=8),
|
||||
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64}, num_stages=6, num_warps=4),
|
||||
# ],
|
||||
# key=['N_CTX'],
|
||||
# )
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _attn_fwd(
|
||||
@@ -436,6 +460,10 @@ class _attention(torch.autograd.Function):
|
||||
BLOCK_N = 64 if Lk <= 64 else 32
|
||||
num_stages = 4 if Lk <= 64 else 3
|
||||
num_warps = 4
|
||||
# Tuning for H100
|
||||
if torch.cuda.get_device_capability()[0] == 9:
|
||||
num_warps = 8
|
||||
num_stages = 7 if Lk >= 64 else 3
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
_attn_fwd[grid](
|
||||
|
||||
Reference in New Issue
Block a user