mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Co-authored-by: Thomas Raoux <thomas.raoux@openai.com> Co-authored-by: Keren Zhou <kerenzhou@openai.com> Co-authored-by: Phil Tillet <phil@openai.com>
799 lines
34 KiB
C++
799 lines
34 KiB
C++
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
|
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
|
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
|
#include "mlir/Pass/Pass.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
#include "triton/Analysis/Utility.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
|
#include "triton/Target/PTX/TmaMetadata.h"
|
|
#include "llvm/ADT/APSInt.h"
|
|
#include <numeric>
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::triton;
|
|
|
|
#define GEN_PASS_CLASSES
|
|
#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc"
|
|
|
|
namespace {
|
|
|
|
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
|
|
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
|
|
for (const NamedAttribute attr : dictAttrs.getValue())
|
|
if (!op->hasAttr(attr.getName()))
|
|
op->setAttr(attr.getName(), attr.getValue());
|
|
}
|
|
|
|
template <class Op> struct GenericOpPattern : public OpConversionPattern<Op> {
|
|
using OpConversionPattern<Op>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
SmallVector<Type> retTypes;
|
|
if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
|
|
retTypes)))
|
|
return failure();
|
|
rewriter.replaceOpWithNewOp<Op>(op, retTypes, adaptor.getOperands(),
|
|
op->getAttrs());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
template <class SrcOp, class DstOp>
|
|
class ArithCmpPattern : public OpConversionPattern<SrcOp> {
|
|
public:
|
|
using OpConversionPattern<SrcOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
addNamedAttrs(
|
|
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
|
adaptor.getLhs(), adaptor.getRhs()),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
|
|
public:
|
|
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = getTypeConverter()->convertType(op.getType());
|
|
auto retShapedType = retType.cast<ShapedType>();
|
|
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
|
if (dyn_cast<RankedTensorType>(retShapedType)) {
|
|
assert(value);
|
|
if (value.getElementType().isInteger(1) && value.isSplat())
|
|
// Workaround until https://reviews.llvm.org/D133743 is included.
|
|
value =
|
|
DenseElementsAttr::get(retShapedType, value.getSplatValue<bool>());
|
|
else
|
|
// This is a hack. We just want to add encoding
|
|
value = value.reshape(retShapedType);
|
|
}
|
|
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
|
op, retShapedType, value),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
TritonGPUConversionTarget &target) {
|
|
// --------------
|
|
// Add legality and rewrite pattern rules for operations
|
|
// from the Arith dialect. The basic premise is that
|
|
// Arith operations require both inputs to have the same
|
|
// non-null encoding
|
|
// --------------
|
|
MLIRContext *context = patterns.getContext();
|
|
// TODO: there's probably a better way to avoid adding all ops one-by-one
|
|
patterns.add<
|
|
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
|
|
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
|
|
GenericOpPattern<arith::DivUIOp>, GenericOpPattern<arith::DivSIOp>,
|
|
GenericOpPattern<arith::CeilDivUIOp>,
|
|
GenericOpPattern<arith::CeilDivSIOp>,
|
|
GenericOpPattern<arith::FloorDivSIOp>, GenericOpPattern<arith::RemUIOp>,
|
|
GenericOpPattern<arith::RemSIOp>, GenericOpPattern<arith::AndIOp>,
|
|
GenericOpPattern<arith::OrIOp>, GenericOpPattern<arith::XOrIOp>,
|
|
GenericOpPattern<arith::ShLIOp>, GenericOpPattern<arith::ShRUIOp>,
|
|
GenericOpPattern<arith::ShRSIOp>, // NegFOp
|
|
// Floating point
|
|
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
|
|
// MaxMin
|
|
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxSIOp>,
|
|
GenericOpPattern<arith::MaxUIOp>, GenericOpPattern<arith::MinimumFOp>,
|
|
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
|
|
// Floating point
|
|
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
|
|
GenericOpPattern<arith::RemFOp>,
|
|
// Cmp
|
|
ArithCmpPattern<arith::CmpIOp, triton::gpu::CmpIOp>,
|
|
ArithCmpPattern<arith::CmpFOp, triton::gpu::CmpFOp>,
|
|
// Cast Ops
|
|
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
|
|
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
|
|
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
|
|
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
|
|
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
|
|
}
|
|
|
|
// this shouldn't exist if mlir's SelectOp checked encodings properly
|
|
class StdSelectPattern : public OpConversionPattern<arith::SelectOp> {
|
|
public:
|
|
using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
|
|
Value cond = adaptor.getCondition();
|
|
if (llvm::isa<RankedTensorType>(retType) &&
|
|
!llvm::isa<TensorType>(cond.getType())) {
|
|
// triton_gpu.select doesn't support scalar condition values, so add a
|
|
// splat
|
|
auto retTypeTensor = llvm::cast<RankedTensorType>(retType);
|
|
auto retShape = retTypeTensor.getShape();
|
|
auto retEncoding = retTypeTensor.getEncoding();
|
|
Type condTy =
|
|
RankedTensorType::get(retShape, cond.getType(), retEncoding);
|
|
cond = rewriter.create<triton::SplatOp>(op.getLoc(), condTy, cond);
|
|
}
|
|
|
|
addNamedAttrs(
|
|
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
|
op, retType, cond, adaptor.getTrueValue(), adaptor.getFalseValue()),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
TritonGPUConversionTarget &target) {
|
|
MLIRContext *context = patterns.getContext();
|
|
// Rewrite rule
|
|
patterns.add<StdSelectPattern>(typeConverter, context);
|
|
}
|
|
|
|
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns,
|
|
TritonGPUConversionTarget &target) {
|
|
MLIRContext *context = patterns.getContext();
|
|
// Rewrite rule
|
|
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::CosOp>,
|
|
GenericOpPattern<math::SinOp>, GenericOpPattern<math::LogOp>,
|
|
GenericOpPattern<math::AbsFOp>, GenericOpPattern<math::AbsIOp>,
|
|
GenericOpPattern<math::SqrtOp>>(typeConverter, context);
|
|
}
|
|
|
|
//
|
|
// Triton patterns
|
|
//
|
|
struct TritonExpandDimsPattern
|
|
: public OpConversionPattern<triton::ExpandDimsOp> {
|
|
using OpConversionPattern<triton::ExpandDimsOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// Type retType = op.getType());
|
|
RankedTensorType argType =
|
|
adaptor.getSrc().getType().cast<RankedTensorType>();
|
|
Attribute _argEncoding = argType.getEncoding();
|
|
if (!_argEncoding)
|
|
return failure();
|
|
auto argEncoding = _argEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
|
// return shape
|
|
auto retShape = argType.getShape().vec();
|
|
retShape.insert(retShape.begin() + op.getAxis(), 1);
|
|
// return encoding
|
|
auto retSizePerThread = argEncoding.getSizePerThread().vec();
|
|
retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1);
|
|
auto retThreadsPerWarp = argEncoding.getThreadsPerWarp().vec();
|
|
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1);
|
|
auto retWarpsPerCTA = argEncoding.getWarpsPerCTA().vec();
|
|
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
|
|
SmallVector<unsigned, 4> retOrder(retShape.size());
|
|
std::iota(retOrder.begin(), retOrder.end(), 0);
|
|
|
|
auto argCTALayout = argEncoding.getCTALayout();
|
|
auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis());
|
|
auto retCTASplitNum =
|
|
insertOne(argCTALayout.getCTASplitNum(), op.getAxis());
|
|
auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis());
|
|
auto retCTALayout = triton::gpu::CTALayoutAttr::get(
|
|
getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder);
|
|
|
|
triton::gpu::BlockedEncodingAttr retEncoding =
|
|
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
|
|
retThreadsPerWarp, retWarpsPerCTA,
|
|
retOrder, retCTALayout);
|
|
// convert operand to slice of return type
|
|
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
|
|
getContext(), op.getAxis(), retEncoding);
|
|
RankedTensorType newArgType = RankedTensorType::get(
|
|
argType.getShape(), argType.getElementType(), newArgEncoding);
|
|
// construct new op
|
|
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
|
op.getLoc(), newArgType, adaptor.getSrc());
|
|
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
|
|
op, newSrc, adaptor.getAxis()),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
template <typename T>
|
|
SmallVector<T> insertOne(ArrayRef<T> vec, unsigned axis) const {
|
|
SmallVector<T> res(vec.begin(), vec.end());
|
|
res.insert(res.begin() + axis, 1);
|
|
return res;
|
|
}
|
|
|
|
// Example: order = [ 0, 2, 1, 3], dim = 2
|
|
// resOrder = [2, 0, 3, 1, 4]
|
|
SmallVector<unsigned> insertOrder(ArrayRef<unsigned> order,
|
|
unsigned axis) const {
|
|
SmallVector<unsigned> resOrder(order.begin(), order.end());
|
|
for (unsigned i = 0; i < resOrder.size(); ++i)
|
|
if (resOrder[i] >= axis)
|
|
++resOrder[i];
|
|
resOrder.insert(resOrder.begin(), axis);
|
|
return resOrder;
|
|
}
|
|
};
|
|
|
|
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
|
using OpConversionPattern<triton::DotOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
RankedTensorType origType = op.getType().cast<RankedTensorType>();
|
|
auto origShape = origType.getShape();
|
|
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
|
|
int numWarps = typeConverter->getNumWarps();
|
|
int threadsPerWarp = typeConverter->getThreadsPerWarp();
|
|
int numCTAs = typeConverter->getNumCTAs();
|
|
|
|
SmallVector<unsigned> retSizePerThread = {1, 1};
|
|
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 4)
|
|
retSizePerThread = {2, 2};
|
|
if (origShape[0] * origShape[1] / (numWarps * threadsPerWarp) >= 16)
|
|
retSizePerThread = {4, 4};
|
|
SmallVector<unsigned> retOrder = {1, 0};
|
|
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
|
|
getContext(), origShape, retSizePerThread, retOrder, numWarps,
|
|
threadsPerWarp, numCTAs);
|
|
RankedTensorType retType =
|
|
RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
|
|
// a & b must be of smem layout
|
|
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
|
|
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
|
|
Type aEltType = aType.getElementType();
|
|
Type bEltType = bType.getElementType();
|
|
Attribute aEncoding = aType.getEncoding();
|
|
Attribute bEncoding = bType.getEncoding();
|
|
if (!aEncoding || !bEncoding)
|
|
return failure();
|
|
Value a = adaptor.getA();
|
|
Value b = adaptor.getB();
|
|
Value c = adaptor.getC();
|
|
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
|
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
|
|
getContext(), 0, dEncoding, aEltType);
|
|
auto dstType =
|
|
RankedTensorType::get(aType.getShape(), aEltType, encoding);
|
|
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
|
}
|
|
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
|
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
|
|
getContext(), 1, dEncoding, bEltType);
|
|
auto dstType =
|
|
RankedTensorType::get(bType.getShape(), bEltType, encoding);
|
|
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
|
}
|
|
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
|
|
|
|
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
|
|
op, retType, a, b, c, adaptor.getAllowTF32(),
|
|
adaptor.getMaxNumImpreciseAcc()),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
|
|
|
using OpConversionPattern<triton::CatOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// The cat op satisfy two conditions:
|
|
// 1. output.numel = lhs.numel + rhs.numel
|
|
// 2. output.total_elems_per_thread =
|
|
// next_power_of_2(lhs.total_elems_per_thread + rhs.total_elems_per_thread)
|
|
// For now, this behaves like generic, but this
|
|
// will evolve when we add support for `can_reorder=False`.
|
|
auto retType = this->getTypeConverter()
|
|
->convertType(op.getType())
|
|
.cast<RankedTensorType>();
|
|
auto retEncoding =
|
|
retType.getEncoding().cast<triton::gpu::BlockedEncodingAttr>();
|
|
auto lhsType = adaptor.getLhs().getType().cast<RankedTensorType>();
|
|
auto rhsType = adaptor.getRhs().getType().cast<RankedTensorType>();
|
|
auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType);
|
|
auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType);
|
|
auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType);
|
|
auto retShape = retType.getShape();
|
|
auto retOrder = retEncoding.getOrder();
|
|
auto retSizePerThread = retEncoding.getSizePerThread();
|
|
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
|
|
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
|
|
// Get new retSizePerThread if ret elems per thread is not enough.
|
|
// We have to round it up to the next power of 2 due to triton's tensor size
|
|
// constraint.
|
|
auto newRetTotalElemsPerThread =
|
|
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
|
|
auto newRetSizePerThread = retSizePerThread.vec();
|
|
newRetSizePerThread[retOrder[0]] *=
|
|
newRetTotalElemsPerThread / retTotalElemsPerThread;
|
|
triton::gpu::BlockedEncodingAttr newRetEncoding =
|
|
triton::gpu::BlockedEncodingAttr::get(
|
|
getContext(), newRetSizePerThread, retThreadsPerWarp,
|
|
retWarpsPerCTA, retOrder, retEncoding.getCTALayout());
|
|
auto newRetType = RankedTensorType::get(retShape, retType.getElementType(),
|
|
newRetEncoding);
|
|
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
|
|
op, newRetType, adaptor.getOperands()),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
|
|
|
using OpConversionPattern<triton::TransOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
Value src = adaptor.getSrc();
|
|
auto srcType = src.getType().cast<RankedTensorType>();
|
|
Attribute srcEncoding = srcType.getEncoding();
|
|
if (!srcEncoding)
|
|
return failure();
|
|
if (!srcEncoding.isa<triton::gpu::SharedEncodingAttr>()) {
|
|
// TODO: end-to-end correctness is broken if
|
|
// the input is blocked and the output is shared
|
|
// with different order. Maybe a backend issue in BlockedToShared?
|
|
SmallVector<unsigned> order = {1, 0};
|
|
if (auto srcBlockedEncoding =
|
|
srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
|
llvm::copy(srcBlockedEncoding.getOrder(), order.begin());
|
|
// TODO(Qingyi): need to check whether the CTALayout of srcEncoding should
|
|
// be used here. For tests where numCTAs = 1, this is not a problem since
|
|
// all CTALayouts are the same.
|
|
auto CTALayout = triton::gpu::getCTALayout(srcEncoding);
|
|
srcEncoding = triton::gpu::SharedEncodingAttr::get(getContext(), 1, 1, 1,
|
|
order, CTALayout);
|
|
srcType = RankedTensorType::get(srcType.getShape(),
|
|
srcType.getElementType(), srcEncoding);
|
|
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
|
src);
|
|
}
|
|
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::TransOp>(op, src),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonBroadcastPattern
|
|
: public OpConversionPattern<triton::BroadcastOp> {
|
|
using OpConversionPattern<triton::BroadcastOp>::OpConversionPattern;
|
|
|
|
// This creates a tensor with the new shape but the argument's layout
|
|
LogicalResult
|
|
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto srcType = adaptor.getSrc().getType().cast<RankedTensorType>();
|
|
auto srcEncoding = srcType.getEncoding();
|
|
if (!srcEncoding)
|
|
return failure();
|
|
auto opType = op.getType().cast<RankedTensorType>();
|
|
Type retType = RankedTensorType::get(opType.getShape(),
|
|
opType.getElementType(), srcEncoding);
|
|
// Type retType = this->getTypeConverter()->convertType(op.getType());
|
|
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
|
op, retType, adaptor.getOperands()),
|
|
adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
|
using OpConversionPattern<triton::ReduceOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newReduce = rewriter.create<triton::ReduceOp>(
|
|
op.getLoc(), adaptor.getOperands(), adaptor.getAxis());
|
|
addNamedAttrs(newReduce, adaptor.getAttributes());
|
|
|
|
auto &newCombineOp = newReduce.getCombineOp();
|
|
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
|
|
newCombineOp.end());
|
|
rewriter.replaceOp(op, newReduce.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
struct TritonScanPattern : public OpConversionPattern<triton::ScanOp> {
|
|
using OpConversionPattern<triton::ScanOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newScan = rewriter.create<triton::ScanOp>(
|
|
op.getLoc(), adaptor.getOperands(), adaptor.getAxis());
|
|
addNamedAttrs(newScan, adaptor.getAttributes());
|
|
|
|
auto &newCombineOp = newScan.getCombineOp();
|
|
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
|
|
newCombineOp.end());
|
|
rewriter.replaceOp(op, newScan.getResult());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
|
|
public:
|
|
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto converter = getTypeConverter();
|
|
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
|
|
op, op.getName(), op.getFunctionType());
|
|
addNamedAttrs(newOp, adaptor.getAttributes());
|
|
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
|
|
newOp.getBody().end());
|
|
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class TritonCallOpPattern : public OpConversionPattern<triton::CallOp> {
|
|
public:
|
|
using OpConversionPattern<triton::CallOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(triton::CallOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newOp = rewriter.replaceOpWithNewOp<triton::CallOp>(
|
|
op, op.getCallee(), op.getResultTypes(), adaptor.getOperands());
|
|
addNamedAttrs(newOp, adaptor.getAttributes());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class TritonReturnOpPattern : public OpConversionPattern<ReturnOp> {
|
|
public:
|
|
using OpConversionPattern<ReturnOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns, unsigned numCTAs) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.insert< // TODO: view should have custom pattern that views the
|
|
// layout
|
|
GenericOpPattern<triton::AdvanceOp>,
|
|
GenericOpPattern<triton::MakeTensorPtrOp>,
|
|
GenericOpPattern<triton::ViewOp>, GenericOpPattern<triton::BitcastOp>,
|
|
GenericOpPattern<triton::FpToFpOp>, GenericOpPattern<triton::IntToPtrOp>,
|
|
GenericOpPattern<triton::PtrToIntOp>, GenericOpPattern<triton::SplatOp>,
|
|
TritonBroadcastPattern, GenericOpPattern<triton::AddPtrOp>,
|
|
TritonCatPattern, GenericOpPattern<triton::ElementwiseInlineAsmOp>,
|
|
TritonReducePattern, GenericOpPattern<triton::ReduceReturnOp>,
|
|
TritonScanPattern, GenericOpPattern<triton::ScanReturnOp>,
|
|
GenericOpPattern<triton::MakeRangeOp>, TritonExpandDimsPattern,
|
|
TritonTransPattern, TritonDotPattern, GenericOpPattern<triton::LoadOp>,
|
|
GenericOpPattern<triton::StoreOp>,
|
|
GenericOpPattern<triton::ExternElementwiseOp>,
|
|
GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
|
|
GenericOpPattern<triton::AtomicRMWOp>, GenericOpPattern<ReturnOp>,
|
|
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
|
|
context);
|
|
}
|
|
|
|
//
|
|
// SCF patterns
|
|
//
|
|
// This is borrowed from ConvertForOpTypes in
|
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
|
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
|
|
using OpConversionPattern<scf::ForOp>::OpConversionPattern;
|
|
// Ref: ConvertForOpTypes
|
|
LogicalResult
|
|
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto newOp =
|
|
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
|
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
|
|
newOp.getRegion().end());
|
|
|
|
// Now, update all the types.
|
|
|
|
// Convert the types of block arguments within the given region. This
|
|
// replaces each block with a new block containing the updated signature.
|
|
// The entry block may have a special conversion if `entryConversion` is
|
|
// provided. On success, the new entry block to the region is returned for
|
|
// convenience. Otherwise, failure is returned.
|
|
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
|
|
*getTypeConverter()))) {
|
|
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
|
}
|
|
// Change the clone to use the updated operands. We could have cloned with
|
|
// a IRMapping, but this seems a bit more direct.
|
|
newOp->setOperands(adaptor.getOperands());
|
|
// Update the result types to the new converted types.
|
|
SmallVector<Type> newResultTypes;
|
|
for (Type type : op.getResultTypes()) {
|
|
Type newType = typeConverter->convertType(type);
|
|
if (!newType)
|
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
|
newResultTypes.push_back(newType);
|
|
}
|
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
|
std::get<0>(t).setType(std::get<1>(t));
|
|
|
|
rewriter.replaceOp(op, newOp.getResults());
|
|
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This is borrowed from ConvertFIfOpTypes in
|
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
|
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
|
|
public:
|
|
using OpConversionPattern<scf::IfOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
// TODO: Generalize this to any type conversion, not just 1:1.
|
|
//
|
|
// We need to implement something more sophisticated here that tracks which
|
|
// types convert to which other types and does the appropriate
|
|
// materialization logic.
|
|
// For example, it's possible that one result type converts to 0 types and
|
|
// another to 2 types, so newResultTypes would at least be the right size to
|
|
// not crash in the llvm::zip call below, but then we would set the the
|
|
// wrong type on the SSA values! These edge cases are also why we cannot
|
|
// safely use the TypeConverter::convertTypes helper here.
|
|
SmallVector<Type> newResultTypes;
|
|
for (auto type : op.getResultTypes()) {
|
|
Type newType = typeConverter->convertType(type);
|
|
if (!newType)
|
|
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
|
|
newResultTypes.push_back(newType);
|
|
}
|
|
|
|
// See comments in the ForOp pattern for why we clone without regions and
|
|
// then inline.
|
|
scf::IfOp newOp =
|
|
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
|
|
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
|
|
newOp.getThenRegion().end());
|
|
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
|
|
newOp.getElseRegion().end());
|
|
|
|
// Update the operands and types.
|
|
newOp->setOperands(adaptor.getOperands());
|
|
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
|
|
std::get<0>(t).setType(std::get<1>(t));
|
|
rewriter.replaceOp(op, newOp.getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
// This is borrowed from ConvertFIfOpTypes in
|
|
// SCF/Transforms/StructuralTypeConversions.cpp
|
|
class SCFWhilePattern : public OpConversionPattern<scf::WhileOp> {
|
|
public:
|
|
using OpConversionPattern<scf::WhileOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto *converter = getTypeConverter();
|
|
assert(converter);
|
|
SmallVector<Type> newResultTypes;
|
|
if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
|
|
return failure();
|
|
|
|
auto newOp = rewriter.create<scf::WhileOp>(op.getLoc(), newResultTypes,
|
|
adaptor.getOperands());
|
|
for (auto i : {0u, 1u}) {
|
|
auto &dstRegion = newOp.getRegion(i);
|
|
rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
|
|
if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
|
|
return rewriter.notifyMatchFailure(op, "could not convert body types");
|
|
}
|
|
rewriter.replaceOp(op, newOp.getResults());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class SCFConditionPattern : public OpConversionPattern<scf::ConditionOp> {
|
|
public:
|
|
using OpConversionPattern<scf::ConditionOp>::OpConversionPattern;
|
|
LogicalResult
|
|
matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
rewriter.updateRootInPlace(
|
|
op, [&]() { op->setOperands(adaptor.getOperands()); });
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.add<GenericOpPattern<scf::YieldOp>, SCFForPattern, SCFIfPattern,
|
|
SCFWhilePattern, SCFConditionPattern>(typeConverter, context);
|
|
}
|
|
|
|
// CF
|
|
|
|
class CFBranchPattern : public OpConversionPattern<cf::BranchOp> {
|
|
public:
|
|
using OpConversionPattern<cf::BranchOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto converter = getTypeConverter();
|
|
auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
|
|
op, op.getSuccessor(), adaptor.getOperands());
|
|
if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(),
|
|
*converter)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class CFCondBranchPattern : public OpConversionPattern<cf::CondBranchOp> {
|
|
public:
|
|
using OpConversionPattern<cf::CondBranchOp>::OpConversionPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto converter = getTypeConverter();
|
|
auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
|
|
op, adaptor.getCondition(), op.getTrueDest(),
|
|
adaptor.getTrueDestOperands(), op.getFalseDest(),
|
|
adaptor.getFalseDestOperands());
|
|
addNamedAttrs(newOp, adaptor.getAttributes());
|
|
|
|
if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(),
|
|
*converter)))
|
|
return failure();
|
|
if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(),
|
|
*converter)))
|
|
return failure();
|
|
return success();
|
|
}
|
|
};
|
|
|
|
void populateCFPatterns(TritonGPUTypeConverter &typeConverter,
|
|
RewritePatternSet &patterns) {
|
|
MLIRContext *context = patterns.getContext();
|
|
patterns.add<CFCondBranchPattern, CFBranchPattern>(typeConverter, context);
|
|
}
|
|
//
|
|
|
|
class ConvertTritonToTritonGPU
|
|
: public ConvertTritonToTritonGPUBase<ConvertTritonToTritonGPU> {
|
|
public:
|
|
ConvertTritonToTritonGPU() = default;
|
|
// constructor with some parameters set explicitly.
|
|
ConvertTritonToTritonGPU(int numWarps, int threadsPerWarp, int numCTAs,
|
|
int computeCapability) {
|
|
this->numWarps = numWarps;
|
|
this->threadsPerWarp = threadsPerWarp;
|
|
this->numCTAs = numCTAs;
|
|
this->computeCapability = computeCapability;
|
|
}
|
|
|
|
void runOnOperation() override {
|
|
MLIRContext *context = &getContext();
|
|
ModuleOp mod = getOperation();
|
|
// type converter
|
|
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,
|
|
numCTAs);
|
|
TritonGPUConversionTarget target(*context, typeConverter);
|
|
// rewrite patterns
|
|
RewritePatternSet patterns(context);
|
|
// add rules
|
|
populateStdPatternsAndLegality(typeConverter, patterns, target);
|
|
populateArithPatternsAndLegality(typeConverter, patterns, target);
|
|
populateMathPatternsAndLegality(typeConverter, patterns, target);
|
|
populateTritonPatterns(typeConverter, patterns, numCTAs);
|
|
// TODO: can we use
|
|
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
|
|
populateSCFPatterns(typeConverter, patterns);
|
|
populateCFPatterns(typeConverter, patterns);
|
|
|
|
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
|
return signalPassFailure();
|
|
|
|
auto inti = llvm::APSInt(32, false);
|
|
auto i32_ty = IntegerType::get(mod->getContext(), 32);
|
|
|
|
mod->setAttr(
|
|
AttrNumWarpsName,
|
|
IntegerAttr::get(i32_ty, llvm::APInt(32, numWarps.getValue())));
|
|
mod->setAttr(
|
|
AttrNumThreadsPerWarp,
|
|
IntegerAttr::get(i32_ty, llvm::APInt(32, threadsPerWarp.getValue())));
|
|
|
|
mod->setAttr(AttrNumCTAsName,
|
|
IntegerAttr::get(i32_ty, llvm::APInt(32, numCTAs.getValue())));
|
|
|
|
mod->setAttr(AttrComputeCapabilityName,
|
|
IntegerAttr::get(
|
|
i32_ty, llvm::APInt(32, computeCapability.getValue())));
|
|
|
|
// update layouts
|
|
// broadcast src => multicast, dst => broadcasted
|
|
// if (failed(target.refineLayouts(mod, numWarps)))
|
|
// return signalPassFailure();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::triton::createConvertTritonToTritonGPUPass(int numWarps,
|
|
int threadsPerWarp,
|
|
int numCTAs,
|
|
int computeCapability) {
|
|
return std::make_unique<::ConvertTritonToTritonGPU>(
|
|
numWarps, threadsPerWarp, numCTAs, computeCapability);
|
|
}
|
|
|
|
std::unique_ptr<OperationPass<ModuleOp>>
|
|
mlir::triton::createConvertTritonToTritonGPUPass() {
|
|
return std::make_unique<::ConvertTritonToTritonGPU>();
|
|
}
|