[DOCS] fixed flash_attn causal argument in tutorial

This commit is contained in:
Phil Tillet
2023-07-11 09:28:20 -07:00
parent bbc1ad16d8
commit 041f1144e8

View File

@@ -409,7 +409,7 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype
cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32)
cu_seqlens[1:] = lengths.cumsum(0)
qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True)
fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal)
if mode == 'bwd':
o = fn()
do = torch.randn_like(o)