[BACKEND] Improve decision of MMA dimension on H100 (#2373)

When there is a chain of mma ops we want to pick the same shape to avoid
conversions. This improves the detection going through for loops.
This fixes a crash in tutorial bw attention.

We might want to change this logic and convert the format to allow more
efficient MMA at some point.
This commit is contained in:
Thomas Raoux
2023-09-22 15:21:56 -07:00
committed by GitHub
parent 1724604bd9
commit 840e7e7b53
4 changed files with 106 additions and 14 deletions

View File

@@ -141,6 +141,16 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);
// Implement backward and forward slice that will go through scf blocks when
// yield or scf results are in the slice.
// Note that like exisiting forward and backard slice this may add operations to
// the slice that are not actually dependent on the root because when a region
// is added to the slice in the forward slice all the operations of the region
// are added. We could implement a more accurate slice method by tracking value
// usage across scf regions.
void getBackwardSliceSCFAware(Operation *, SetVector<Operation *> *slices);
void getForwardSliceSCFAware(Value root, SetVector<Operation *> *slices);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

View File

@@ -102,8 +102,7 @@ warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
mutable llvm::SmallVector<llvm::SetVector<Operation *>> dotOpSetVector;
mutable llvm::SmallVector<unsigned> mmaV3InstrNs;
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;
static bool bwdFilter(Operation *op) {
return op->getNumOperands() == 1 &&
@@ -150,36 +149,36 @@ public:
auto type = dotOp.getResult().getType().cast<RankedTensorType>();
if (type.getEncoding().isa<MmaEncodingAttr>())
return currN;
for (size_t i = 0; i < dotOpSetVector.size(); ++i) {
if (dotOpSetVector[i].count(dotOp.getOperation()) > 0)
return mmaV3InstrNs[i];
}
auto it = dotOpInstNs.find(dotOp.getOperation());
if (it != dotOpInstNs.end())
return it->second;
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
mlir::getBackwardSlice(dotOp.getOperation(), &slices);
mlir::getForwardSliceSCFAware(dotOp.getResult(), &slices);
mlir::getBackwardSliceSCFAware(dotOp.getOperation(), &slices);
unsigned N = currN;
llvm::SetVector<Operation *> dotOpSet;
SmallVector<Operation *> dotOps;
for (Operation *iter : slices) {
if (auto nextDotOp = dyn_cast<tt::DotOp>(iter)) {
auto type = nextDotOp.getResult().getType().cast<RankedTensorType>();
auto AType = nextDotOp.getOperand(0).getType().cast<RankedTensorType>();
auto shapePerCTA = ttg::getShapePerCTA(type);
auto instrShape = mmaVersionToInstrShape(3, shapePerCTA, AType);
dotOpSet.insert(iter);
dotOps.push_back(iter);
if (instrShape[1] < N)
N = instrShape[1];
}
}
mmaV3InstrNs.push_back(N);
dotOpSetVector.push_back(dotOpSet);
for (Operation *dotOp : dotOps)
dotOpInstNs[dotOp] = N;
return N;
}
static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter,
int opIdx) {
auto cvtOp = dyn_cast_or_null<ttg::ConvertLayoutOp>(v.getDefiningOp());
auto arg = cvtOp.getSrc();
Value arg = v;
if (auto cvtOp = v.getDefiningOp<ttg::ConvertLayoutOp>())
arg = cvtOp.getSrc();
auto argType = arg.getType().cast<RankedTensorType>();
auto eltType = argType.getElementType();
assert(argType.getEncoding() && "unexpected tensor type");

View File

@@ -492,6 +492,45 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef<Value> multiDim,
return linear;
}
void getBackwardSliceSCFAware(Operation *op, SetVector<Operation *> *slices) {
SmallVector<Operation *> queue = {op};
while (!queue.empty()) {
Operation *currentOp = queue.back();
queue.pop_back();
SetVector<Operation *> temp;
auto filter = [slices](Operation *sliceOp) {
return slices->count(sliceOp) == 0;
};
mlir::getBackwardSlice(currentOp, &temp, filter);
for (Operation *sliceOp : temp) {
if (auto forOp = dyn_cast<scf::ForOp>(sliceOp)) {
queue.push_back(forOp.getBody()->getTerminator());
}
}
slices->insert(temp.begin(), temp.end());
}
}
void getForwardSliceSCFAware(Value root, SetVector<Operation *> *slices) {
SmallVector<Value> queue = {root};
while (!queue.empty()) {
Value currentValue = queue.back();
queue.pop_back();
SetVector<Operation *> temp;
auto filter = [slices](Operation *sliceOp) {
return slices->count(sliceOp) == 0;
};
mlir::getForwardSlice(currentValue, &temp, filter);
for (Operation *sliceOp : temp) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(sliceOp)) {
auto forOp = yieldOp->getParentOfType<scf::ForOp>();
queue.append(forOp->getResults().begin(), forOp->getResults().end());
}
}
slices->insert(temp.begin(), temp.end());
}
}
namespace {
/// Detect dead arguments in scf.for op by assuming all the values are dead and

View File

@@ -0,0 +1,44 @@
// RUN: ENABLE_MMA_V3=1 triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s
// CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
// CHECK: mma_chain_loop
tt.func public @mma_chain_loop(
%170: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>,
%171: tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>,
%179: tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>,
%164: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>,
%165: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>,
%173: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>,
%153: tensor<128x64x!tt.ptr<f16, 1>, #blocked1>) {
%c0_i32 = arith.constant 0 : i32
%c8_i32 = arith.constant 8 : i32
%c1_i32 = arith.constant 1 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<128x16xf16, #blocked>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2>
// CHECK: scf.for
// CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]>
%115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 {
%172 = tt.dot %170, %171, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked>
%178 = triton_gpu.convert_layout %172 : (tensor<128x16xf16, #blocked>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
%180 = tt.dot %178, %179, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1>
scf.yield %180 : tensor<128x64xf16, #blocked1>
}
// CHECK: scf.for
// CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA]]>
// CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA]]>
%149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 {
%166 = tt.dot %164, %165, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2>
%172 = triton_gpu.convert_layout %166 : (tensor<128x32xf16, #blocked2>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
%174 = tt.dot %172, %173, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1>
scf.yield %174 : tensor<128x64xf16, #blocked1>
}
tt.store %153, %149 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked1>
tt.return
}
}