mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER][TUTORIAL] flash attention v2 (#1952)
This commit is contained in:
@@ -809,11 +809,11 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
nextIV, newForOp.getUpperBound());
|
||||
|
||||
pipelineIterIdx = newForOp.getRegionIterArgs()[ivIndex + 1];
|
||||
Value insertSliceIndex = builder.create<arith::RemSIOp>(
|
||||
Value insertSliceIndex = builder.create<arith::RemUIOp>(
|
||||
nextIV.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
loopIterIdx = newForOp.getRegionIterArgs()[ivIndex + 2];
|
||||
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||
Value extractSliceIndex = builder.create<arith::RemUIOp>(
|
||||
nextIV.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
|
||||
|
||||
@@ -2,8 +2,13 @@
|
||||
Fused Attention
|
||||
===============
|
||||
|
||||
This is a Triton implementation of the Flash Attention algorithm
|
||||
(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
|
||||
|
||||
Extra Credits:
|
||||
- Original flash attention paper (https://arxiv.org/abs/2205.14135)
|
||||
- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
|
||||
- Adam P. Goucher for simplified vector math
|
||||
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -90,8 +95,7 @@ def _fwd_kernel(
|
||||
l_i = tl.load(l_ptrs)
|
||||
acc += tl.load(O_block_ptr).to(tl.float32)
|
||||
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
|
||||
# credits to: Adam P. Goucher (https://github.com/apgoucher):
|
||||
# scale sm_scale by 1/log_2(e) and use
|
||||
# scale sm_scale by log_2(e) and use
|
||||
# 2^x instead of exp in the loop because CSE and LICM
|
||||
# don't work as expected with `exp` in the loop
|
||||
qk_scale = sm_scale * 1.44269504
|
||||
@@ -111,20 +115,15 @@ def _fwd_kernel(
|
||||
if MODE == 1 or MODE == 3:
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, 1))
|
||||
p = tl.math.exp2(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.math.exp2(m_i - m_i_new)
|
||||
beta = tl.math.exp2(m_ij - m_i_new)
|
||||
alpha = tl.math.exp2(m_i - m_ij)
|
||||
l_i *= alpha
|
||||
l_i_new = l_i + beta * l_ij
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
l_i_new = l_i + l_ij
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new
|
||||
acc_scale = l_i * 0 + alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
# update acc
|
||||
v = tl.load(V_block_ptr)
|
||||
@@ -132,11 +131,12 @@ def _fwd_kernel(
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
m_i = m_ij
|
||||
# update pointers
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
|
||||
# write back l and m
|
||||
acc = acc / l_i[:, None]
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
@@ -266,13 +266,14 @@ class _attention(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, q, k, v, causal, sm_scale):
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
assert Lq == Lk and Lk == Lv
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], 128), q.shape[0] * q.shape[1], 1)
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 64
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
|
||||
@@ -291,10 +292,10 @@ class _attention(torch.autograd.Function):
|
||||
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
|
||||
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
|
||||
q.shape[0], q.shape[1], q.shape[2],
|
||||
BLOCK_M=128, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk,
|
||||
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk,
|
||||
MODE=mode,
|
||||
num_warps=num_warps,
|
||||
num_stages=2)
|
||||
num_stages=4)
|
||||
|
||||
ctx.save_for_backward(q, k, v, o, L, m)
|
||||
ctx.grid = grid
|
||||
@@ -396,7 +397,7 @@ configs = [triton.testing.Benchmark(
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode, 'causal': causal}
|
||||
) for mode in ['fwd', 'bwd'] for causal in [False, True]]
|
||||
) for mode in ['fwd'] for causal in [False]]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
|
||||
@@ -37,8 +37,8 @@
|
||||
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
@@ -110,8 +110,8 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
@@ -179,8 +179,8 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remui %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remui %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
|
||||
Reference in New Issue
Block a user