mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Improved flash attention forward pass performance (#1075)
- Fixed typo in instruction reordering pass - Minor additional optimizations for shared memory allocator - Optimized flash attention tutorial forward pass kernel
This commit is contained in:
@@ -77,6 +77,9 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
return result;
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
|
||||
triton::gpu::DotOperandEncodingAttr &dotOperandLayout);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
@@ -58,6 +58,13 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
|
||||
auto dstTy = op.result().getType().cast<RankedTensorType>();
|
||||
Attribute srcLayout = srcTy.getEncoding();
|
||||
Attribute dstLayout = dstTy.getEncoding();
|
||||
|
||||
// MmaToDotShortcut doesn't use shared mem
|
||||
if (auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>())
|
||||
if (auto dotOperandLayout = dstLayout.dyn_cast<DotOperandEncodingAttr>())
|
||||
if (isMmaToDotShortcut(mmaLayout, dotOperandLayout))
|
||||
return {};
|
||||
|
||||
assert(srcLayout && dstLayout &&
|
||||
"Unexpect layout in getScratchConfigForCvtLayout()");
|
||||
auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout);
|
||||
|
||||
@@ -48,6 +48,12 @@ SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
||||
auto axis = op.axis();
|
||||
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
||||
|
||||
auto argLayout = srcTy.getEncoding();
|
||||
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
|
||||
triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
||||
return {{1, 1}, {1, 1}};
|
||||
|
||||
/// shared memory block0
|
||||
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
||||
smemShapes[0][axis] = getInterWarpSize();
|
||||
@@ -148,4 +154,14 @@ std::string getValueOperandName(Value value, AsmState &state) {
|
||||
return opName;
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
|
||||
triton::gpu::DotOperandEncodingAttr &dotOperandLayout) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
return mmaLayout.getVersionMajor() == 2 &&
|
||||
mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "DotOpHelpers.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
@@ -17,15 +18,6 @@ using ::mlir::triton::gpu::getSizePerThread;
|
||||
using ::mlir::triton::gpu::isaDistributedLayout;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
||||
DotOperandEncodingAttr &dotOperandLayout) {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
struct ConvertLayoutOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
|
||||
public:
|
||||
|
||||
@@ -8,9 +8,6 @@ using namespace mlir::triton;
|
||||
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
|
||||
bool isMmaToDotShortcut(MmaEncodingAttr &mmaLayout,
|
||||
DotOperandEncodingAttr &dotOperandLayout);
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
|
||||
@@ -330,16 +330,6 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(
|
||||
MmaEncodingAttr &mmaLayout,
|
||||
triton::gpu::DotOperandEncodingAttr &dotOperandLayout) const {
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
||||
return mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
void storeDistributedToShared(Value src, Value llSrc,
|
||||
ArrayRef<Value> dstStrides,
|
||||
ArrayRef<SmallVector<Value>> srcIndices,
|
||||
|
||||
@@ -88,7 +88,7 @@ public:
|
||||
if (!dstEncoding)
|
||||
return;
|
||||
int opIdx = dstEncoding.getOpIdx();
|
||||
if (opIdx != 1)
|
||||
if (opIdx != 0)
|
||||
return;
|
||||
if (op->getUsers().empty())
|
||||
return;
|
||||
|
||||
@@ -15,7 +15,7 @@ import triton.language as tl
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
Q, K, V, sm_scale,
|
||||
TMP, L, M, # NOTE: TMP is a scratchpad buffer to work around a compiler bug
|
||||
L, M,
|
||||
Out,
|
||||
stride_qz, stride_qh, stride_qm, stride_qk,
|
||||
stride_kz, stride_kh, stride_kn, stride_kk,
|
||||
@@ -32,58 +32,55 @@ def _fwd_kernel(
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk
|
||||
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
|
||||
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
|
||||
# Initialize pointers to Q, K, V
|
||||
q_ptrs = Q + off_q
|
||||
k_ptrs = K + off_k
|
||||
v_ptrs = V + off_v
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
|
||||
# load q: it will stay in SRAM throughout
|
||||
q = tl.load(q_ptrs)
|
||||
# loop over k, v and update accumulator
|
||||
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
|
||||
# start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
# -- compute qk ----
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
k = tl.load(k_ptrs)
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk += tl.dot(q, k)
|
||||
qk *= sm_scale
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf"))
|
||||
# -- compute m_ij, p, l_ij
|
||||
m_ij = tl.max(qk, 1)
|
||||
p = tl.exp(qk - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
# -- update m_i and l_i
|
||||
m_i_new = tl.maximum(m_i, m_ij)
|
||||
alpha = tl.exp(m_i - m_i_new)
|
||||
beta = tl.exp(m_ij - m_i_new)
|
||||
l_i_new = alpha * l_i + beta * l_ij
|
||||
# -- update output accumulator --
|
||||
# scale p
|
||||
p_scale = beta / l_i_new
|
||||
p = p * p_scale[:, None]
|
||||
# scale acc
|
||||
acc_scale = l_i / l_i_new * alpha
|
||||
acc = acc * acc_scale[:, None]
|
||||
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
|
||||
# compute new m
|
||||
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
|
||||
# correct old l
|
||||
l_prev *= tl.exp(m_prev - m_curr)
|
||||
# attention weights
|
||||
p = tl.exp(qk - m_curr[:, None])
|
||||
l_curr = tl.sum(p, 1) + l_prev
|
||||
# rescale operands of matmuls
|
||||
l_rcp = 1. / l_curr
|
||||
p *= l_rcp
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
v = tl.load(v_ptrs + start_n * stride_vk)
|
||||
p = p.to(tl.float16)
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
l_i = l_i_new
|
||||
m_i = m_i_new
|
||||
l_prev = l_curr
|
||||
m_prev = m_curr
|
||||
# update pointers
|
||||
k_ptrs += BLOCK_N * stride_kn
|
||||
v_ptrs += BLOCK_N * stride_vk
|
||||
# rematerialize offsets to save registers
|
||||
start_m = tl.program_id(0)
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
# write back l and m
|
||||
l_ptrs = L + off_hz * N_CTX + offs_m
|
||||
m_ptrs = M + off_hz * N_CTX + offs_m
|
||||
tl.store(l_ptrs, l_i)
|
||||
tl.store(m_ptrs, m_i)
|
||||
tl.store(l_ptrs, l_prev)
|
||||
tl.store(m_ptrs, m_prev)
|
||||
# initialize pointers to output
|
||||
offs_n = tl.arange(0, BLOCK_DMODEL)
|
||||
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
|
||||
@@ -209,14 +206,13 @@ class _attention(torch.autograd.Function):
|
||||
assert Lk in {16, 32, 64, 128}
|
||||
o = torch.empty_like(q)
|
||||
grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)
|
||||
tmp = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)
|
||||
num_warps = 4 if Lk <= 64 else 8
|
||||
|
||||
_fwd_kernel[grid](
|
||||
q, k, v, sm_scale,
|
||||
tmp, L, m,
|
||||
L, m,
|
||||
o,
|
||||
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
|
||||
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
|
||||
@@ -316,7 +312,7 @@ BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64
|
||||
# vary seq length for fixed head and batch=4
|
||||
configs = [triton.testing.Benchmark(
|
||||
x_names=['N_CTX'],
|
||||
x_vals=[2**i for i in range(10, 15)],
|
||||
x_vals=[2**i for i in range(10, 14)],
|
||||
line_arg='provider',
|
||||
line_vals=['triton'] + (['flash'] if HAS_FLASH else []),
|
||||
line_names=['Triton'] + (['Flash'] if HAS_FLASH else []),
|
||||
@@ -324,7 +320,7 @@ configs = [triton.testing.Benchmark(
|
||||
ylabel='ms',
|
||||
plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}',
|
||||
args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode}
|
||||
) for mode in ['fwd']]
|
||||
) for mode in ['fwd', 'bwd']]
|
||||
|
||||
|
||||
@triton.testing.perf_report(configs)
|
||||
@@ -357,4 +353,6 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f
|
||||
ms = triton.testing.do_bench(fn, percentiles=None, warmup=warmup, rep=rep)
|
||||
return ms
|
||||
|
||||
# bench_flash_attention.run(save_path='.', print_data=True)
|
||||
|
||||
# only works on post-Ampere GPUs right now
|
||||
bench_flash_attention.run(save_path='.', print_data=True)
|
||||
|
||||
@@ -6,4 +6,5 @@ add_mlir_library(TritonTestAnalysis
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
TritonAnalysis
|
||||
${dialect_libs}
|
||||
)
|
||||
Reference in New Issue
Block a user