[BACKEND] Minor removing of unnecessary code and cleanup (#2443)

This commit is contained in:
Thomas Raoux
2023-10-04 12:14:08 -07:00
committed by GitHub
parent 71a8544ce7
commit 5a0170a27c
5 changed files with 5 additions and 57 deletions

View File

@@ -129,9 +129,9 @@ bool supportMMA(Value value, int version);
bool isSingleValue(Value value);
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
// Return true if the src and dst layout match.
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,

View File

@@ -437,7 +437,7 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
dst.getWarpsPerCTA()[1] == 1;
}
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
}
@@ -453,7 +453,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
srcTy.getElementType().isF16();
}
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
// dot_op<opIdx=0, parent=#mma> = #mma

View File

@@ -31,58 +31,7 @@ using triton::gpu::SliceEncodingAttr;
//
// -----------------------------------------------------------------------------
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
// this is a heuristic to accommodate some pattern seen in fused attention
// kernels.
// TODO: replace this by something more generic, i.e. layout-aware CSE
class DecomposeDotOperand : public mlir::RewritePattern {
public:
explicit DecomposeDotOperand(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
return mlir::failure();
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
auto dstType = convert.getType().cast<RankedTensorType>();
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
auto dstDotOperand =
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto dstParent = dstDotOperand.getParent();
if (dstDotOperand.getOpIdx() == 1 ||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
return mlir::failure();
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1)
return mlir::failure();
SetVector<Operation *> bwdSlices;
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
if (llvm::find_if(bwdSlices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) == bwdSlices.end())
return mlir::failure();
auto tmpType = RankedTensorType::get(
dstType.getShape(), dstType.getElementType(), dstParentMma);
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
convert.getLoc(), tmpType, convert.getOperand());
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, dstType,
tmp);
return mlir::success();
}
return mlir::failure();
}
};
//
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
class ConvertDotConvert : public mlir::RewritePattern {
public:
ConvertDotConvert(mlir::MLIRContext *context)
@@ -1045,7 +994,6 @@ public:
hoistConvert(m);
mlir::RewritePatternSet decomposePatterns(context);
decomposePatterns.add<DecomposeDotOperand>(context);
decomposePatterns.add<ConvertDotConvert>(context);
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
.failed()) {