[OPTIMIZER] Remove extra wgmma_wait_group in flash attention (#2399)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Dongdong Li
2023-10-27 00:35:36 +08:00
committed by GitHub
parent cfae7e2a25
commit 0469d5fccd
3 changed files with 278 additions and 8 deletions

View File

@@ -53,3 +53,66 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
except BaseException:
HAS_FLASH = False
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(10, 14)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}',
args={
'H': N_HEADS,
'BATCH': BATCH,
'D_HEAD': D_HEAD,
'dtype': torch.float16,
'mode': mode,
'casual': casual,
'seq_par': seq_par}
) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False]]
@triton.testing.perf_report(configs)
def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provider, dtype=torch.float16, device="cuda"):
assert mode in ['fwd', 'bwd']
warmup = 25
rep = 100
sm_scale = 1.3
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if provider == "triton":
fn = lambda: triton.ops.attention(q, k, v, casual, sm_scale, seq_par)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros(
(BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
return ms
# only works on post-Ampere GPUs right now
# bench_flash_attention.run(save_path='.', print_data=True)