[TUTORIALS] attention: support torch 2.1 (#2461)

This commit is contained in:
Sam Shleifer
2023-10-06 20:50:11 -04:00
committed by GitHub
parent be19cf3103
commit fb3c2f3b2b

View File

@@ -588,6 +588,7 @@ try:
except BaseException:
HAS_FLASH = False
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
# vary seq length for fixed head and batch=4
configs = [
@@ -624,7 +625,7 @@ def bench_flash_attention(
if provider == "triton":
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if mode == "fwd":
if mode == "fwd" and TORCH_HAS_FP8:
q = q.to(torch.float8_e5m2)
k = k.to(torch.float8_e5m2)
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)