mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TUTORIALS] attention: support torch 2.1 (#2461)
This commit is contained in:
@@ -588,6 +588,7 @@ try:
|
||||
except BaseException:
|
||||
HAS_FLASH = False
|
||||
|
||||
TORCH_HAS_FP8 = hasattr(torch, 'float8_e5m2')
|
||||
BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [
|
||||
@@ -624,7 +625,7 @@ def bench_flash_attention(
|
||||
if provider == "triton":
|
||||
q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
if mode == "fwd":
|
||||
if mode == "fwd" and TORCH_HAS_FP8:
|
||||
q = q.to(torch.float8_e5m2)
|
||||
k = k.to(torch.float8_e5m2)
|
||||
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
|
||||
|
||||
Reference in New Issue
Block a user