mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[TESTS] fix flash attention (#2086)
Co-authored-by: dongdongl <dongdongl@nvidia.com>
This commit is contained in:
@@ -368,14 +368,15 @@ class _attention(torch.autograd.Function):
|
||||
attention = _attention.apply
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 128, 64),
|
||||
# (4, 48, 256, 64),
|
||||
# (4, 48, 512, 64),
|
||||
# (4, 48, 1024, 64),
|
||||
# (4, 48, 2048, 64),
|
||||
# (4, 48, 4096, 64),
|
||||
# (4, 48, 8192, 64), out of memory
|
||||
])
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [
|
||||
(4, 48, 128, 64),
|
||||
(4, 48, 256, 64),
|
||||
(4, 48, 512, 64),
|
||||
(4, 48, 1024, 64),
|
||||
(4, 48, 2048, 64),
|
||||
(4, 48, 4096, 64),
|
||||
# (4, 48, 8192, 64), out of memory
|
||||
])
|
||||
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+")
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
torch.manual_seed(20)
|
||||
|
||||
Reference in New Issue
Block a user