mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPS] enable flash_attention_v2 TMA (#2544)
This commit is contained in:
@@ -165,7 +165,7 @@ flash_attention_data = {
|
||||
(4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266,
|
||||
(4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136,
|
||||
(4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.159,
|
||||
(4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -43,8 +43,6 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
ref_dq, q.grad = q.grad.clone(), None
|
||||
# # triton implementation
|
||||
tri_out = triton.ops.attention(q, k, v, causal, sm_scale, seq_par)
|
||||
# temporary env var control begin
|
||||
os.putenv("ENABLE_TMA", "0")
|
||||
tri_out.backward(dout)
|
||||
tri_dv, v.grad = v.grad.clone(), None
|
||||
tri_dk, k.grad = k.grad.clone(), None
|
||||
@@ -55,5 +53,3 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
|
||||
torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0)
|
||||
torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0)
|
||||
# temporary env var control end
|
||||
os.putenv("ENABLE_TMA", enable_tma)
|
||||
|
||||
@@ -140,7 +140,7 @@ def _bwd_kernel_one_col_block(
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block,
|
||||
off_h, off_z, off_hz, start_n, num_block,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
@@ -152,13 +152,21 @@ def _bwd_kernel_one_col_block(
|
||||
else:
|
||||
lo = 0
|
||||
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (lo, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (lo, 0))
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo, 0))
|
||||
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M, 0))
|
||||
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M, 0))
|
||||
Q_offset = (off_z * stride_qz + off_h * stride_qh) // stride_qm
|
||||
DQ_offset = (off_z * stride_qz + off_h * stride_qh)
|
||||
K_offset = (off_z * stride_kz + off_h * stride_kh) // stride_kn
|
||||
V_offset = (off_z * stride_vz + off_h * stride_vh) // stride_vn
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_offset += stride_dqa.to(tl.int64) * start_n
|
||||
DQ_offset = DQ_offset // stride_qm
|
||||
|
||||
Q_block_ptr = tl.advance(Q_block_ptr, (lo + Q_offset, 0))
|
||||
K_block_ptr = tl.advance(K_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
DO_block_ptr = tl.advance(DO_block_ptr, (lo + Q_offset, 0))
|
||||
DQ_block_ptr = tl.advance(DQ_block_ptr, (lo + DQ_offset, 0))
|
||||
DK_block_ptr = tl.advance(DK_block_ptr, (start_n * BLOCK_M + K_offset, 0))
|
||||
DV_block_ptr = tl.advance(DV_block_ptr, (start_n * BLOCK_M + V_offset, 0))
|
||||
|
||||
# initialize row/col offsets
|
||||
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
@@ -232,6 +240,8 @@ def _bwd_kernel(
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
Z_H_N_CTX,
|
||||
SQ_Z_H_N_CTX,
|
||||
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
SEQUENCE_PARALLEL: tl.constexpr,
|
||||
@@ -243,21 +253,10 @@ def _bwd_kernel(
|
||||
off_hz = tl.program_id(0)
|
||||
off_z = off_hz // H
|
||||
off_h = off_hz % H
|
||||
# offset pointers for batch/head
|
||||
Q += off_z * stride_qz + off_h * stride_qh
|
||||
K += off_z * stride_kz + off_h * stride_kh
|
||||
V += off_z * stride_vz + off_h * stride_vh
|
||||
DO += off_z * stride_qz + off_h * stride_qh
|
||||
DQ += off_z * stride_qz + off_h * stride_qh
|
||||
DK += off_z * stride_kz + off_h * stride_kh
|
||||
DV += off_z * stride_vz + off_h * stride_vh
|
||||
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ += stride_dqa.to(tl.int64) * tl.program_id(1)
|
||||
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
@@ -265,7 +264,7 @@ def _bwd_kernel(
|
||||
)
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
@@ -273,7 +272,7 @@ def _bwd_kernel(
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
@@ -281,23 +280,34 @@ def _bwd_kernel(
|
||||
)
|
||||
DO_block_ptr = tl.make_block_ptr(
|
||||
base=DO,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
if SEQUENCE_PARALLEL:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(SQ_Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
else:
|
||||
DQ_block_ptr = tl.make_block_ptr(
|
||||
base=DQ,
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
DK_block_ptr = tl.make_block_ptr(
|
||||
base=DK,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_kn, stride_kk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
@@ -305,7 +315,7 @@ def _bwd_kernel(
|
||||
)
|
||||
DV_block_ptr = tl.make_block_ptr(
|
||||
base=DV,
|
||||
shape=(N_CTX, BLOCK_DMODEL),
|
||||
shape=(Z_H_N_CTX, BLOCK_DMODEL),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(0, 0),
|
||||
block_shape=(BLOCK_M, BLOCK_DMODEL),
|
||||
@@ -326,7 +336,7 @@ def _bwd_kernel(
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
off_h, off_z, off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
@@ -346,7 +356,7 @@ def _bwd_kernel(
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
stride_vz, stride_vh, stride_vn, stride_vk,
|
||||
Z, H, N_CTX,
|
||||
off_hz, start_n, num_block_n,
|
||||
off_h, off_z, off_hz, start_n, num_block_n,
|
||||
BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_N=BLOCK_N,
|
||||
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
|
||||
@@ -429,6 +439,8 @@ class _attention(torch.autograd.Function):
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
q.shape[0] * q.shape[1] * q.shape[2],
|
||||
cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2],
|
||||
BLOCK_M=BLOCK, BLOCK_N=BLOCK,
|
||||
BLOCK_DMODEL=ctx.BLOCK_DMODEL,
|
||||
SEQUENCE_PARALLEL=sequence_parallel,
|
||||
|
||||
Reference in New Issue
Block a user