[OPTIMIZATION] Enable pipelining for bwd flash attention (#2590)

This allow pipelining when a load is used by multiple dot in a loop.

Relax the condition to pipeline dot operands for mma v3 case. This
improves performance for the bwd pass from 260TF to 275TF. However this
expose a performance problem due to the wmma pipelining as ptxas will
now fall back to serial wgmma. A follow up PR will fix a bug in how we
emit wgmma_wait during pipelining and will bring performance to 335TF
This commit is contained in:
Thomas Raoux
2023-11-03 11:46:51 -07:00
committed by GitHub
parent df08301e76
commit 6ac9d51ff0
4 changed files with 86 additions and 24 deletions

View File

@@ -157,8 +157,29 @@ static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
return Value();
Operation *use = *loadOp.getResult().getUsers().begin();
Operation *preUse = nullptr;
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
auto tensorType =
convertLayout.getResult().getType().cast<RankedTensorType>();
if (auto sharedEnc =
tensorType.getEncoding().dyn_cast<ttg::SharedEncodingAttr>()) {
if (sharedEnc.getHasLeadingOffset()) {
// MMA V3 case.
auto newOrder = sharedEnc.getOrder();
auto ty = loadOp.getType().cast<RankedTensorType>();
auto oldOrder = ttg::getOrder(ty.getEncoding());
if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) {
// The operand of MMAv3 is in SharedEncoding and it's order should
// not be changed after FuseTranspositions Pass. So we only pipeline
// the load if the order of the loaded BlockedEncoding is the same
// as the order of the SharedEncoding it is converted to.
// TODO: remove this constraint once the LoadOp supports transpose
// fusion
hasMMAV3 = true;
return convertLayout.getResult();
}
}
}
}
// Advance to the first conversion as long as the use resides in shared
// memory and it has a single use itself
while (use) {
@@ -167,7 +188,6 @@ static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
auto tensorType = use->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>())
break;
preUse = use;
use = *use->getResult(0).getUsers().begin();
}
@@ -179,27 +199,6 @@ static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
return convertLayout.getResult();
}
}
} else if (preUse && isa<tt::DotOp>(use)) {
// for MMAv3 whose dot take SharedEncoding as operands directly
Operation *post = *loadOp.getResult().getUsers().begin();
auto newOrder = post->getResult(0)
.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<ttg::SharedEncodingAttr>()
.getOrder();
auto ty = loadOp.getType().cast<RankedTensorType>();
auto oldOrder = ttg::getOrder(ty.getEncoding());
// The operand of MMAv3 is in SharedEncoding and it's order should not
// be changed after FuseTranspositions Pass. So we only pipeline the
// load if the order of the loaded BlockedEncoding is the same as the
// order of the SharedEncoding it is converted to.
// TODO: remove this constraint once the LoadOp supports transpose
// fusion
if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) {
hasMMAV3 = true;
return preUse->getResult(0);
}
}
return Value();
}

View File

@@ -111,6 +111,7 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, e
if optimize_epilogue:
pm.add_tritongpu_optimize_epilogue_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_cse_pass()
ws_enabled = False
# `num_warps` does not mean the total number of warps of a CTA when
# warp specialization is enabled.

View File

@@ -489,6 +489,8 @@ class _attention(torch.autograd.Function):
BATCH, N_HEAD, N_CTX = q.shape[:3]
PRE_BLOCK = 128
NUM_WARPS, NUM_STAGES = 4, 1
if torch.cuda.get_device_capability()[0] == 9:
NUM_STAGES = 5
BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
BLK_SLICE_FACTOR = 2
RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)

View File

@@ -319,3 +319,63 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
// tt.return %res#0, %res#1, %res#2 : !tt.ptr<tensor<128x32xf16>, 1>, !tt.ptr<tensor<32x128xf16>, 1>, tensor<128x128xf32, #C>
// }
//}
// -----
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}>
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK-LABEL: dot_chained_single_load
tt.func @dot_chained_single_load(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}) -> tensor<128x64xf32, #mma> {
%cst = arith.constant dense<0> : tensor<64x16xi32, #blocked>
%c0_i32 = arith.constant 0 : i32
%cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked>
%cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>
%c0_i64 = arith.constant 0 : i64
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
%c1_i32 = arith.constant 1 : i32
%c8_i32 = arith.constant 8 : i32
%0 = tt.addptr %arg0, %c0_i64 : !tt.ptr<f16, 1>, i64
%1 = tt.addptr %arg1, %c0_i64 : !tt.ptr<f16, 1>, i64
%2 = tt.splat %1 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked1>
%3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr<f16, 1>, #blocked1>, tensor<128x1xi32, #blocked1>
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
%5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
%6 = tt.broadcast %3 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked1>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked1>
%7 = tt.broadcast %5 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
%8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr<f16, 1>, #blocked1>, tensor<128x64xi32, #blocked1>
%9 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
%10 = tt.splat %0 : (!tt.ptr<f16, 1>) -> tensor<1x16x!tt.ptr<f16, 1>, #blocked>
%11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr<f16, 1>, #blocked>, tensor<1x16xi32, #blocked>
%12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
%13 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
%14 = tt.broadcast %11 : (tensor<1x16x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x16x!tt.ptr<f16, 1>, #blocked>
%15 = tt.broadcast %13 : (tensor<64x1xi32, #blocked>) -> tensor<64x16xi32, #blocked>
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<64x16xi32, #blocked>
// CHECK: scf.for
// CHECK: triton_gpu.async_wait {num = 1 : i32}
// CHECK: tt.dot
// CHECK: triton_nvidia_gpu.dot_async
// CHECK: triton_gpu.insert_slice_async
// CHECK: triton_gpu.async_commit_group
// CHECK: scf.yield
%17:2 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr<f16, 1>, #blocked>) : i32 {
%18 = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked>
%19 = triton_gpu.convert_layout %9 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared>
%20 = triton_gpu.convert_layout %18 : (tensor<64x16xf16, #blocked>) -> tensor<64x16xf16, #shared1>
%21 = tt.dot %19, %20, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1>
%22 = arith.truncf %21 : tensor<128x16xf32, #mma1> to tensor<128x16xf16, #mma1>
%23 = tt.trans %20 : (tensor<64x16xf16, #shared1>) -> tensor<16x64xf16, #shared>
%24 = triton_gpu.convert_layout %22 : (tensor<128x16xf16, #mma1>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
%25 = tt.dot %24, %23, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<16x64xf16, #shared> -> tensor<128x64xf32, #mma>
%26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<64x16xi32, #blocked>
scf.yield %25, %26 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr<f16, 1>, #blocked>
}
tt.return %17#0 : tensor<128x64xf32, #mma>
}
}