Files
ROCm/lib/Dialect/TritonNvidiaGPU/Transforms/PlanCTA.cpp
Mehdi Amini 721897fcc4 upgrade llvm to b1115f8c (NFC) (#2403)
Co-authored-by: Thomas Raoux <thomas.raoux@openai.com>
Co-authored-by: Keren Zhou <kerenzhou@openai.com>
Co-authored-by: Phil Tillet <phil@openai.com>
2023-10-16 16:38:49 -07:00

1026 lines
37 KiB
C++

/*
* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include <queue>
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace {
using namespace mlir;
namespace ttg = ::mlir::triton::gpu;
namespace ttng = ::mlir::triton::nvidia_gpu;
// TODO: use ConvertLayoutOp
using CastOp = ::mlir::UnrealizedConversionCastOp;
unsigned getNumUsers(Value value) {
return std::distance(value.user_begin(), value.user_end());
}
Type replaceLayout(const Type &type, const Attribute &newLayout) {
Type curType = type;
auto ptrTy = curType.dyn_cast<triton::PointerType>();
if (ptrTy)
curType = ptrTy.getPointeeType();
if (auto tensorTy = curType.dyn_cast<RankedTensorType>())
curType = RankedTensorType::get(tensorTy.getShape(),
tensorTy.getElementType(), newLayout);
if (ptrTy)
curType = triton::PointerType::get(curType, ptrTy.getAddressSpace());
return curType;
}
Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef<int64_t> shape,
const ttg::CTALayoutAttr &newCTALayout) {
if (auto blockedLayout = layout.dyn_cast<ttg::BlockedEncodingAttr>()) {
return ttg::BlockedEncodingAttr::get(
layout.getContext(), shape, blockedLayout.getSizePerThread(),
blockedLayout.getOrder(), ttg::getNumWarpsPerCTA(layout), 32,
newCTALayout);
} else if (auto sliceLayout = layout.dyn_cast<ttg::SliceEncodingAttr>()) {
return ttg::SliceEncodingAttr::get(
layout.getContext(), sliceLayout.getDim(),
replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout));
} else {
// Other layouts are generated by passes after PlanCTAPass
assert(0 && "replaceCTALayout not implemented");
}
}
class CTAPlanner {
public:
CTAPlanner(ttng::ClusterInfo *clusterInfo_);
~CTAPlanner();
void run(triton::FuncOp &funcOp);
private:
CastOp markBackward(CastOp cast) const;
CastOp markForward(CastOp cast) const;
bool isBackward(CastOp cast) const;
bool isForward(CastOp cast) const;
void setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA);
bool processDot(triton::FuncOp &funcOp);
bool processReduce(triton::FuncOp &funcOp);
void processStoreLikeOps(triton::FuncOp &funcOp);
bool propagate(CastOp cast);
bool propagateBackward(CastOp cast);
bool propagateForward(CastOp cast);
void eraseCastOp(CastOp cast);
void eraseCastOpFromQueue(CastOp cast);
void eraseCastOpsFromQueue(llvm::ArrayRef<CastOp> casts);
void insertCasts(Operation *op, llvm::ArrayRef<Attribute> newOperandLayouts,
llvm::ArrayRef<Attribute> newResultLayouts);
void eliminateAdjacentCasts(CastOp cast0, CastOp cast1);
bool isLoadStoreOp(Operation *op) const;
bool processLoadStore(Operation *op, Attribute layout);
bool isElementwiseOp(Operation *op) const;
bool processElementwise(Operation *op, Attribute layout);
bool processConstant(arith::ConstantOp constant, Attribute layout);
bool processSplat(triton::SplatOp splat, Attribute layout);
bool processMakeRange(triton::MakeRangeOp makeRange, Attribute layout);
bool processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr,
Attribute layout);
bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout);
bool processExpandDimsBackward(triton::ExpandDimsOp expandDims,
Attribute newResultLayout);
bool processExpandDimsForward(triton::ExpandDimsOp expandDims,
Attribute newSrcLayout);
bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout,
CastOp cast);
bool processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout,
CastOp cast);
bool processIfOp(scf::IfOp ifOp, int index, const Type &newType);
bool processForOp(scf::ForOp forOp, int index, const Type &newType);
bool processIfOpBackward(scf::IfOp ifOp, CastOp cast);
bool processForOpBackward(scf::ForOp forOp, CastOp cast);
bool processBlockArgBackward(BlockArgument arg, CastOp cast);
bool processForOpForward(scf::ForOp forOp, CastOp cast);
bool processYieldOpForward(scf::YieldOp yieldOp, CastOp cast);
bool processOpFallback(Operation *op);
bool processMultiUsersBackward(Value input, CastOp cast);
bool processMultiUsersForward(Value output, CastOp cast);
// This flag indicates whether clusterInfo needs to be deleted in the
// destructor of CTAPlanner. The flag `ownInfo` is set to false when a
// non-null pointer to clusterInfo is passed to the constructor of CTAPlanner.
// Otherwise, a self-managed ClusterInfo will be created and the ownInfo will
// be set to true.
bool ownInfo;
ttng::ClusterInfo *clusterInfo;
bool tiled;
unsigned step;
unsigned stepUnchanged;
std::queue<CastOp> queue;
};
CTAPlanner::CTAPlanner(ttng::ClusterInfo *clusterInfo_)
: ownInfo(false), clusterInfo(clusterInfo_), tiled(false), step(0),
stepUnchanged(0) {
if (clusterInfo == nullptr) {
clusterInfo = new ttng::ClusterInfo();
ownInfo = true;
}
}
CTAPlanner::~CTAPlanner() {
if (ownInfo) {
delete clusterInfo;
// Actually not necessary but safer
ownInfo = false;
clusterInfo = nullptr;
}
}
void CTAPlanner::run(triton::FuncOp &funcOp) {
assert(!tiled && "Please create a new CTAPlanner");
static const unsigned maxSteps = 10000;
auto nextStep = [&]() {
++step;
assert(step < maxSteps && "Maximum number of steps exceeded");
};
processDot(funcOp);
nextStep();
processReduce(funcOp);
nextStep();
if (!tiled) {
processStoreLikeOps(funcOp);
nextStep();
}
while (!queue.empty()) {
CastOp cast = queue.front();
queue.pop();
bool changed = propagate(cast);
if (changed) {
stepUnchanged = 0;
} else {
queue.push(cast);
++stepUnchanged;
}
nextStep();
}
}
CastOp CTAPlanner::markBackward(CastOp cast) const {
cast->setAttr("direction", StringAttr::get(cast.getContext(), "backward"));
return cast;
}
CastOp CTAPlanner::markForward(CastOp cast) const {
cast->setAttr("direction", StringAttr::get(cast.getContext(), "forward"));
return cast;
}
bool CTAPlanner::isBackward(CastOp cast) const {
return cast->getAttrOfType<StringAttr>("direction") == "backward";
}
bool CTAPlanner::isForward(CastOp cast) const {
return cast->getAttrOfType<StringAttr>("direction") == "forward";
}
void CTAPlanner::setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA) {
assert(!tiled && "CTA tiling is already determinted");
assert(clusterInfo && "ClusterInfo pointer is null");
assert(CTAsPerCGA.size() <= 3 && "setTiling not implemented");
if (CTAsPerCGA.size() > 0)
clusterInfo->clusterDimX = CTAsPerCGA[0];
if (CTAsPerCGA.size() > 1)
clusterInfo->clusterDimY = CTAsPerCGA[1];
if (CTAsPerCGA.size() > 2)
clusterInfo->clusterDimZ = CTAsPerCGA[2];
tiled = true;
}
bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
// TODO: This is a naive implementation and should be refactored
auto getCTATiling = [](int64_t M, int64_t N, int64_t K,
unsigned numCTAs) -> std::pair<unsigned, unsigned> {
unsigned splitM = std::clamp<unsigned>(M / 64, 1, numCTAs);
unsigned splitN = numCTAs / splitM;
return {splitM, splitN};
};
funcOp.walk([&](triton::DotOp dot) {
MLIRContext *ctx = dot.getContext();
auto aTy = dot.getA().getType().cast<RankedTensorType>();
auto bTy = dot.getB().getType().cast<RankedTensorType>();
auto dTy = dot.getD().getType().cast<RankedTensorType>();
assert(aTy.getEncoding().isa<ttg::DotOperandEncodingAttr>() &&
bTy.getEncoding().isa<ttg::DotOperandEncodingAttr>() &&
dTy.getEncoding().isa<ttg::BlockedEncodingAttr>() &&
"PlanCTAPass should follow immediately after CoalescePass");
auto aLayout = aTy.getEncoding().cast<ttg::DotOperandEncodingAttr>();
auto bLayout = bTy.getEncoding().cast<ttg::DotOperandEncodingAttr>();
auto dLayout = dTy.getEncoding().cast<ttg::BlockedEncodingAttr>();
unsigned M = dTy.getShape()[0];
unsigned N = dTy.getShape()[1];
unsigned K = aTy.getShape()[1];
unsigned splitM, splitN;
std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout));
// FIXME: Should consider IR with more than one DotOps
setTiling({splitM, splitN, 1});
auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN},
{splitM, splitN}, {1, 0});
auto newDLayout = ttg::BlockedEncodingAttr::get(
ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(),
ttg::getNumWarpsPerCTA(dLayout), 32, newCTALayout);
auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(),
newDLayout, 0);
auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(),
newDLayout, 0);
insertCasts(dot.getOperation(), {newALayout, newBLayout, newDLayout},
{newDLayout});
});
return true;
}
bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
unsigned numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod);
funcOp.walk([&](triton::ReduceOp reduce) {
MLIRContext *context = reduce.getContext();
Value src = reduce.getOperands()[0];
unsigned axis = reduce.getAxis();
auto srcTy = src.getType().cast<RankedTensorType>();
auto srcShape = srcTy.getShape();
auto srcLayout = srcTy.getEncoding();
auto rank = srcShape.size();
auto order = ttg::getOrder(srcLayout);
auto sizePerThread = ttg::getSizePerThread(srcLayout);
auto CTAOrder = ttg::getCTAOrder(srcLayout);
llvm::SmallVector<unsigned> CTAsPerCGA(rank, 0);
unsigned remainingCTAs = numCTAs;
for (int i = rank - 1; i >= 0; --i) {
unsigned dim = order[i];
if (dim == axis) {
CTAsPerCGA[dim] = 1;
} else {
CTAsPerCGA[dim] = std::min<unsigned>(srcShape[dim] / sizePerThread[dim],
remainingCTAs);
remainingCTAs /= CTAsPerCGA[dim];
}
}
for (int i = rank - 1; i >= 0; --i) {
unsigned dim = order[i];
if (dim != axis) {
CTAsPerCGA[dim] *= remainingCTAs;
break;
}
}
llvm::SmallVector<unsigned> CTASplitNum = CTAsPerCGA;
// If numCTAs > 1 and the only dimension is the reduced dimension, after the
// above two for-loops, CTAsPerCGA = [0] and remainingCTAs = numCTAs. We set
// CTAsPerCGA[0] = numCTAs and keep CTASplitNum[0] = 1 to ensure that no
// cross-CTA reduction is required, although this will introduce duplicated
// calculation
if (remainingCTAs > 0)
CTAsPerCGA[order[rank - 1]] *= remainingCTAs;
auto CTALayout =
ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
if (!tiled)
setTiling(CTALayout.getCTAsPerCGA());
auto newSrcLayout = replaceCTALayout(srcLayout, srcShape, CTALayout);
auto newResultLayout =
ttg::SliceEncodingAttr::get(context, axis, newSrcLayout);
unsigned numOperands = reduce.getNumOperands();
SmallVector<Attribute> newSrcLayoutVec(numOperands, newSrcLayout);
SmallVector<Attribute> newResultLayoutVec(numOperands, newResultLayout);
insertCasts(reduce.getOperation(), newSrcLayoutVec, newResultLayoutVec);
});
return true;
}
void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
assert(!tiled && "CTA tiling is already determinted");
llvm::SmallVector<Operation *> stores;
funcOp.walk([&](Operation *op) {
if (llvm::isa<triton::StoreOp, triton::AtomicRMWOp, triton::AtomicCASOp>(
op))
stores.push_back(op);
});
assert(stores.size() > 0 && "Cannot find store-like ops");
ttg::CTALayoutAttr CTALayout;
for (Operation *store : stores) {
if (auto tensorTy =
store->getOperand(0).getType().dyn_cast<RankedTensorType>()) {
if (!tiled) {
// Use CTA tiling of the first store-like op as global CTA tiling
CTALayout = ttg::getCTALayout(tensorTy.getEncoding());
setTiling(CTALayout.getCTAsPerCGA());
}
auto newLayout = replaceCTALayout(tensorTy.getEncoding(),
tensorTy.getShape(), CTALayout);
processElementwise(store, newLayout);
}
}
// If all store-like ops are processing scalar values and no ReduceOp is
// found, we can conclude that this is an all-scalar computation, since
// ReduceOp is the only op that converts tensor values to scalar values.
if (!tiled)
setTiling({1, 1, 1});
}
bool CTAPlanner::propagate(CastOp cast) {
return isBackward(cast) ? propagateBackward(cast) : propagateForward(cast);
}
bool CTAPlanner::propagateBackward(CastOp cast) {
Value input = cast.getOperand(0);
Value output = cast.getResult(0);
unsigned numUsers = getNumUsers(input);
if (numUsers == 0) {
assert(0 && "Unreachable branch");
} else if (numUsers == 1) {
Type outTy = output.getType();
if (auto ptrTy = outTy.dyn_cast<triton::PointerType>())
outTy = ptrTy.getPointeeType();
Attribute layout = outTy.cast<RankedTensorType>().getEncoding();
Operation *op = input.getDefiningOp();
if (op == nullptr) {
assert(input.isa<BlockArgument>() &&
"Unexpected Value without defining op");
processBlockArgBackward(input.cast<BlockArgument>(), cast);
} else if (auto prevCast = llvm::dyn_cast<CastOp>(op)) {
eliminateAdjacentCasts(prevCast, cast);
} else if (isLoadStoreOp(op)) {
processLoadStore(op, layout);
} else if (isElementwiseOp(op)) {
processElementwise(op, layout);
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
processConstant(constant, layout);
} else if (auto splat = llvm::dyn_cast<triton::SplatOp>(op)) {
processSplat(splat, layout);
} else if (auto makeRange = llvm::dyn_cast<triton::MakeRangeOp>(op)) {
processMakeRange(makeRange, layout);
} else if (auto makeTensorPtr =
llvm::dyn_cast<triton::MakeTensorPtrOp>(op)) {
processMakeTensorPtr(makeTensorPtr, layout);
} else if (llvm::isa<triton::AdvanceOp>(op)) {
// ptr operand and result have the same layout, while other operands are
// scalar values
processElementwise(op, layout);
} else if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
processBroadcast(broadcast, layout);
} else if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
processExpandDimsBackward(expandDims, layout);
} else if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op)) {
processIfOpBackward(ifOp, cast);
} else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
processForOpBackward(forOp, cast);
} else if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(op)) {
return processConvertLayoutBackward(convertLayout, cast);
} else {
// Keep original layouts. This may result in a loss of performance.
return processOpFallback(op);
}
return true;
} else {
return processMultiUsersBackward(input, cast);
}
}
bool CTAPlanner::propagateForward(CastOp cast) {
Value input = cast.getOperand(0);
Value output = cast.getResult(0);
unsigned numUsers = getNumUsers(output);
if (numUsers == 0) {
cast.erase();
} else if (numUsers == 1) {
Type inTy = input.getType();
if (auto ptrTy = inTy.dyn_cast<triton::PointerType>())
inTy = ptrTy.getPointeeType();
Attribute layout = inTy.cast<RankedTensorType>().getEncoding();
Operation *op = *output.user_begin();
if (auto nextCast = llvm::dyn_cast<CastOp>(op)) {
eliminateAdjacentCasts(cast, nextCast);
} else if (isLoadStoreOp(op)) {
processLoadStore(op, layout);
} else if (isElementwiseOp(op)) {
processElementwise(op, layout);
} else if (llvm::isa<triton::AdvanceOp>(op)) {
// ptr operand and result have the same layout, while other operands are
// scalar values
processElementwise(op, layout);
} else if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(op)) {
return processConvertLayoutForward(convertLayout, cast);
} else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
processForOpForward(forOp, cast);
} else if (auto yieldOp = llvm::dyn_cast<scf::YieldOp>(op)) {
processYieldOpForward(yieldOp, cast);
} else {
// Keep original layouts. This may result in a loss of performance.
return processOpFallback(op);
}
} else {
processMultiUsersForward(output, cast);
}
return true;
}
void CTAPlanner::eraseCastOp(CastOp cast) {
Value output = cast.getResult(0);
assert(getNumUsers(output) == 0 &&
"Cannot erase CastOp because it is still in use");
cast.erase();
}
void CTAPlanner::eraseCastOpFromQueue(CastOp cast) {
eraseCastOpsFromQueue({cast});
}
void CTAPlanner::eraseCastOpsFromQueue(llvm::ArrayRef<CastOp> casts) {
llvm::DenseSet<CastOp> erased;
for (CastOp cast : casts) {
eraseCastOp(cast);
erased.insert(cast);
}
decltype(queue) tempQueue;
std::swap(queue, tempQueue);
// This is only a naive implementation. Should refactor with linked-list.
while (!tempQueue.empty()) {
auto cast = tempQueue.front();
tempQueue.pop();
if (!erased.contains(cast))
queue.push(cast);
}
}
void CTAPlanner::insertCasts(Operation *op,
llvm::ArrayRef<Attribute> newOperandLayouts,
llvm::ArrayRef<Attribute> newResultLayouts) {
assert(op->getNumOperands() == newOperandLayouts.size() &&
"NumOperands mismatched");
assert(op->getNumResults() == newResultLayouts.size() &&
"NumResults mismatched");
Location loc = op->getLoc();
OpBuilder builder(op->getContext());
builder.setInsertionPoint(op);
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
Value operand = op->getOperand(i);
auto operandTy = operand.getType();
if (triton::isTensorOrTensorPointerType(operandTy)) {
operandTy = replaceLayout(operandTy, newOperandLayouts[i]);
auto cast = markBackward(builder.create<CastOp>(loc, operandTy, operand));
op->setOperand(i, cast.getResult(0));
queue.push(cast);
}
}
builder.setInsertionPointAfter(op);
for (unsigned i = 0; i < op->getNumResults(); ++i) {
Value result = op->getResult(i);
auto resultTy = result.getType();
if (triton::isTensorOrTensorPointerType(resultTy)) {
resultTy = replaceLayout(resultTy, newResultLayouts[i]);
auto cast =
markForward(builder.create<CastOp>(loc, result.getType(), result));
result.setType(resultTy);
result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation());
queue.push(cast);
}
}
}
void CTAPlanner::eliminateAdjacentCasts(CastOp cast0, CastOp cast1) {
assert(cast0.getResult(0) == cast1.getOperand(0) &&
"The two casts are not adjacent");
assert(isForward(cast0) && isBackward(cast1) &&
"Expected pattern of adjacent casts: forward + backward");
Value input = cast0.getOperand(0);
Value output = cast1.getResult(0);
if (input.getType() == output.getType()) {
output.replaceAllUsesWith(input);
eraseCastOpsFromQueue({cast1, cast0});
} else {
OpBuilder builder(cast1.getOperation());
auto cvt = builder.create<ttg::ConvertLayoutOp>(cast1.getLoc(),
output.getType(), input);
output.replaceAllUsesWith(cvt.getResult());
eraseCastOpsFromQueue({cast1, cast0});
}
}
bool CTAPlanner::isLoadStoreOp(Operation *op) const {
return llvm::isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
triton::AtomicCASOp>(op);
}
bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
// Special logic for:
// LoadOp -> SliceLayout
// Transform to:
// LoadOp -> originalLayout -> ConvertLayout(DSmem) -> SliceLayout
if (auto sliceLayout = layout.dyn_cast<ttg::SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = ttg::getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] > 1) {
// Find an input or output value of LoadOp or StoreOp to get its layout
Value val =
op->getNumResults() > 0 ? op->getResult(0) : op->getOperand(0);
Attribute originalLayout =
val.getType().cast<RankedTensorType>().getEncoding();
// Insert casts using originalLayout. Adjacent casts will be eliminated
// and generate a ConvertLayoutOp with DSmem access
return processLoadStore(op, originalLayout);
}
}
auto CTALayout = ttg::getCTALayout(layout);
llvm::SmallVector<Attribute> newOperandLayouts;
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
auto type = op->getOperand(i).getType();
if (auto ptrTy = type.dyn_cast<triton::PointerType>())
type = ptrTy.getPointeeType();
auto tensorTy = type.cast<RankedTensorType>();
auto newLayout = replaceCTALayout(tensorTy.getEncoding(),
tensorTy.getShape(), CTALayout);
newOperandLayouts.push_back(newLayout);
}
llvm::SmallVector<Attribute> newResultLayouts;
for (unsigned i = 0; i < op->getNumResults(); ++i) {
auto type = op->getResult(i).getType();
if (auto ptrTy = type.dyn_cast<triton::PointerType>())
type = ptrTy.getPointeeType();
auto tensorTy = type.cast<RankedTensorType>();
auto newLayout = replaceCTALayout(tensorTy.getEncoding(),
tensorTy.getShape(), CTALayout);
newResultLayouts.push_back(newLayout);
}
insertCasts(op, newOperandLayouts, newResultLayouts);
return true;
}
bool CTAPlanner::isElementwiseOp(Operation *op) const {
if (llvm::isa<arith::AddFOp, arith::AddIOp, arith::AndIOp, arith::CeilDivSIOp,
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
arith::UIToFPOp, arith::XOrIOp>(op))
return true;
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
math::CountLeadingZerosOp, math::CountTrailingZerosOp,
math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
math::ExpM1Op, math::FloorOp, math::FmaOp, math::LogOp,
math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp,
math::RsqrtOp, math::SqrtOp, math::TanhOp>(op))
return true;
if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
triton::FpToFpOp, triton::AddPtrOp>(op))
return true;
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
return externElementwiseOp.getPure();
if (llvm::isa<ttg::CmpIOp, ttg::CmpFOp, ttg::SelectOp>(op))
return true;
return false;
}
bool CTAPlanner::processElementwise(Operation *op, Attribute layout) {
llvm::SmallVector<Attribute> newOperandLayouts(op->getNumOperands(), layout);
llvm::SmallVector<Attribute> newResultLayouts(op->getNumResults(), layout);
insertCasts(op, newOperandLayouts, newResultLayouts);
return true;
}
bool CTAPlanner::processConstant(arith::ConstantOp constant, Attribute layout) {
if (auto tensorTy =
constant.getResult().getType().dyn_cast<RankedTensorType>()) {
if (auto attr = constant.getValue().dyn_cast<SplatElementsAttr>()) {
auto newTensorTy = RankedTensorType::get(
tensorTy.getShape(), tensorTy.getElementType(), layout);
constant.setValueAttr(
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>()));
}
}
insertCasts(constant.getOperation(), {}, {layout});
return true;
}
bool CTAPlanner::processSplat(triton::SplatOp splat, Attribute layout) {
insertCasts(splat.getOperation(), {{}}, {layout});
return true;
}
bool CTAPlanner::processMakeRange(triton::MakeRangeOp makeRange,
Attribute layout) {
insertCasts(makeRange.getOperation(), {}, {layout});
return true;
}
bool CTAPlanner::processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr,
Attribute layout) {
// All inputs of `makeTensorPtr` are scalar types
llvm::SmallVector<Attribute> dummyInAttrs(makeTensorPtr.getNumOperands(), {});
insertCasts(makeTensorPtr.getOperation(), dummyInAttrs, {layout});
return true;
}
bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast,
Attribute layout) {
insertCasts(broadcast.getOperation(), {layout}, {layout});
return true;
}
bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims,
Attribute newResultLayout) {
auto newSrcLayout = ttg::SliceEncodingAttr::get(
newResultLayout.getContext(), expandDims.getAxis(), newResultLayout);
insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout});
return true;
}
bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims,
Attribute newSrcLayout) {
assert(0 && "processExpandDimsForward not implemented yet");
return true;
}
bool CTAPlanner::processConvertLayoutBackward(
ttg::ConvertLayoutOp convertLayout, CastOp cast) {
Value src = convertLayout.getSrc();
Value result = convertLayout.getResult();
assert(getNumUsers(result) == 1 &&
"Expect to call processMultiUsersBackward first");
result.replaceAllUsesWith(src);
convertLayout.erase();
queue.push(cast);
return true;
}
bool CTAPlanner::processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout,
CastOp cast) {
Value src = convertLayout.getSrc();
Value result = convertLayout.getResult();
assert(getNumUsers(src) == 1 &&
"Expect to call processMultiUsersForward first");
src.setType(result.getType());
result.replaceAllUsesWith(src);
convertLayout.erase();
queue.push(cast);
return true;
}
bool CTAPlanner::processIfOp(scf::IfOp ifOp, int index, const Type &newType) {
// Check index
assert(index < ifOp.getNumResults() && "Invalid result index of IfOp");
assert(index < ifOp.thenYield().getNumOperands() &&
"Invalid operand index of YieldOp");
assert(index < ifOp.elseYield().getNumOperands() &&
"Invalid operand index of YieldOp");
Location loc = ifOp.getLoc();
OpBuilder builder(ifOp.getContext());
// Insert forward cast after ifOp
Value result = ifOp.getResult(index);
builder.setInsertionPointAfter(ifOp.getOperation());
auto newCast =
markForward(builder.create<CastOp>(loc, result.getType(), result));
result.setType(newType);
result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
queue.push(newCast);
// Insert backward casts before yield
for (scf::YieldOp yield : {ifOp.thenYield(), ifOp.elseYield()}) {
Value yieldSrc = yield.getOperand(index);
builder.setInsertionPoint(yield.getOperation());
newCast = markBackward(builder.create<CastOp>(loc, newType, yieldSrc));
yield->setOperand(index, newCast.getResult(0));
queue.push(newCast);
}
return true;
}
bool CTAPlanner::processForOp(scf::ForOp forOp, int index,
const Type &newType) {
Block *body = forOp.getBody();
auto yield = llvm::cast<scf::YieldOp>(forOp.getBody()->getTerminator());
// Check index
assert(index + forOp.getNumControlOperands() < forOp.getNumOperands() &&
"Invalid operand index of ForOp");
assert(index + forOp.getNumInductionVars() < body->getNumArguments() &&
"Invalid block arg index of ForOp");
assert(index < yield.getNumOperands() && "Invalid operand index of YieldOp");
assert(index < forOp.getNumResults() && "Invalid result index of IfOp");
Location loc = forOp.getLoc();
OpBuilder builder(forOp.getContext());
// Insert backward cast before forOp
OpOperand &operand =
forOp->getOpOperand(index + forOp.getNumControlOperands());
builder.setInsertionPoint(forOp.getOperation());
auto newCast =
markBackward(builder.create<CastOp>(loc, newType, operand.get()));
operand.set(newCast.getResult(0));
queue.push(newCast);
// Insert forward cast after block arg
Value arg = body->getArgument(index + forOp.getNumInductionVars());
builder.setInsertionPointToStart(body);
newCast = markForward(builder.create<CastOp>(loc, arg.getType(), arg));
arg.setType(newType);
arg.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
queue.push(newCast);
// Insert backward cast before yield
Value yieldSrc = yield.getOperand(index);
builder.setInsertionPoint(yield.getOperation());
newCast = markBackward(builder.create<CastOp>(loc, newType, yieldSrc));
yield->setOperand(index, newCast.getResult(0));
queue.push(newCast);
// Insert forward cast after forOp
Value result = forOp.getResult(index);
builder.setInsertionPointAfter(forOp.getOperation());
newCast = markForward(builder.create<CastOp>(loc, result.getType(), result));
result.setType(newType);
result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
queue.push(newCast);
return true;
}
int findResultIndex(Operation *op, Value result) {
for (int i = 0; i < op->getNumResults(); ++i)
if (op->getResult(i) == result)
return i;
assert(0 && "Invalid index of op result");
return -1;
}
bool CTAPlanner::processIfOpBackward(scf::IfOp ifOp, CastOp cast) {
int index = findResultIndex(ifOp.getOperation(), cast.getOperand(0));
auto newType = cast.getResult(0).getType();
return processIfOp(ifOp, index, newType);
}
bool CTAPlanner::processForOpBackward(scf::ForOp forOp, CastOp cast) {
int index = findResultIndex(forOp.getOperation(), cast.getOperand(0));
auto newType = cast.getResult(0).getType();
return processForOp(forOp, index, newType);
}
bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) {
if (auto forOp = llvm::dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp())) {
int index = int(arg.getArgNumber()) - forOp.getNumInductionVars();
auto newType = cast.getResult(0).getType();
return processForOp(forOp, index, newType);
} else {
assert(0 && "Unexpected parent op of block argument");
return true;
}
}
bool CTAPlanner::processForOpForward(scf::ForOp forOp, CastOp cast) {
int index = cast.getResult(0).use_begin()->getOperandNumber() -
forOp.getNumControlOperands();
auto newType = cast.getOperand(0).getType();
return processForOp(forOp, index, newType);
}
bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) {
int index = cast.getResult(0).use_begin()->getOperandNumber();
auto newType = cast.getOperand(0).getType();
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
return processIfOp(ifOp, index, newType);
else if (auto forOp = llvm::dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
return processForOp(forOp, index, newType);
else
assert(0 && "Unexpected parent op of YieldOp");
return true;
}
bool CTAPlanner::processOpFallback(Operation *op) {
Location loc = op->getLoc();
OpBuilder builder(op->getContext());
builder.setInsertionPoint(op);
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
Value operand = op->getOperand(i);
auto operandTy = operand.getType();
if (triton::isTensorOrTensorPointerType(operandTy)) {
auto cast = markBackward(builder.create<CastOp>(loc, operandTy, operand));
op->setOperand(i, cast.getResult(0));
queue.push(cast);
}
}
builder.setInsertionPointAfter(op);
for (unsigned i = 0; i < op->getNumResults(); ++i) {
Value result = op->getResult(i);
auto resultTy = result.getType();
if (triton::isTensorOrTensorPointerType(resultTy)) {
auto cast = markForward(builder.create<CastOp>(loc, resultTy, result));
result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation());
queue.push(cast);
}
}
return true;
}
bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) {
Location loc = input.getLoc();
OpBuilder builder(input.getContext());
llvm::DenseMap<Type, llvm::SmallVector<CastOp>> typeToIndices;
for (OpOperand &operand : input.getUses()) {
auto brotherCast = llvm::dyn_cast<CastOp>(operand.getOwner());
if (!brotherCast) {
if (stepUnchanged <= queue.size())
return false;
builder.setInsertionPoint(operand.getOwner());
brotherCast = markBackward(
builder.create<CastOp>(loc, cast.getResult(0).getType(), input));
auto newCast = markForward(builder.create<CastOp>(
loc, input.getType(), brotherCast.getResult(0)));
operand.set(newCast.getResult(0));
queue.push(brotherCast);
queue.push(newCast);
}
auto type = brotherCast.getResult(0).getType();
typeToIndices[type].push_back(brotherCast);
}
bool first = true;
for (auto it : typeToIndices) {
Type &type = it.first;
llvm::SmallVector<CastOp> &casts = it.second;
Value newInput = input;
if (!first) {
if (Operation *defOp = input.getDefiningOp()) {
builder.setInsertionPointAfter(defOp);
Operation *clonedOp = builder.clone(*defOp);
newInput = clonedOp->getResult(0);
} else {
assert(0 && "Layout conflict for block arg"); // TODO
}
}
first = false;
if (Operation *defOp = newInput.getDefiningOp()) {
builder.setInsertionPointAfter(defOp);
} else {
assert(newInput.isa<BlockArgument>() &&
"Unexpected Value without defining op");
builder.setInsertionPointToStart(
newInput.cast<BlockArgument>().getOwner());
}
auto newCast = markBackward(builder.create<CastOp>(loc, type, newInput));
queue.push(newCast);
auto newResult = newCast.getResult(0);
for (CastOp &brotherCast : casts) {
brotherCast.getResult(0).replaceAllUsesWith(newResult);
eraseCastOpFromQueue(brotherCast);
}
}
return true;
}
bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) {
Value castSrc = cast.getOperand(0);
Location loc = cast.getLoc();
OpBuilder builder(cast.getContext());
builder.setInsertionPointAfter(cast.getOperation());
while (!castResult.use_empty()) {
auto newCast =
markForward(builder.create<CastOp>(loc, castResult.getType(), castSrc));
castResult.use_begin()->set(newCast.getResult(0));
queue.push(newCast);
}
eraseCastOp(cast);
return true;
}
struct PlanCTAPass : public TritonGPUPlanCTAPassBase<PlanCTAPass> {
PlanCTAPass(ttng::ClusterInfo *clusterInfo_ = nullptr)
: clusterInfo(clusterInfo_) {}
void runOnOperation() override {
ModuleOp mod = getOperation();
// Skip PlanCTAPass when numCTAs == 1
if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1)
return;
mod.walk([&](triton::FuncOp funcOp) {
CTAPlanner planner(clusterInfo);
planner.run(funcOp);
// FIXME: Clone funcOp so that the IR change can be identified after
// PlanCTAPass. Without this, the change after PlanCTAPass will not be
// displayed when MLIR_ENABLE_DUMP=1. This is not reasonable and should
// be fixed later.
OpBuilder builder(funcOp);
builder.clone(*funcOp.getOperation());
funcOp.erase();
});
}
ttng::ClusterInfo *clusterInfo;
};
} // namespace
std::unique_ptr<Pass>
mlir::createTritonNvidiaGPUPlanCTAPass(ttng::ClusterInfo *clusterInfo) {
return std::make_unique<PlanCTAPass>(clusterInfo);
}
/* TODO
* - Use ConvertLayoutOp instead of UnrealizedConversionCastOp.
* - Move PlanCTAPass to the front of CoalescePass.
* - Design better tiling strategy for DotOp and ReduceOp.
* - Consider cases where there are more than one DotOps.
* - Use better data structure for erasing CastOps from queue (linked list?).
* - Process eliminable CastOps in higher priority.
* - Fix the clone func bug in PlanCTAPass::runOnOperation.
* - Add some comments to introduce the overall idea of this pass.
* - Add some lit tests for this pass.
*/