[TESTS] fix flash attention (#2086)

Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
Dongdong Li
2023-09-20 14:23:46 +08:00
committed by GitHub
parent 363182928c
commit e5eda098b3
4 changed files with 282 additions and 20 deletions

View File

@@ -368,14 +368,15 @@ class _attention(torch.autograd.Function):
attention = _attention.apply
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 128, 64),
# (4, 48, 256, 64),
# (4, 48, 512, 64),
# (4, 48, 1024, 64),
# (4, 48, 2048, 64),
# (4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
(4, 48, 128, 64),
(4, 48, 256, 64),
(4, 48, 512, 64),
(4, 48, 1024, 64),
(4, 48, 2048, 64),
(4, 48, 4096, 64),
# (4, 48, 8192, 64), out of memory
])
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
torch.manual_seed(20)