[TUTORIALS] Support FlashAttention-2 reference (#1984)

Uses FlashAttention-2 if available, otherwise acts as before (if
FlashAttention-1 is available, that is used, otherwise the
FlashAttention reference benchmark is not run).

I decided to keep the same name for the imported function, but feel free
to make me change that.
This commit is contained in:
janEbert
2023-07-24 22:54:01 +02:00
committed by GitHub
parent e6216047b8
commit 62a8afa403

View File

@@ -338,10 +338,15 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
try:
from flash_attn.flash_attn_interface import flash_attn_func
HAS_FLASH = True
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func as flash_attn_func
FLASH_VER = 2
except BaseException:
HAS_FLASH = False
try:
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_VER = 1
except BaseException:
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
@@ -350,7 +355,7 @@ configs = [triton.testing.Benchmark(
x_vals=[2**i for i in range(10, 15)],
line_arg='provider',
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
line_names=['Triton'] + ([f'Flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('blue', '-')],
ylabel='ms',
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
@@ -375,11 +380,17 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
fn = lambda: o.backward(do, retain_graph=True)
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
if provider == "flash":
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal)
qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
if FLASH_VER == 1:
lengths = torch.full((BATCH,), fill_value=N_CTX, device=device)
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal)
elif FLASH_VER == 2:
fn = lambda: flash_attn_func(qkv, causal=causal)
else:
raise ValueError(f'unknown {FLASH_VER = }')
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)