mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[HOPPER] enable flash attention with tma (#2336)
This commit is contained in:
@@ -13,17 +13,23 @@ import triton.ops
|
||||
@pytest.mark.parametrize('causal', [True, False])
|
||||
@pytest.mark.parametrize('seq_par', [True, False])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
|
||||
import os
|
||||
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
|
||||
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
|
||||
if enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]:
|
||||
pytest.skip('Segmentation fault')
|
||||
if enable_tma in ["on", "true", "1"]:
|
||||
if dtype == torch.bfloat16:
|
||||
pytest.skip('bfloat16 tma not support currently')
|
||||
if '-'.join(map(str, [seq_par, causal, Z, H, N_CTX, D_HEAD])) in [
|
||||
"True-True-2-4-512-16",
|
||||
"True-True-2-4-512-32",
|
||||
"True-False-2-4-512-16",
|
||||
"True-False-2-4-512-32",
|
||||
]:
|
||||
pytest.skip('backward ref check failed')
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
|
||||
if not interpreter and capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
pytest.skip("Flash attention only supported for compute capability >= 80")
|
||||
torch.manual_seed(20)
|
||||
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
|
||||
|
||||
@@ -24,6 +24,7 @@ def _fwd_kernel(
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
stride_oz, stride_oh, stride_om, stride_on,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
@@ -31,27 +32,21 @@ def _fwd_kernel(
|
||||
start_m = tl.program_id(0)
|
||||
off_hz = tl.program_id(1)
|
||||
qvk_offset = off_hz * stride_qh
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
vk_offset = qvk_offset // stride_qm
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + qvk_offset,
|
||||
shape=(BLOCK_DMODEL, N_CTX),
|
||||
base=K,
|
||||
shape=(BLOCK_DMODEL, Z_H_N_CTX),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, 0),
|
||||
offsets=(0, vk_offset),
|
||||
block_shape=(BLOCK_DMODEL, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
base=V,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
offsets=(vk_offset, 0),
|
||||
block_shape=(BLOCK_N, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
@@ -68,7 +63,11 @@ def _fwd_kernel(
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(Q_block_ptr)
|
||||
|
||||
offs_k = tl.arange(0, BLOCK_DMODEL)
|
||||
Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
q = tl.load(Q_ptrs)
|
||||
|
||||
q = (q * qk_scale).to(K.dtype.element_ty)
|
||||
lo = 0
|
||||
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
|
||||
@@ -100,13 +99,14 @@ def _fwd_kernel(
|
||||
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
|
||||
# write back O
|
||||
O_block_ptr = tl.make_block_ptr(
|
||||
base=Out + qvk_offset,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
base=Out,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_om, stride_on),
|
||||
offsets=(start_m * BLOCK_M, 0),
|
||||
offsets=(vk_offset + start_m * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
|
||||
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
|
||||
|
||||
|
||||
@@ -312,6 +312,7 @@ class _attention(torch.autograd.Function):
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
IS_CAUSAL=causal,
|
||||
num_warps=num_warps,
|
||||
|
||||
Reference in New Issue
Block a user