mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Minor removing of unnecessary code and cleanup (#2443)
This commit is contained in:
@@ -129,9 +129,9 @@ bool supportMMA(Value value, int version);
|
||||
|
||||
bool isSingleValue(Value value);
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
|
||||
|
||||
// Return true if the src and dst layout match.
|
||||
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
|
||||
@@ -437,7 +437,7 @@ static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) {
|
||||
dst.getWarpsPerCTA()[1] == 1;
|
||||
}
|
||||
|
||||
bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToMmaShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding());
|
||||
}
|
||||
|
||||
@@ -453,7 +453,7 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
|
||||
srcTy.getElementType().isF16();
|
||||
}
|
||||
|
||||
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
|
||||
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
|
||||
return true;
|
||||
// dot_op<opIdx=0, parent=#mma> = #mma
|
||||
|
||||
@@ -31,58 +31,7 @@ using triton::gpu::SliceEncodingAttr;
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
// this is a heuristic to accommodate some pattern seen in fused attention
|
||||
// kernels.
|
||||
// TODO: replace this by something more generic, i.e. layout-aware CSE
|
||||
class DecomposeDotOperand : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
explicit DecomposeDotOperand(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
if (!llvm::isa<triton::gpu::ConvertLayoutOp>(op))
|
||||
return mlir::failure();
|
||||
auto convert = llvm::cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = convert.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = convert.getType().cast<RankedTensorType>();
|
||||
if (srcType.getEncoding().isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
dstType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
auto dstDotOperand =
|
||||
dstType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto dstParent = dstDotOperand.getParent();
|
||||
if (dstDotOperand.getOpIdx() == 1 ||
|
||||
!dstParent.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return mlir::failure();
|
||||
auto dstParentMma = dstParent.cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (dstParentMma.isVolta() || dstParentMma.getWarpsPerCTA()[1] > 1)
|
||||
return mlir::failure();
|
||||
SetVector<Operation *> bwdSlices;
|
||||
mlir::getBackwardSlice(convert.getResult(), &bwdSlices);
|
||||
if (llvm::find_if(bwdSlices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) == bwdSlices.end())
|
||||
return mlir::failure();
|
||||
|
||||
auto tmpType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), dstParentMma);
|
||||
auto tmp = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
convert.getLoc(), tmpType, convert.getOperand());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(op, dstType,
|
||||
tmp);
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
ConvertDotConvert(mlir::MLIRContext *context)
|
||||
@@ -1045,7 +994,6 @@ public:
|
||||
hoistConvert(m);
|
||||
|
||||
mlir::RewritePatternSet decomposePatterns(context);
|
||||
decomposePatterns.add<DecomposeDotOperand>(context);
|
||||
decomposePatterns.add<ConvertDotConvert>(context);
|
||||
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
|
||||
.failed()) {
|
||||
|
||||
Reference in New Issue
Block a user