[HOPPER] enable flash attention with tma (#2336)

This commit is contained in:
ben-zhang-609
2023-09-21 05:06:56 +08:00
committed by GitHub
parent 9cab885dff
commit bcaf14755a
3 changed files with 43 additions and 45 deletions

View File

@@ -13,17 +13,23 @@ import triton.ops
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
# with ENABLE_TMA=0 and ENABLE_MMA_V3=0
import os
enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower()
enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower()
if enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]:
pytest.skip('Segmentation fault')
if enable_tma in ["on", "true", "1"]:
if dtype == torch.bfloat16:
pytest.skip('bfloat16 tma not support currently')
if '-'.join(map(str, [seq_par, causal, Z, H, N_CTX, D_HEAD])) in [
"True-True-2-4-512-16",
"True-True-2-4-512-32",
"True-False-2-4-512-16",
"True-False-2-4-512-32",
]:
pytest.skip('backward ref check failed')
capability = torch.cuda.get_device_capability()
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
if not interpreter and capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
pytest.skip("Flash attention only supported for compute capability >= 80")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()

View File

@@ -24,6 +24,7 @@ def _fwd_kernel(
stride_vz, stride_vh, stride_vn, stride_vk,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
Z_H_N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
@@ -31,27 +32,21 @@ def _fwd_kernel(
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
qvk_offset = off_hz * stride_qh
Q_block_ptr = tl.make_block_ptr(
base=Q + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
vk_offset = qvk_offset // stride_qm
K_block_ptr = tl.make_block_ptr(
base=K + qvk_offset,
shape=(BLOCK_DMODEL, N_CTX),
base=K,
shape=(BLOCK_DMODEL, Z_H_N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
offsets=(0, vk_offset),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
base=V,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_vn, stride_vk),
offsets=(0, 0),
offsets=(vk_offset, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0)
)
@@ -68,7 +63,11 @@ def _fwd_kernel(
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr)
offs_k = tl.arange(0, BLOCK_DMODEL)
Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
q = tl.load(Q_ptrs)
q = (q * qk_scale).to(K.dtype.element_ty)
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
@@ -100,13 +99,14 @@ def _fwd_kernel(
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(
base=Out + qvk_offset,
shape=(N_CTX, BLOCK_DMODEL),
base=Out,
shape=(Z_H_N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
offsets=(vk_offset + start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0)
)
# O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
tl.store(O_block_ptr, acc.to(K.dtype.element_ty))
@@ -312,6 +312,7 @@ class _attention(torch.autograd.Function):
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
q.shape[0], q.shape[1], q.shape[2],
q.shape[0] * q.shape[1] * q.shape[2],
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
num_warps=num_warps,