Files
ROCm/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Philippe Tillet 52c146f66b [OPTIMIZER][BACKEND] significantly cleaner handling of mixed-precision kernels (#1949)
we currently have a very janky approach to optimizing mixed-precision
matmul workloads, where some layout combinations (e.g., NT matmul) were
explicitly pattern-matched to take a more optimized codepath. Attempt at
unifying all the codepaths to codegen cp.async failed, due to bugs in
SharedToDotOperandMMAv2.cpp.

This PR fixes said bugs, add some assertions for SharedToDotOperandMMAv2
modes that aren't well supported, and greatly simplify our handling of
element-wise operations between load and conversions to DotOperand.
2023-07-28 10:29:42 -07:00

237 lines
8.3 KiB
C++

#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include <memory>
using namespace mlir;
namespace {
using triton::DotOp;
using triton::gpu::BlockedEncodingAttr;
using triton::gpu::ConvertLayoutOp;
using triton::gpu::DotOperandEncodingAttr;
using triton::gpu::MmaEncodingAttr;
using triton::gpu::SliceEncodingAttr;
int computeCapabilityToMMAVersion(int computeCapability) {
if (computeCapability < 70) {
return 0;
} else if (computeCapability < 80) {
return 1;
} else if (computeCapability < 90) {
return 2;
} else if (computeCapability < 100) {
// FIXME: temporarily add this to pass unis tests
return 2;
} else {
assert(false && "computeCapability > 100 not supported");
return 3;
}
}
SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
if (version == 1)
return {16, 16};
else if (version == 2)
return {16, 8};
else {
assert(false && "version not supported");
return {0, 0};
}
}
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, {filter});
for (Operation *op : slices)
if (isa<triton::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
bool changed = false;
do {
changed = false;
if (ret[0] * ret[1] >= numWarps)
break;
if (shape[0] / shapePerWarp[0] / ret[0] >=
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
if (ret[0] < shape[0] / shapePerWarp[0]) {
ret[0] *= 2;
} else
ret[1] *= 2;
} else {
ret[1] *= 2;
}
} while (true);
return ret;
}
class BlockedToMMA : public mlir::RewritePattern {
int computeCapability;
mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding
static bool bwdFilter(Operation *op) {
return op->getNumOperands() == 1 &&
(isa<triton::FpToFpOp, triton::BitcastOp,
triton::gpu::ConvertLayoutOp>(op) ||
op->getDialect()->getTypeID() ==
mlir::TypeID::get<arith::ArithDialect>());
}
// finds the first different value bitwidth in the chain of
// shape-preserving unary ops that x depends on
static int computeOrigBitWidth(Value x) {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
if (RankedTensorType argTy = arg.getType().dyn_cast<RankedTensorType>())
origBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
return origBitWidth;
}
public:
BlockedToMMA(mlir::MLIRContext *context, int computeCapability)
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (computeCapability < 70)
return failure();
auto dotOp = cast<triton::DotOp>(op);
auto ctx = op->getContext();
// TODO: Check data-types and SM compatibility
auto oldRetType = dotOp.getResult().getType().cast<RankedTensorType>();
if (!oldRetType.getEncoding() ||
oldRetType.getEncoding().isa<triton::gpu::MmaEncodingAttr>())
return failure();
// for FMA, should retain the blocked layout.
int versionMajor = computeCapabilityToMMAVersion(computeCapability);
if (!supportMMA(dotOp, versionMajor))
return failure();
// get MMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
// operands
Value a = dotOp.getA();
Value b = dotOp.getB();
auto oldAType = a.getType().cast<RankedTensorType>();
auto oldBType = b.getType().cast<RankedTensorType>();
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
getBackwardSlice(a, &aBwdSlices, {isCvt});
getBackwardSlice(b, &bBwdSlices, {isCvt});
// get the source of the first conversion found in slices
auto getCvtArgOrder = [](Operation *op) {
return cast<ConvertLayoutOp>(op)
.getOperand()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<BlockedEncodingAttr>()
.getOrder();
};
bool isARow = true;
bool isBRow = true;
Operation *aOp = a.getDefiningOp();
Operation *bOp = b.getDefiningOp();
if (!aBwdSlices.empty())
aOp = aBwdSlices[0];
if (!bBwdSlices.empty())
bOp = bBwdSlices[0];
if (aOp)
isARow = getCvtArgOrder(aOp)[0] == 1;
if (bOp)
isBRow = getCvtArgOrder(bOp)[0] == 1;
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, numWarps, oldAType.getShape(),
oldBType.getShape(), retShape, isARow, isBRow, mmaV1Counter++);
} else if (versionMajor == 2) {
auto warpsPerTile = warpsPerTileV2(dotOp, retShape, numWarps);
mmaEnc = triton::gpu::MmaEncodingAttr::get(
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
warpsPerTile);
} else {
llvm_unreachable("Mma layout only supports versionMajor in {1, 2}");
}
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mmaEnc);
// convert accumulator
auto oldAcc = dotOp.getOperand(2);
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
oldAcc.getLoc(), newRetType, oldAcc);
// convert operands
int minBitwidth = std::min(computeOrigBitWidth(a), computeOrigBitWidth(b));
Type minType = IntegerType::get(ctx, minBitwidth);
// convert A operand
auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(),
minBitwidth > 0 ? minType : oldAType.getElementType());
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
// convert B operand
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(),
minBitwidth > 0 ? minType : oldBType.getElementType());
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
// convert dot instruction
auto newDot = rewriter.create<triton::DotOp>(
dotOp.getLoc(), newRetType, a, b, newAcc, dotOp.getAllowTF32());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, oldRetType, newDot.getResult());
return success();
}
};
} // namespace
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
class TritonGPUAccelerateMatmulPass
: public TritonGPUAccelerateMatmulBase<TritonGPUAccelerateMatmulPass> {
public:
TritonGPUAccelerateMatmulPass() = default;
TritonGPUAccelerateMatmulPass(int computeCapability) {
this->computeCapability = computeCapability;
}
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp m = getOperation();
mlir::RewritePatternSet patterns(context);
patterns.add<::BlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
}
};
std::unique_ptr<Pass>
mlir::createTritonGPUAccelerateMatmulPass(int computeCapability) {
return std::make_unique<TritonGPUAccelerateMatmulPass>(computeCapability);
}