mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
177 lines
6.6 KiB
C++
177 lines
6.6 KiB
C++
#include "Utility.h"
|
|
#include "mlir/Pass/PassManager.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "mlir/Transforms/Passes.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
|
#include <memory>
|
|
|
|
using namespace mlir;
|
|
namespace {
|
|
using triton::DotOp;
|
|
using triton::gpu::ConvertLayoutOp;
|
|
using triton::gpu::DotOperandEncodingAttr;
|
|
using triton::gpu::MmaEncodingAttr;
|
|
using triton::gpu::SliceEncodingAttr;
|
|
|
|
// convert(trans(convert(arg)))
|
|
// x = convert_layout arg: #distributed -> #shared_x
|
|
// y = trans x: #shared_x -> #shared_y
|
|
// z = convert_layout y: #shared_y -> #dot_operand
|
|
class ConvertTransConvert : public mlir::RewritePattern {
|
|
|
|
public:
|
|
ConvertTransConvert(mlir::MLIRContext *context)
|
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
|
1, context) {}
|
|
|
|
LogicalResult
|
|
matchAndRewrite(mlir::Operation *op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto dstOp = cast<triton::gpu::ConvertLayoutOp>(op);
|
|
auto tmpOp =
|
|
dyn_cast_or_null<triton::TransOp>(dstOp.getSrc().getDefiningOp());
|
|
if (!tmpOp)
|
|
return mlir::failure();
|
|
auto srcOp = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
|
tmpOp.getSrc().getDefiningOp());
|
|
if (!srcOp)
|
|
return mlir::failure();
|
|
auto arg = srcOp.getSrc();
|
|
auto X = tmpOp.getSrc();
|
|
// types
|
|
auto argType = arg.getType().cast<RankedTensorType>();
|
|
auto XType = X.getType().cast<RankedTensorType>();
|
|
auto ZType = dstOp.getResult().getType().cast<RankedTensorType>();
|
|
// encodings
|
|
auto argEncoding = argType.getEncoding();
|
|
auto XEncoding =
|
|
XType.getEncoding().cast<triton::gpu::SharedEncodingAttr>();
|
|
auto ZEncoding =
|
|
ZType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
if (!ZEncoding)
|
|
return mlir::failure();
|
|
// new X encoding
|
|
auto newXOrder = triton::gpu::getOrder(argEncoding);
|
|
auto newXEncoding = triton::gpu::SharedEncodingAttr::get(
|
|
getContext(), ZEncoding, XType.getShape(), newXOrder,
|
|
XType.getElementType());
|
|
auto newXType = RankedTensorType::get(XType.getShape(),
|
|
XType.getElementType(), newXEncoding);
|
|
if (XEncoding == newXEncoding)
|
|
return mlir::failure();
|
|
|
|
auto newX = rewriter.create<triton::gpu::ConvertLayoutOp>(srcOp.getLoc(),
|
|
newXType, arg);
|
|
auto newY = rewriter.create<triton::TransOp>(tmpOp.getLoc(), newX);
|
|
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(dstOp, ZType,
|
|
newY);
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
//
|
|
|
|
class MoveOpAfterLayoutConversion : public mlir::RewritePattern {
|
|
|
|
public:
|
|
MoveOpAfterLayoutConversion(mlir::MLIRContext *context)
|
|
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
|
1, context) {}
|
|
|
|
mlir::LogicalResult
|
|
matchAndRewrite(mlir::Operation *op,
|
|
mlir::PatternRewriter &rewriter) const override {
|
|
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
|
auto srcTy = cvt.getOperand().getType().cast<RankedTensorType>();
|
|
auto retTy = cvt.getResult().getType().dyn_cast<RankedTensorType>();
|
|
auto retEncoding =
|
|
retTy.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
|
auto srcEncoding =
|
|
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
|
if (!retTy)
|
|
return failure();
|
|
if (!retEncoding)
|
|
return failure();
|
|
auto retEncodingParent =
|
|
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
|
if (!retEncodingParent || retEncodingParent.isVolta())
|
|
return failure();
|
|
if (!srcEncoding)
|
|
return failure();
|
|
// don't move things around when cvt operand is a block arg
|
|
Operation *argOp = cvt.getOperand().getDefiningOp();
|
|
if (!argOp)
|
|
return failure();
|
|
//
|
|
SetVector<Operation *> processed;
|
|
SetVector<Attribute> layout;
|
|
llvm::MapVector<Value, Attribute> toConvert;
|
|
int numCvts = simulateBackwardRematerialization(cvt, processed, layout,
|
|
toConvert, retEncoding);
|
|
if (numCvts > 1 || toConvert.size() == 1)
|
|
return failure();
|
|
for (Operation *op : processed) {
|
|
if (op->getNumOperands() != 1)
|
|
continue;
|
|
auto srcTy = op->getOperand(0).getType().cast<RankedTensorType>();
|
|
auto dstTy = op->getResult(0).getType().cast<RankedTensorType>();
|
|
// we don't want to push conversions backward if there is a downcast
|
|
// since it would result in more shared memory traffic
|
|
if (srcTy.getElementType().getIntOrFloatBitWidth() >
|
|
dstTy.getElementType().getIntOrFloatBitWidth())
|
|
return failure();
|
|
// we only push back when the first op in the chain has a load operand
|
|
if ((op == processed.back()) &&
|
|
!isa<triton::LoadOp>(op->getOperand(0).getDefiningOp()))
|
|
return failure();
|
|
// we don't want to use ldmatrix for 8-bit data that requires trans
|
|
// since Nvidia GPUs can't do it efficiently
|
|
bool isTrans =
|
|
(retEncoding.getOpIdx() == 1) ^ (srcEncoding.getOrder()[0] == 0);
|
|
bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8;
|
|
if (isTrans && isInt8)
|
|
return failure();
|
|
}
|
|
IRMapping mapping;
|
|
rematerializeConversionChain(toConvert, rewriter, processed, mapping);
|
|
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
|
return mlir::success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
|
|
|
class TritonGPUOptimizeDotOperandsPass
|
|
: public TritonGPUOptimizeDotOperandsBase<
|
|
TritonGPUOptimizeDotOperandsPass> {
|
|
public:
|
|
TritonGPUOptimizeDotOperandsPass() = default;
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp m = getOperation();
|
|
|
|
mlir::PassManager pm(m.getContext());
|
|
pm.addPass(mlir::createCanonicalizerPass());
|
|
auto ret = pm.run(m);
|
|
|
|
mlir::RewritePatternSet patterns(context);
|
|
patterns.add<ConvertTransConvert>(context);
|
|
patterns.add<MoveOpAfterLayoutConversion>(context);
|
|
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
|
signalPassFailure();
|
|
if (fixupLoops(m).failed())
|
|
signalPassFailure();
|
|
}
|
|
};
|
|
|
|
std::unique_ptr<Pass> mlir::createTritonGPUOptimizeDotOperandsPass() {
|
|
return std::make_unique<TritonGPUOptimizeDotOperandsPass>();
|
|
}
|