mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZATION] Enable pipelining for bwd flash attention (#2590)
This allow pipelining when a load is used by multiple dot in a loop. Relax the condition to pipeline dot operands for mma v3 case. This improves performance for the bwd pass from 260TF to 275TF. However this expose a performance problem due to the wmma pipelining as ptxas will now fall back to serial wgmma. A follow up PR will fix a bug in how we emit wgmma_wait during pipelining and will bring performance to 335TF
This commit is contained in:
@@ -157,8 +157,29 @@ static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
|
||||
return Value();
|
||||
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
Operation *preUse = nullptr;
|
||||
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
auto tensorType =
|
||||
convertLayout.getResult().getType().cast<RankedTensorType>();
|
||||
if (auto sharedEnc =
|
||||
tensorType.getEncoding().dyn_cast<ttg::SharedEncodingAttr>()) {
|
||||
if (sharedEnc.getHasLeadingOffset()) {
|
||||
// MMA V3 case.
|
||||
auto newOrder = sharedEnc.getOrder();
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
auto oldOrder = ttg::getOrder(ty.getEncoding());
|
||||
if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) {
|
||||
// The operand of MMAv3 is in SharedEncoding and it's order should
|
||||
// not be changed after FuseTranspositions Pass. So we only pipeline
|
||||
// the load if the order of the loaded BlockedEncoding is the same
|
||||
// as the order of the SharedEncoding it is converted to.
|
||||
// TODO: remove this constraint once the LoadOp supports transpose
|
||||
// fusion
|
||||
hasMMAV3 = true;
|
||||
return convertLayout.getResult();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Advance to the first conversion as long as the use resides in shared
|
||||
// memory and it has a single use itself
|
||||
while (use) {
|
||||
@@ -167,7 +188,6 @@ static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
|
||||
auto tensorType = use->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>())
|
||||
break;
|
||||
preUse = use;
|
||||
use = *use->getResult(0).getUsers().begin();
|
||||
}
|
||||
|
||||
@@ -179,27 +199,6 @@ static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
|
||||
return convertLayout.getResult();
|
||||
}
|
||||
}
|
||||
} else if (preUse && isa<tt::DotOp>(use)) {
|
||||
// for MMAv3 whose dot take SharedEncoding as operands directly
|
||||
Operation *post = *loadOp.getResult().getUsers().begin();
|
||||
auto newOrder = post->getResult(0)
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<ttg::SharedEncodingAttr>()
|
||||
.getOrder();
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
auto oldOrder = ttg::getOrder(ty.getEncoding());
|
||||
// The operand of MMAv3 is in SharedEncoding and it's order should not
|
||||
// be changed after FuseTranspositions Pass. So we only pipeline the
|
||||
// load if the order of the loaded BlockedEncoding is the same as the
|
||||
// order of the SharedEncoding it is converted to.
|
||||
// TODO: remove this constraint once the LoadOp supports transpose
|
||||
// fusion
|
||||
if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) {
|
||||
hasMMAV3 = true;
|
||||
return preUse->getResult(0);
|
||||
}
|
||||
}
|
||||
return Value();
|
||||
}
|
||||
|
||||
@@ -111,6 +111,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
|
||||
if optimize_epilogue:
|
||||
pm.add_tritongpu_optimize_epilogue_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_cse_pass()
|
||||
ws_enabled = False
|
||||
# `num_warps` does not mean the total number of warps of a CTA when
|
||||
# warp specialization is enabled.
|
||||
|
||||
@@ -489,6 +489,8 @@ class _attention(torch.autograd.Function):
|
||||
BATCH, N_HEAD, N_CTX = q.shape[:3]
|
||||
PRE_BLOCK = 128
|
||||
NUM_WARPS, NUM_STAGES = 4, 1
|
||||
if torch.cuda.get_device_capability()[0] == 9:
|
||||
NUM_STAGES = 5
|
||||
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
|
||||
BLK_SLICE_FACTOR = 2
|
||||
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
|
||||
|
||||
@@ -319,3 +319,63 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
// tt.return %res#0, %res#1, %res#2 : !tt.ptr<tensor<128x32xf16>, 1>, !tt.ptr<tensor<32x128xf16>, 1>, tensor<128x128xf32, #C>
|
||||
// }
|
||||
//}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}>
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
// CHECK-LABEL: dot_chained_single_load
|
||||
tt.func @dot_chained_single_load(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}) -> tensor<128x64xf32, #mma> {
|
||||
%cst = arith.constant dense<0> : tensor<64x16xi32, #blocked>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked>
|
||||
%cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>
|
||||
%c0_i64 = arith.constant 0 : i64
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
|
||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = tt.addptr %arg0, %c0_i64 : !tt.ptr<f16, 1>, i64
|
||||
%1 = tt.addptr %arg1, %c0_i64 : !tt.ptr<f16, 1>, i64
|
||||
%2 = tt.splat %1 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked1>
|
||||
%3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr<f16, 1>, #blocked1>, tensor<128x1xi32, #blocked1>
|
||||
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%6 = tt.broadcast %3 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked1>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked1>
|
||||
%7 = tt.broadcast %5 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||
%8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr<f16, 1>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||
%9 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||
%10 = tt.splat %0 : (!tt.ptr<f16, 1>) -> tensor<1x16x!tt.ptr<f16, 1>, #blocked>
|
||||
%11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr<f16, 1>, #blocked>, tensor<1x16xi32, #blocked>
|
||||
%12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%13 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
|
||||
%14 = tt.broadcast %11 : (tensor<1x16x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x16x!tt.ptr<f16, 1>, #blocked>
|
||||
%15 = tt.broadcast %13 : (tensor<64x1xi32, #blocked>) -> tensor<64x16xi32, #blocked>
|
||||
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<64x16xi32, #blocked>
|
||||
// CHECK: scf.for
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: tt.dot
|
||||
// CHECK: triton_nvidia_gpu.dot_async
|
||||
// CHECK: triton_gpu.insert_slice_async
|
||||
// CHECK: triton_gpu.async_commit_group
|
||||
// CHECK: scf.yield
|
||||
%17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr<f16, 1>, #blocked>) : i32 {
|
||||
%18 = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked>
|
||||
%19 = triton_gpu.convert_layout %9 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared>
|
||||
%20 = triton_gpu.convert_layout %18 : (tensor<64x16xf16, #blocked>) -> tensor<64x16xf16, #shared1>
|
||||
%21 = tt.dot %19, %20, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1>
|
||||
%22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1>
|
||||
%23 = tt.trans %20 : (tensor<64x16xf16, #shared1>) -> tensor<16x64xf16, #shared>
|
||||
%24 = triton_gpu.convert_layout %22 : (tensor<128x16xf16, #mma1>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||
%25 = tt.dot %24, %23, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<16x64xf16, #shared> -> tensor<128x64xf32, #mma>
|
||||
%26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<64x16xi32, #blocked>
|
||||
scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr<f16, 1>, #blocked>
|
||||
}
|
||||
tt.return %17#0 : tensor<128x64xf32, #mma>
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user