Files
ROCm/python/test/unit/operators/test_flash_attention.py
jon-chuang 5231d57c71 [TESTS] replace deprecated torch.testing.assert_allclose (#2250)
Prior to this PR, matmul on sm_89 (RTX 4070)
(`test/unit/operators/test_matmul.py::test_op`) would result in test
failure due to too strict atol/rtol.

To avoid having to choose strictness ourselves, and to have better
defaults based on dtype, use the non-deprecated torch testing util.

See: https://github.com/pytorch/pytorch/issues/61844

Replace: https://github.com/openai/triton/pull/2242
2023-09-11 15:31:17 -04:00

58 lines
2.5 KiB
Python

import pytest
import torch
import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
(4, 48, 1024, 32),
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
import os
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
if enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]:
pytest.skip('Segmentation fault')
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
sm_scale = 0.5
dout = torch.randn_like(q)
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if causal:
p[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float(), dim=-1).to(dtype)
# p = torch.exp(p)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
# # triton implementation
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
# print(ref_out)
# print(tri_out)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# compare
atol = 1e-1 if dtype == torch.bfloat16 else 1e-2
torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0)
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)