Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117

Conflicts:
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Dialect/TritonGPU/IR/Dialect.cpp
	python/setup.py
	python/test/unit/language/assert_helper.py
	python/test/unit/operators/test_flash_attention.py
	python/test/unit/runtime/test_subproc.py
	python/triton/compiler/compiler.py
	python/triton/language/semantic.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/tutorials/03-matrix-multiplication.py
	python/tutorials/05-layer-norm.py
	python/tutorials/06-fused-attention.py
	python/tutorials/11-grouped-gemm.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-17 20:42:12 +00:00
179 changed files with 10116 additions and 6835 deletions

View File

@@ -52,7 +52,7 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return dotLayout.getTotalElemsPerThread(shape, eltTy);
} else {
assert(0 && "getElemsPerThread not implemented");
llvm::report_fatal_error("getElemsPerThread not implemented");
return 0;
}
}
@@ -68,7 +68,7 @@ SmallVector<unsigned> getElemsPerThread(Attribute layout,
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
return mfmaLayout.getElemsPerThread(shape, eltTy);
} else {
assert(0 && "getElemsPerThread not implemented");
llvm::report_fatal_error("getElemsPerThread not implemented");
return SmallVector<unsigned>();
}
}
@@ -129,7 +129,7 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
return threadsPerWarp;
}
assert(0 && "getThreadsPerWarp not implemented");
llvm::report_fatal_error("getThreadsPerWarp not implemented");
return {};
}
@@ -180,15 +180,17 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
assert(parentWarpsPerCTA.size() == 2 &&
"getWarpsPerCTA only implemented for 2D slice layout");
assert(parentWarpsPerCTA.size() == 2 ||
parentWarpsPerCTA[sliceLayout.getDim()] == 1 &&
"getWarpsPerCTA only implemented for 2D slice layout or the "
"slice dim must have 1 warp in the parent layout");
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
return warpsPerCTA;
}
assert(0 && "getWarpsPerCTA not implemented");
llvm::report_fatal_error("getWarpsPerCTA not implemented");
return {};
}
@@ -264,7 +266,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
} else if (opIdx == 1) {
return {4, 1};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
return {};
}
} else if (parentLayout.isa<MfmaEncodingAttr>()) {
@@ -278,12 +280,13 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
return {};
}
} else {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
llvm::report_fatal_error(
"DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
return {};
}
} else {
assert(0 && "getSizePerThread not implemented");
llvm::report_fatal_error("getSizePerThread not implemented");
return {};
}
}
@@ -337,6 +340,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
threads = {8 * mmaLayout.getWarpsPerCTA()[0],
4 * mmaLayout.getWarpsPerCTA()[1]};
} else
<<<<<<< HEAD
assert(0 && "Unimplemented usage of MmaEncodingAttr");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
if (mfmaLayout.getNonKDim() == 32) {
@@ -346,8 +350,11 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
4 * mfmaLayout.getWarpsPerCTA()[1]};
}
=======
llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr");
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
} else {
assert(0 && "Unimplemented usage of getThreadsPerCTA");
llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA");
}
return threads;
@@ -381,11 +388,15 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
return {16 * mmaLayout.getWarpsPerCTA()[0],
instrShape[1] * mmaLayout.getWarpsPerCTA()[1]};
}
<<<<<<< HEAD
assert(0 && "Unexpected MMA layout version found");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
auto nonKDim = mfmaLayout.getNonKDim();
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
=======
llvm::report_fatal_error("Unexpected MMA layout version found");
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -401,7 +412,7 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
} else if (opIdx == 1) {
return {16, parentShapePerCTATile[1]};
} else {
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
}
} else if (auto parentMfmaLayout =
parentLayout.dyn_cast<MfmaEncodingAttr>()) {
@@ -416,15 +427,20 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
}
} else {
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
llvm::report_fatal_error(
"DotOperandEncodingAttr non-MmaEncodingAttr parent not "
"supported yet");
}
} else {
assert(0 && "Unimplemented usage of getShapePerCTATile");
llvm::report_fatal_error("Unimplemented usage of getShapePerCTATile");
}
return shape;
}
bool isExpensiveView(Type srcType, Type dstType) {
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
}
namespace {
/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
@@ -473,7 +489,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
sharedLayout.getOrder().end());
} else {
assert(0 && "Unimplemented usage of getOrder");
llvm::report_fatal_error("Unimplemented usage of getOrder");
}
return {};
};
@@ -494,7 +510,7 @@ CTALayoutAttr getCTALayout(Attribute layout) {
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
return sharedLayout.getCTALayout();
else
assert(0 && "Unimplemented usage of getCTALayout");
llvm::report_fatal_error("Unimplemented usage of getCTALayout");
return {};
}
@@ -522,7 +538,8 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
* in the branch where layout is an instance of SliceEncodingAttr. This is
* inconvenient but safe.
*/
assert(0 && "getCTAsPerCGA for SliceEncodingAttr is not well-defined");
llvm::report_fatal_error(
"getCTAsPerCGA for SliceEncodingAttr is not well-defined");
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
ref = mmaLayout.getCTALayout().getCTAsPerCGA();
#ifdef USE_ROCM
@@ -534,7 +551,7 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
ref = sharedLayout.getCTALayout().getCTAsPerCGA();
else
assert(0 && "Unimplemented usage of getCTAsPerCGA");
llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA");
return SmallVector<unsigned>(ref.begin(), ref.end());
}
@@ -589,7 +606,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
ref = sharedLayout.getCTALayout().getCTAOrder();
} else {
assert(0 && "Unimplemented usage of getCTAOrder");
llvm::report_fatal_error("Unimplemented usage of getCTAOrder");
}
return SmallVector<unsigned>(ref.begin(), ref.end());
}
@@ -642,9 +659,9 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getNumWarpsPerCTA(dotLayout.getParent());
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
assert(0 && "Cannot get numWarps from SharedEncodingAttr");
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
else
assert(0 && "Unimplemented usage of getNumWarpsPerCTA");
llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA");
return product<unsigned>(warpsPerCTA);
}
@@ -665,7 +682,7 @@ unsigned getNumCTAs(Attribute layout) {
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
else
assert(0 && "Unimplemented usage of getNumCTAs");
llvm::report_fatal_error("Unimplemented usage of getNumCTAs");
return product<unsigned>(CTAsPerCGA);
}
@@ -1779,13 +1796,15 @@ struct CanonicalizeConvertFromView
Operation *arg = op->getOperand(0).getDefiningOp();
if (!arg)
return mlir::failure();
auto convert = dyn_cast<ConvertLayoutOp>(arg);
if (!convert)
return failure();
if (isExpensiveView(convert.getOperand().getType(), op.getType()))
return failure();
// view(convert) -> view
if (auto convert = dyn_cast<ConvertLayoutOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), convert.getOperand());
return mlir::success();
}
return mlir::failure();
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
convert.getOperand());
return mlir::success();
}
};
@@ -1831,6 +1850,8 @@ struct CanonicalizeConvertFromConvert
return mlir::failure();
// cvt(view) -> view
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
if (isExpensiveView(view.getOperand().getType(), op.getType()))
return failure();
rewriter.replaceOpWithNewOp<triton::ViewOp>(
op, op->getResult(0).getType(), view.getResult());
return mlir::success();

View File

@@ -70,10 +70,15 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, {filter});
auto slices = multiRootGetSlice(dotOp, {filter});
for (Operation *op : slices)
if (isa<tt::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
if (isa<tt::DotOp>(op) && (op != dotOp)) {
if (shape[0] >= shape[1]) {
return {(unsigned)numWarps, 1};
} else {
return {1, (unsigned)numWarps};
}
}
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
@@ -133,8 +138,18 @@ class BlockedToMMA : public mlir::RewritePattern {
mlir::TypeID::get<arith::ArithDialect>());
}
// finds the first different value bitwidth in the chain of
// shape-preserving unary ops that x depends on
// Finds the first different bitwidth in the chain of shape-preserving
// unary ops that x depends on.
// There are two primary scenarios:
// (1) Upcasting: A sequence such as loading an fp16, followed by arithmetic
// operations, then bitcasting to fp32, and finally computing in fp32.
// (2) Downcasting: This might involve loading an fp32, performing arithmetic
// operations, bitcasting to fp16, and finally computing in fp16.
// In the upcasting scenario, element reordering converts the original
// elements distribution to the order of higher precision primitives. As a
// result, kwidth can be the bitwidth of the lower precision primitive.
// Conversely, in the downcasting scenario, no reordering is performed,
// making it directory use the lower precision primitive.
static int computeOrigBitWidth(Value x) {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
@@ -143,11 +158,17 @@ class BlockedToMMA : public mlir::RewritePattern {
opt.omitBlockArguments = true;
opt.filter = bwdFilter;
getBackwardSlice(x, &slice, opt);
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
if (RankedTensorType argTy = arg.getType().dyn_cast<RankedTensorType>())
origBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
for (auto op : slice) {
if (Value arg = op->getOperand(0))
if (RankedTensorType argTy =
arg.getType().dyn_cast<RankedTensorType>()) {
auto argBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
if (argBitWidth != origBitWidth) {
origBitWidth = std::min<int>(origBitWidth, argBitWidth);
break;
}
}
}
return origBitWidth;
}

View File

@@ -5,7 +5,10 @@ add_mlir_dialect_library(TritonGPUTransforms
DecomposeConversions.cpp
OptimizeDotOperands.cpp
OptimizeEpilogue.cpp
Pipeline.cpp
OptimizeThreadLocality.cpp
Pipeliner/MatmulLoopPipeline.cpp
Pipeliner/PipelineExpander.cpp
Pipeliner/SoftwarePipeliner.cpp
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp

View File

@@ -0,0 +1,312 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include <memory>
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
using namespace mlir;
class TritonGPUOptimizeThreadLocalityPass
: public TritonGPUOptimizeThreadLocalityBase<
TritonGPUOptimizeThreadLocalityPass> {
void runOnOperation() override {
ModuleOp mod = getOperation();
DenseSet<triton::ReduceOp> reduceOps;
mod.walk([&](triton::ReduceOp reduce) -> void {
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
auto rank = srcType.getShape().size();
auto srcEncoding = srcType.getEncoding();
auto reductionOp = getReductionOp(reduce);
if (!reductionOp ||
!isa<arith::AddFOp, arith::MaximumFOp, arith::MinimumFOp,
arith::MulFOp>(reductionOp.value()))
return;
// TODO: relax this restriction
if (!(srcEncoding.isa<triton::gpu::BlockedEncodingAttr>() && rank > 1))
return;
for (auto operand : reduce->getOperands()) {
auto def = operand.getDefiningOp();
if (!isa<triton::LoadOp>(def))
return;
}
auto elemsPerThread =
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
// Not worth applying this optimization if there is only one element per
// thread on the reduction axis
if (elemsPerThread == 1)
return;
if (!reduce->hasOneUse())
return;
Operation *user = *(reduce->getUsers().begin());
if (!user->hasOneUse())
return;
OpOperand &yieldOpOperand = *(user->getUses().begin());
auto yieldOp = dyn_cast<scf::YieldOp>(yieldOpOperand.getOwner());
if (!yieldOp)
return;
auto operandNumber = yieldOpOperand.getOperandNumber();
Block *block = reduce->getBlock();
Operation *parentOp = block->getParentOp();
auto forOp = dyn_cast<scf::ForOp>(parentOp);
if (!forOp)
return;
auto argNum = yieldOpOperand.getOperandNumber();
auto oldAccum = forOp.getInitArgs()[argNum];
auto cstOp = dyn_cast<arith::ConstantOp>(oldAccum.getDefiningOp());
if (!cstOp)
return;
reduceOps.insert(reduce);
});
for (auto reduce : reduceOps) {
OpBuilder builder(reduce);
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
auto srcShape = srcType.getShape();
auto srcEncoding = srcType.getEncoding();
assert(srcEncoding.isa<triton::gpu::BlockedEncodingAttr>() &&
"Thread locality optimization only supports blocked encoding");
auto blocked = srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto elemsPerThread =
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
auto rank = srcShape.size();
// create new layouts
auto blocked3d = getThreadLocalityOptimizedEncoding(reduce);
auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce);
auto viewOpTensorType = RankedTensorType::get(
viewOpTensorShape, srcType.getElementType(), blocked3d);
auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank,
blocked3d);
// Get forOp
assert(reduce->hasOneUse());
OpOperand &use = *(reduce->getUses().begin());
auto operandNumber = use.getOperandNumber();
auto oldUpdate = use.getOwner();
assert(oldUpdate->getNumOperands() == 2);
auto accumOperandNumber = (operandNumber == 0) ? 1 : 0;
auto accumOperand = oldUpdate->getOperand(accumOperandNumber);
assert(accumOperand.isa<BlockArgument>());
auto blockArg = accumOperand.dyn_cast<BlockArgument>();
auto blockArgNum = blockArg.getArgNumber();
auto forOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
// get oldAccum
auto oldAccum =
forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()];
// get old loop user
Value loopResult =
forOp.getResult(blockArgNum - forOp.getNumInductionVars());
assert(loopResult.hasOneUse());
OpOperand &loopUse = *(loopResult.getUses().begin());
Operation *loopUser = loopUse.getOwner();
// get old loop yield
auto oldYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
// create newAccum initialization
auto newAccum =
createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d);
// create new loop by copying the old for op signature and appending
// newAccum to the block arguments
auto newLoop = replaceForOpWithNewSignature(
builder, forOp, ValueRange{newAccum->getResult(0)});
// create thread local reduction (also adds viewOps)
auto newReduce = createReduce(builder, reduce, viewOpTensorType);
// create new accum update
auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate);
// create new yield
auto newYield = createYield(builder, newLoop, oldYield,
newUpdate->getResult(0), blockArgNum);
// create post loop reduction on the original reduce axis
auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce);
// add convert_layout to get back to original layout, the result layout
// should now match the layout of the old accumulator (%cst)
Type destType = loopResult.getType();
auto cvtLayout = createConvertLayout(builder, destType, newReduce2);
// incorporate the original accumulator value into the final result
auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate,
cvtLayout, oldAccum);
// Replace the old loop user with the final result
loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0));
// cleanup
oldYield.erase();
forOp.erase();
}
};
private:
std::optional<Operation *> getReductionOp(triton::ReduceOp reduce) const {
auto numRegions = reduce->getNumRegions();
if (numRegions != 1)
return std::nullopt;
Region &region = reduce->getRegion(0);
auto numBlocks = region.getBlocks().size();
if (numBlocks != 1)
return std::nullopt;
Block &block = region.front();
auto blockWithoutTerminator = block.without_terminator();
auto blockSizeWithoutTerminator = std::distance(
blockWithoutTerminator.begin(), blockWithoutTerminator.end());
if (blockSizeWithoutTerminator != 1)
return std::nullopt;
Operation *op = &block.front();
return std::optional<Operation *>(op);
}
Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder,
Operation *oldUpdate,
Operation *cvtLayout,
Value oldAccum) const {
builder.setInsertionPointAfter(cvtLayout);
IRMapping mapping;
mapping.map(oldUpdate->getOperand(0), oldAccum);
mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0));
auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping);
return finalOp;
}
Operation *createConvertLayout(OpBuilder &builder, Type destType,
Operation *newReduce) const {
builder.setInsertionPointAfter(newReduce);
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
newReduce->getLoc(), destType, newReduce->getResult(0));
return newCvt;
}
Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop,
triton::ReduceOp &reduce) const {
auto resultIndex =
loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars();
auto newLoopResult = loop.getResult(resultIndex);
builder.setInsertionPointAfter(loop);
IRMapping mapping;
mapping.map(*(reduce.getOperands().begin()), newLoopResult);
auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping);
return newReduce2;
}
Operation *createYield(OpBuilder &builder, scf::ForOp &loop,
scf::YieldOp &oldYield, Value newUpdate,
int oldAccumBlockArgNum) const {
builder.setInsertionPoint(oldYield);
SmallVector<Value> yieldValues = llvm::to_vector(oldYield.getOperands());
yieldValues[oldAccumBlockArgNum - 1] =
loop.getBody()->getArgument(oldAccumBlockArgNum);
yieldValues.push_back(newUpdate);
auto newYield =
builder.create<scf::YieldOp>(oldYield.getLoc(), yieldValues);
return newYield;
}
Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop,
Operation *newReduce, Operation *oldUpdate) const {
auto blockArgNum = loop.getBody()->getNumArguments() - 1;
auto newArg = loop.getBody()->getArgument(blockArgNum);
builder.setInsertionPointAfter(newReduce);
IRMapping mapping;
mapping.map(oldUpdate->getOperand(0), newArg);
mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0));
auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping);
return newUpdate;
}
Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce,
Type viewOpTensorType) const {
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
auto rank = srcType.getShape().size();
builder.setInsertionPointAfter(reduce);
IRMapping mapping;
for (auto operand : reduce.getOperands()) {
auto viewOp = builder.create<triton::ViewOp>(reduce.getLoc(),
viewOpTensorType, operand);
mapping.map(operand, viewOp);
}
auto newReduce = cloneWithInferType(builder, &(*reduce), mapping);
newReduce->setAttr("axis", builder.getI32IntegerAttr(rank));
auto typeInfer = dyn_cast<InferTypeOpInterface>(newReduce);
if (typeInfer) {
SmallVector<Type, 1> newTypes;
auto success = typeInfer.inferReturnTypes(
newReduce->getContext(), newReduce->getLoc(),
newReduce->getOperands(), newReduce->getAttrDictionary(),
newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes);
if (succeeded(success)) {
for (size_t i = 0; i < newTypes.size(); i++)
newReduce->getResult(i).setType(newTypes[i]);
}
}
return newReduce;
}
Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce,
Value &oldAccum, SmallVector<int64_t> &shape,
Attribute &slice2d) const {
// Drop the last dimension (thread locality dimension)
SmallVector<int64_t> accumShape(shape.begin(), shape.end() - 1);
auto elemType =
oldAccum.getType().cast<RankedTensorType>().getElementType();
// Create tensor type for the new accumulator
auto accumType = RankedTensorType::get(accumShape, elemType, slice2d);
// Create new accumulator
builder.setInsertionPointAfter(oldAccum.getDefiningOp());
auto reductionOp = getReductionOp(reduce);
assert(reductionOp && "Processing a reduce that is not supported!");
auto neutralVal = mlir::arith::getNeutralElement(reductionOp.value());
assert(neutralVal && "Could not find neutral value for reduction op!");
auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value());
auto newAccum = builder.create<arith::ConstantOp>(oldAccum.getLoc(),
accumType, denseAttr);
return newAccum;
}
SmallVector<int64_t>
getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const {
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
auto srcShape = srcType.getShape();
auto rank = srcShape.size();
auto elemsPerThread =
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
auto viewOpTensorShape = insertValue(srcShape, rank, 1);
viewOpTensorShape[reduce.getAxis()] /= elemsPerThread;
viewOpTensorShape[rank] = elemsPerThread;
return viewOpTensorShape;
}
Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const {
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
auto rank = srcType.getShape().size();
auto srcEncoding = srcType.getEncoding();
auto blocked = srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>();
auto sizePerThread3d =
insertValue(blocked.getSizePerThread(), rank,
blocked.getSizePerThread()[reduce.getAxis()]);
sizePerThread3d[reduce.getAxis()] = 1;
auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1);
auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1);
auto order3d = insertValue(blocked.getOrder(), 0, rank);
auto ctasPerCGA3d =
insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1);
auto ctasSplitNum3d =
insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1);
auto ctaOrder3d =
insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank);
auto ctaLayout3d = triton::gpu::CTALayoutAttr::get(
reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d);
auto blocked3d = triton::gpu::BlockedEncodingAttr::get(
reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d,
order3d, ctaLayout3d);
return blocked3d;
}
template <typename T>
SmallVector<T> insertValue(ArrayRef<T> vec, unsigned index, int value) const {
SmallVector<T> res(vec.begin(), vec.end());
res.insert(res.begin() + index, static_cast<T>(value));
return res;
}
};
std::unique_ptr<Pass> mlir::createTritonGPUOptimizeThreadLocalityPass() {
return std::make_unique<TritonGPUOptimizeThreadLocalityPass>();
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,814 @@
#include "PipelineExpander.h"
#include "Schedule.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/Support/Debug.h"
#define int_attr(num) builder.getI64IntegerAttr(num)
using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttng = mlir::triton::nvidia_gpu;
// TODO: We can extra some helpers into common utilities once we add more
// schedules.
/// Replace the yield with a new one with the given operands appended.
static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
// Fix up the yield op.
Operation *yieldOp = forOp.getBody()->getTerminator();
SmallVector<Value> operands(yieldOp->getOperands().begin(),
yieldOp->getOperands().end());
operands.append(newOperands.begin(), newOperands.end());
OpBuilder builder(yieldOp);
builder.create<scf::YieldOp>(yieldOp->getLoc(), operands);
yieldOp->erase();
}
static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx) {
OpBuilder builder(forOp);
// Replace the load with insert/extract slice.
builder.setInsertionPoint(loadOp);
Location loc = loadOp.getLoc();
auto insertOp = builder.create<ttg::InsertSliceAsyncOp>(
loc, alloc.getType(), loadOp.getPtr(), alloc, insertIdx, loadOp.getMask(),
loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(),
loadOp.getIsVolatile(), /*axis*/ 0);
auto commmit = builder.create<ttg::AsyncCommitGroupOp>(loc);
// Extract part.
auto allocType = alloc.getType().cast<RankedTensorType>();
RankedTensorType sliceType = RankedTensorType::get(
{allocType.getShape()[1], allocType.getShape()[2]},
allocType.getElementType(), allocType.getEncoding());
auto extract = builder.create<ttg::ExtractSliceOp>(
loc, sliceType, insertOp.getResult(),
SmallVector<OpFoldResult>{extractIdx, int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1), int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
Operation *user = *loadOp.getResult().getUsers().begin();
auto convertLayout = llvm::cast<ttg::ConvertLayoutOp>(user);
auto newCvt = builder.create<ttg::ConvertLayoutOp>(
convertLayout->getLoc(), convertLayout.getType(), extract.getResult());
convertLayout->replaceAllUsesWith(newCvt->getResults());
convertLayout->erase();
loadOp.erase();
// Fix up the yield op.
appendToYield(forOp, {insertOp});
}
static void createTMALoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx, Value phase) {
OpBuilder builder(forOp);
Location loc = loadOp.getLoc();
auto CTALayout = ttg::CTALayoutAttr::get(loadOp.getContext(),
/*CTAsPerCGA*/ {1},
/*CTASplitNum*/ {1},
/*CTAOrder*/ {0});
auto sharedEncoding = ttg::SharedEncodingAttr::get(loadOp.getContext(), 1, 1,
1, {0}, CTALayout, false);
int64_t numBuffers = alloc.getType().cast<RankedTensorType>().getShape()[0];
auto mBarriersTy = RankedTensorType::get(
{numBuffers}, builder.getIntegerType(64), sharedEncoding);
// Allocate an array of mbarrier objects outside the loop.
Value barrierArray =
builder.create<ttng::AllocMBarrierOp>(loc, mBarriersTy, 1);
// extract the barrier and emit arriver/copy/wait/extract code sequence.
builder.setInsertionPoint(loadOp);
auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3);
Value barrier = builder.create<ttng::ExtractMBarrierOp>(
loc, mBarTy, barrierArray, insertIdx);
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
Value threadId = builder.create<ttng::GetThreadIdOp>(loc);
Value pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
threadId, zero);
auto loadTy = loadOp.getType().dyn_cast<RankedTensorType>();
auto loadShape = loadTy.getShape();
auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding());
auto shapePerSlice = ttg::getShapePerCTA(CTASplitNum, loadShape);
auto elemTy = loadTy.getElementType();
unsigned elems = std::accumulate(shapePerSlice.begin(), shapePerSlice.end(),
1, std::multiplies{});
elems *= (elemTy.getIntOrFloatBitWidth() / 8);
builder.create<ttng::MBarrierArriveOp>(loc, barrier, pred,
/*remoteCtaId*/ nullptr,
/*trackAsyncOp*/ false, elems);
auto allocType = alloc.getType().cast<RankedTensorType>();
auto insertOp = builder.create<ttng::InsertSliceAsyncV2Op>(
loc, allocType, loadOp.getPtr(), alloc,
/*index*/ insertIdx, barrier, loadOp.getMask(), loadOp.getOther(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(),
/*axis*/ 0);
RankedTensorType sliceType = RankedTensorType::get(
{allocType.getShape()[1], allocType.getShape()[2]},
allocType.getElementType(), allocType.getEncoding());
auto extract = builder.create<mlir::triton::gpu::ExtractSliceOp>(
loc, sliceType, insertOp.getResult(),
SmallVector<OpFoldResult>{extractIdx, int_attr(0), int_attr(0)},
SmallVector<OpFoldResult>{int_attr(1), int_attr(sliceType.getShape()[0]),
int_attr(sliceType.getShape()[1])},
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
Value barrierWait = builder.create<ttng::ExtractMBarrierOp>(
loc, mBarTy, barrierArray, extractIdx);
builder.create<ttng::MBarrierWaitOp>(loc, barrierWait, phase);
Operation *user = *loadOp.getResult().getUsers().begin();
auto convertLayout = llvm::cast<ttg::ConvertLayoutOp>(user);
auto newCvt = builder.create<ttg::ConvertLayoutOp>(
convertLayout->getLoc(), convertLayout.getType(), extract.getResult());
convertLayout->replaceAllUsesWith(newCvt->getResults());
convertLayout->erase();
loadOp.erase();
// Fix up the yield op.
appendToYield(forOp, {insertOp});
}
/// Create an async load equivalent to the given load.
static void createAsyncLoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
Value insertIdx, Value extractIdx, Value phase) {
if (isLoadFromTensorPtr(loadOp)) {
createTMALoad(forOp, loadOp, alloc, insertIdx, extractIdx, phase);
} else {
createAsyncCopy(forOp, loadOp, alloc, insertIdx, extractIdx);
}
}
// Return the transitive use of the load which is a dot operand.
static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
// We only pipeline loads that have one covert_layout (to dot_op) use
// TODO: lift this constraint in the future
bool isCandidate = false;
if (!loadOp.getResult().hasOneUse())
return Value();
Operation *use = *loadOp.getResult().getUsers().begin();
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
auto tensorType =
convertLayout.getResult().getType().cast<RankedTensorType>();
if (auto sharedEnc =
tensorType.getEncoding().dyn_cast<ttg::SharedEncodingAttr>()) {
if (sharedEnc.getHasLeadingOffset()) {
// MMA V3 case.
auto newOrder = sharedEnc.getOrder();
auto ty = loadOp.getType().cast<RankedTensorType>();
auto oldOrder = ttg::getOrder(ty.getEncoding());
if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) {
// The operand of MMAv3 is in SharedEncoding and it's order should
// not be changed after FuseTranspositions Pass. So we only pipeline
// the load if the order of the loaded BlockedEncoding is the same
// as the order of the SharedEncoding it is converted to.
// TODO: remove this constraint once the LoadOp supports transpose
// fusion
hasMMAV3 = true;
return convertLayout.getResult();
}
}
}
}
// Advance to the first conversion as long as the use resides in shared
// memory and it has a single use itself
while (use) {
if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse())
break;
auto tensorType = use->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>())
break;
use = *use->getResult(0).getUsers().begin();
}
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType =
convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
return convertLayout.getResult();
}
}
}
return Value();
}
namespace {
struct LoadDotOperand {
LoadDotOperand(tt::LoadOp load, Value dotOperand)
: load(load), dotOperand(dotOperand) {}
tt::LoadOp load;
Value dotOperand;
};
} // namespace
/// Collect loads to pipeline. Return success if we can pipeline this loop
static void collectOpsToPipeline(scf::ForOp forOp,
SmallVectorImpl<LoadDotOperand> &ops,
bool &hasMMAV3) {
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
// We cannot use forOp.walk(...) here because we only want to visit the
// operations in the loop body block. Nested blocks are handled separately.
for (Operation &op : forOp) {
if (auto loadOp = dyn_cast<tt::LoadOp>(&op)) {
bool candidate = false;
if (isLoadFromTensorPtr(loadOp)) {
// Map to TMA load.
candidate = true;
} else {
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
if (auto mask = loadOp.getMask())
vec =
std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy || tensorTy.getRank() < 2)
continue;
auto ty =
tensorTy.getElementType().cast<tt::PointerType>().getPointeeType();
unsigned width = vec * ty.getIntOrFloatBitWidth();
// We do not pipeline all loads for the following reasons:
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16.
// 2. It's likely that pipling small loads won't offer much performance
// improvement and may even hurt performance by increasing register
// pressure.
if (width >= 32)
candidate = true;
}
if (!candidate)
continue;
Value dotOperand = loadDotOperand(loadOp, hasMMAV3);
if (!dotOperand)
continue;
ops.emplace_back(loadOp, dotOperand);
}
}
}
// Create an allocation that can old distance number of loadOp shapes.
static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, Value dotOperand,
unsigned distance) {
OpBuilder builder(forOp);
auto ty = loadOp.getType().cast<RankedTensorType>();
if (!loadOp.getResult().hasOneUse())
return Value();
Attribute sharedEnc;
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
if (auto dotOpEnc =
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>()) {
auto convertLayout = dotOperand.getDefiningOp<ttg::ConvertLayoutOp>();
bool needTrans = dyn_cast_or_null<tt::TransOp>(
convertLayout->getOperand(0).getDefiningOp());
unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth();
sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans);
} else {
// MMAv3
sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(),
ttg::getOrder(ty.getEncoding()),
CTALayout, ty.getElementType());
}
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
bufferShape.insert(bufferShape.begin(), distance);
Type allocType =
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
Value alloc = builder.create<mlir::triton::gpu::AllocTensorOp>(
loadOp.getLoc(), allocType);
return alloc;
}
// Convert load ops into their asyn version and apply multi-buffering based on
// the number of stages.
static void createAsynOps(scf::ForOp &forOp, ArrayRef<LoadDotOperand> loads,
int numStages, bool hasMMAV3) {
struct AsyncLoad {
AsyncLoad(tt::LoadOp loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {}
tt::LoadOp loadOp;
Value alloc;
};
int numBuffers = numStages - 1;
// For MMAv3 we need an extra buffer as this is assumed in the wgmma
// pipelining post-processing.
// TODO: Improve modeling of wgmma pipelining.
if (hasMMAV3)
numBuffers++;
SmallVector<AsyncLoad> asyncLoads;
SmallVector<Value> newOperands;
bool needsMbarrierPhase = false;
bool needsAsyncWait = false;
for (const LoadDotOperand &loadOperand : loads) {
tt::LoadOp loadOp = loadOperand.load;
Value dotOperand = loadOperand.dotOperand;
Value alloc = createAlloc(forOp, loadOp, dotOperand, numBuffers);
assert(alloc && "Failed to create alloc for the async load.");
newOperands.push_back(alloc);
asyncLoads.emplace_back(loadOp, alloc);
if (isLoadFromTensorPtr(loadOp))
needsMbarrierPhase = true;
else
needsAsyncWait = true;
}
OpBuilder builder(forOp);
Location loc = forOp.getLoc();
// Create two new counters to index into the allocs.
Value minusOne = builder.create<arith::ConstantIntOp>(loc, -1, 32);
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32);
Value insertIdx = minusOne;
Value extractIdx = minusOne;
Value numBuffersVal =
builder.create<arith::ConstantIntOp>(loc, numBuffers, 32);
newOperands.push_back(insertIdx);
newOperands.push_back(extractIdx);
Value phase;
if (needsMbarrierPhase) {
phase = builder.create<arith::ConstantIntOp>(loc, 0, 1);
newOperands.push_back(phase);
}
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
// Patch the loop to add the new loop carried dependencies.
scf::ForOp newForOp =
replaceForOpWithNewSignature(builder, forOp, newOperands);
forOp.erase();
forOp = newForOp;
for (int i = 0; i < asyncLoads.size(); i++) {
asyncLoads[i].alloc = newForOp.getBody()->getArgument(newOperandIndex + i);
}
insertIdx =
newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size());
extractIdx =
newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size() + 1);
// Create two counters for the insert and extract indices to avoid creating
// long liverange.
builder.setInsertionPoint(asyncLoads.front().loadOp);
insertIdx = builder.create<arith::AddIOp>(loc, insertIdx, one);
Value cndIns = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
insertIdx, numBuffersVal);
insertIdx = builder.create<arith::SelectOp>(loc, cndIns, insertIdx, zero);
extractIdx = builder.create<arith::AddIOp>(loc, extractIdx, one);
Value cndExt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
extractIdx, numBuffersVal);
extractIdx = builder.create<arith::SelectOp>(loc, cndExt, extractIdx, zero);
if (needsMbarrierPhase) {
phase = newForOp.getBody()->getArgument(newOperandIndex +
asyncLoads.size() + 2);
Value oneI1 = builder.create<arith::ConstantIntOp>(loc, 1, 1);
Value nextPhase = builder.create<arith::XOrIOp>(loc, phase, oneI1);
phase = builder.create<arith::SelectOp>(loc, cndExt, phase, nextPhase);
}
bool firstLoad = true;
for (AsyncLoad &asyncLoad : asyncLoads) {
createAsyncLoad(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx,
extractIdx, phase);
firstLoad = false;
}
// Insert a waitOp after the first async copy. This does make the assumption
// that the wait will be scheduled in a different stage that all the async
// copy but we cannot guarantee that one wait is enough otherwise.
for (auto &op : forOp.getBody()->without_terminator()) {
if (isa<ttg::InsertSliceAsyncOp>(op)) {
OpBuilder builder(op.getContext());
builder.setInsertionPointAfter(&op);
builder.create<ttg::AsyncWaitOp>(op.getLoc(), 0);
break;
}
}
SmallVector<Value> newYieldOperands = {insertIdx, extractIdx};
if (needsMbarrierPhase)
newYieldOperands.push_back(phase);
// Patch the yield with the updated counters.
appendToYield(forOp, newYieldOperands);
}
// Combine the current mask with the given predicate.
static Value getPredMask(RewriterBase &rewriter, Type typeLike,
Value currentMask, Value pred) {
Type maskType = tt::getI1SameShape(typeLike);
Location loc = pred.getLoc();
Value mask = pred;
if (maskType.isa<RankedTensorType>()) {
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
}
if (currentMask) {
mask = rewriter.create<arith::AndIOp>(loc, mask, currentMask);
}
return mask;
}
// Function to mask operations during scheduling.
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
Value pred) {
OpBuilder::InsertionGuard guard(rewriter);
if (mlir::isMemoryEffectFree(op))
return op;
if (isa<ttg::AsyncCommitGroupOp>(op))
return op;
if (isa<ttg::AsyncWaitOp>(op))
return op;
if (auto insertOp = dyn_cast<ttg::InsertSliceAsyncOp>(op)) {
rewriter.setInsertionPoint(insertOp);
Value mask = getPredMask(rewriter, insertOp.getSrc().getType(),
insertOp.getMask(), pred);
insertOp.getMaskMutable().assign(mask);
return op;
}
if (auto insertOp = dyn_cast<ttng::InsertSliceAsyncV2Op>(op)) {
rewriter.setInsertionPoint(insertOp);
Value mask = getPredMask(
rewriter,
insertOp.getSrc().getType().cast<tt::PointerType>().getPointeeType(),
insertOp.getMask(), pred);
insertOp.getMaskMutable().assign(mask);
return op;
}
if (auto arriveOp = dyn_cast<ttng::MBarrierArriveOp>(op)) {
rewriter.setInsertionPoint(arriveOp);
Value mask = getPredMask(rewriter, rewriter.getIntegerType(1),
arriveOp.getPred(), pred);
arriveOp.getPredMutable().assign(mask);
return op;
}
if (isa<ttng::MBarrierWaitOp>(op)) {
return op;
}
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
rewriter.setInsertionPoint(loadOp);
Value mask = getPredMask(rewriter, loadOp.getPtr().getType(),
loadOp.getMask(), pred);
loadOp.getMaskMutable().assign(mask);
return op;
}
assert("don't know how to predicate this op" && false);
return op;
}
static void setWaitNum(Operation *op,
mlir::triton::PipeliningOption::PipelinerPart part,
unsigned iteration, unsigned numLoadsInStage) {
if (auto waitOp = dyn_cast<ttg::AsyncWaitOp>(op)) {
waitOp.setNum(numLoadsInStage);
}
}
/// Helper to recursively add dependencies to the same stage.
static void addDep(Operation *op, DenseSet<Operation *> &deps,
bool includeArg = true,
DenseSet<Operation *> *filter = nullptr) {
if (filter && filter->count(op))
return;
if (!deps.insert(op).second)
return;
for (Value operand : op->getOperands()) {
Value v = operand;
llvm::SmallDenseSet<Value> seen;
while (auto arg = v.dyn_cast<BlockArgument>()) {
if (!includeArg)
break;
if (!seen.insert(v).second)
break;
if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) {
auto yieldOp = op->getBlock()->getTerminator();
v = yieldOp->getOperand(arg.getArgNumber() - 1);
continue;
}
break;
}
Operation *defOp = v.getDefiningOp();
if (defOp && defOp->getBlock() == op->getBlock()) {
addDep(defOp, deps, includeArg, filter);
}
}
}
// Add operations to the shedule with the given stage based on the filter
// function.
static void addOps(scf::ForOp forOp, int stage,
std::vector<std::pair<Operation *, unsigned>> &schedule,
std::function<bool(Operation *)> filter) {
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!filter(&op))
continue;
schedule.emplace_back(&op, stage);
}
}
// create the schedule for a matmul loop. This is ad hoc based on how we know
// matmul loops should be pipelined and is not a generic scheduler.
static std::vector<std::pair<Operation *, unsigned>>
createSchedule(scf::ForOp forOp, int numStages, bool prefetchExtract) {
SmallVector<Operation *> insertOps;
SmallVector<Operation *> extractOps;
// Find the insert/extract ops that will go respectively in stage 0 and stage
// `numStages - 2`. All the other operations will go in stage `numStages - 1`.
for (Operation &op : forOp.getBody()->without_terminator()) {
if (isa<ttg::InsertSliceAsyncOp, ttg::AsyncCommitGroupOp,
ttng::MBarrierArriveOp, ttng::InsertSliceAsyncV2Op>(op))
insertOps.emplace_back(&op);
if (prefetchExtract) {
if (isa<ttg::ExtractSliceOp, ttg::AsyncWaitOp>(op))
extractOps.emplace_back(&op);
}
}
DenseSet<Operation *> insertAndDeps;
for (Operation *op : insertOps) {
addDep(op, insertAndDeps, false);
}
// Find depenencies with distance of 1.
SmallVector<Operation *> distanceOneUsers;
for (Operation *op : insertAndDeps) {
for (Value operand : op->getOperands()) {
if (auto arg = operand.dyn_cast<BlockArgument>()) {
if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) {
auto yieldOp = op->getBlock()->getTerminator();
Value v = yieldOp->getOperand(arg.getArgNumber() - 1);
Operation *defOp = v.getDefiningOp();
if (defOp && insertAndDeps.count(defOp) == 0) {
distanceOneUsers.push_back(defOp);
}
}
}
}
}
// Schedule loads with a distance of 1 in stage 0
for (Operation *op : distanceOneUsers) {
if (isa<tt::LoadOp>(op)) {
addDep(op, insertAndDeps, true);
}
}
// For the rest of the ops we can move then into stage 1 so that they can be
// closer to their uses.
DenseSet<Operation *> stage1deps;
for (Operation *op : distanceOneUsers) {
if (!isa<tt::LoadOp>(op)) {
addDep(op, stage1deps, true, &insertAndDeps);
}
}
DenseSet<Operation *> extractAndDeps;
for (Operation *op : extractOps) {
addDep(op, extractAndDeps, true, &insertAndDeps);
}
std::vector<std::pair<Operation *, unsigned>> schedule;
// Schedule stage `numStage - 1` first.
addOps(forOp, numStages - 1, schedule, [&](Operation *op) {
return insertAndDeps.count(op) == 0 && stage1deps.count(op) == 0 &&
extractAndDeps.count(op) == 0;
});
// Schedule some dependencies with distance of 1 into stage 1 to reduce
// pressure.
addOps(forOp, 1, schedule,
[&](Operation *op) { return stage1deps.count(op); });
// Then Schedule stage 0.
addOps(forOp, 0, schedule,
[&](Operation *op) { return insertAndDeps.count(op); });
// Finally schedule the extract ops in stage `numStage - 2` so that they get
// pre-fetched and play well with pretech pass.
addOps(forOp, numStages - 2, schedule,
[&](Operation *op) { return extractAndDeps.count(op); });
return schedule;
}
bool mlir::triton::preProcessLoopAndGetSchedule(
scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
// 1. First collect "interesting" operations with a stage where to schedule
// them. This gives a coarse scheduling for the loop.
SmallVector<LoadDotOperand> loads;
bool hasMMAV3 = false;
collectOpsToPipeline(forOp, loads, hasMMAV3);
if (loads.empty())
return false;
bool hasAsynCp = llvm::any_of(loads, [](LoadDotOperand &load) {
return !isLoadFromTensorPtr(load.load);
});
// 2. Convert the loads into async loads and create the allocs.
createAsynOps(forOp, loads, numStages, hasMMAV3);
// 3. Create the final schedule for the kernel loop. This will dictate the
// stages and order of operations to the pipeline expander.
std::vector<std::pair<Operation *, unsigned>> schedule =
createSchedule(forOp, numStages, /*prefetchExtract=*/!hasMMAV3);
// 4. Fill out the pipeline options.
options.getScheduleFn =
[schedule](scf::ForOp forOp,
std::vector<std::pair<Operation *, unsigned>> &s) {
s = std::move(schedule);
};
options.peelEpilogue = false;
options.predicateFn = predicateOp;
options.supportDynamicLoops = true;
unsigned numLoadsInStage = (numStages - 2) * loads.size();
options.annotateFn =
[numLoadsInStage](Operation *op,
mlir::triton::PipeliningOption::PipelinerPart part,
unsigned iteration) {
return setWaitNum(op, part, iteration, numLoadsInStage);
};
if (hasAsynCp) {
// Insert a wait 0 after the loop
OpBuilder builder(forOp);
builder.setInsertionPointAfter(forOp);
builder.create<ttg::AsyncWaitOp>(forOp.getLoc(), 0);
}
return true;
}
/// MMA V3 post-processing.
static bool selfDepend(tt::DotOp dotOp, scf::ForOp forOp,
Operation **firstUse) {
std::function<bool(Value, int, scf::ForOp)> dependOn =
[&dependOn](Value v, int argId, scf::ForOp forOp) {
auto op = v.getDefiningOp();
if (isa<BlockArgument>(v)) {
auto iterArgs = forOp.getRegionIterArgs();
auto iter = std::find(iterArgs.begin(), iterArgs.end(), v);
if (iter != iterArgs.end())
return std::distance(iterArgs.begin(), iter) == argId;
} else {
if (!op)
return false;
for (auto operand : op->getOperands()) {
if (dependOn(operand, argId, forOp))
return true;
}
}
return false;
};
auto result = dotOp.getResult();
auto yieldOp = forOp.getBody()->getTerminator();
int argIdx = -1;
auto iter = std::find(yieldOp->getOperands().begin(),
yieldOp->getOperands().end(), result);
if (iter != yieldOp->getOperands().end())
argIdx = std::distance(yieldOp->getOperands().begin(), iter);
if (argIdx == -1)
return false;
for (auto operand : dotOp.getOperands()) {
if (dependOn(operand, argIdx, forOp)) {
auto iterArgs = forOp.getRegionIterArgs();
*firstUse = iterArgs[argIdx].use_begin().getUser();
return true;
}
}
return false;
}
static void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp,
bool hasDotWait0) {
if (hasDotWait0) {
dotWaitOp->erase();
}
}
void mlir::triton::asyncLaunchDots(scf::ForOp forOp) {
Block *loop = forOp.getBody();
auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) {
if (!op)
return -1l;
auto lastOp = op;
while (op->getBlock()->getParentOp() != forOp) {
lastOp = op;
op = op->getBlock()->getParentOp();
}
return std::distance(lastOp->getBlock()->getParent()->begin(),
lastOp->getBlock()->getIterator());
};
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
/// dots to be pipelined
bool hasSyncDot = false;
bool hasDotWait0 = false;
SmallVector<tt::DotOp> allDots;
SmallVector<tt::DotOp> dots;
SmallVector<unsigned> resultNeedSync;
for (Operation &op : *loop) {
if (auto dotWaitOp = dyn_cast<tt::nvidia_gpu::DotWaitOp>(&op)) {
auto attr = dotWaitOp->getAttrOfType<IntegerAttr>("pendings");
auto pendingCount = attr.getInt();
if (pendingCount == 0)
hasDotWait0 = true;
}
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
allDots.push_back(dotOp);
}
}
for (Operation &op : *loop) {
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
if (auto resEnc = resTy.getEncoding().dyn_cast<ttg::MmaEncodingAttr>()) {
if (resEnc && resEnc.isHopper()) {
auto dot = dotOp.getResult();
bool valid = true;
// all users of dot should be scf.yield
if (!dot.hasOneUse())
valid = false;
if (!isa<scf::YieldOp>(*dot.getUsers().begin()))
valid = false;
Operation *firstUse = nullptr;
auto depend = selfDepend(dotOp, forOp, &firstUse);
bool selfDirectDepend = (dotOp == firstUse);
for (auto tempInAll : allDots) {
auto iter = std::find(dots.begin(), dots.end(), tempInAll);
if (iter != dots.end())
continue;
auto db = getBlockNumInFor(tempInAll, forOp);
auto fb = getBlockNumInFor(firstUse, forOp);
if (db < fb ||
(db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse)))
hasSyncDot = true;
}
auto CArg = dotOp.getOperand(2);
if (!(selfDirectDepend ||
(depend && !selfDirectDepend && hasSyncDot)) ||
!CArg.hasOneUse())
valid = false;
if (valid) {
dots.push_back(dotOp);
resultNeedSync.push_back(
dotOp->getUses().begin()->getOperandNumber());
}
}
}
}
}
// Early stop: no need to continue if there is no valid dot in the loop.
if (dots.empty())
return;
OpBuilder builder(forOp);
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
// wgmma ops by one stage.
// This is needed to prevent shared memory inputs to be overriden before the
// operation is completed.
// TODO: merge this with the rest of the pipelining transformation and look at
// a better representation for async dots.
tt::DotOp lastDot = dots.back();
auto loc = lastDot.getLoc();
builder.setInsertionPointAfter(lastDot);
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
lastDot.getLoc(), lastDot.getResult(), dots.size());
// 1. replace Dot with DotAsync
for (size_t idx = 0; idx < dots.size(); ++idx) {
tt::DotOp dotOp = dots[idx];
builder.setInsertionPoint(dotOp);
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dotOp.replaceAllUsesWith(dotAsync.getResult());
dotOp->erase();
}
hasDotWait0 = hasDotWait0 || hasSyncDot;
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
SmallVector<Value> waitOperands;
for (int i = 0; i < resultNeedSync.size(); ++i) {
Value result = forOp->getResult(resultNeedSync[i]);
if (result.use_empty())
continue;
waitOperands.push_back(result);
}
if (!waitOperands.empty()) {
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
waitOperands, 0);
for (int i = 0; i < resultNeedSync.size(); ++i) {
Value result = forOp->getResult(resultNeedSync[i]);
result.replaceAllUsesExcept(dotWait.getResult(i), dotWait);
}
}
// 3. potentially remove redundant dot_wait after dot_async if having mutiple
// DotOp in the loop
removeExtraWait(dotWait, hasDotWait0);
}

View File

@@ -0,0 +1,704 @@
//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements loop software pipelining
//
//===----------------------------------------------------------------------===//
// Fork of upstream pipeliner. This will be merged upstream once things are
// stable. Modifications so far are:
// -Bug fix for def with a distance of 1 scheduled in stage 0.
// -Support dynamic loops and predicate operations in the prologue.
// -Support for non-index type for induction variable.
// -Support source with distance of 1 used multiple stages later.
// -Fix bug when a value yield is used outside the loop and the value def is not
// in the last stage. If we are not peeling the epilgue we need to remap the
// output correctly.
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
#include "PipelineExpander.h"
#define DEBUG_TYPE "triton-loop-pipelining"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
using namespace mlir;
using namespace mlir::scf;
using namespace mlir::triton;
namespace {
/// Helper to keep internal information during pipelining transformation.
struct LoopPipelinerInternal {
/// Coarse liverange information for ops used across stages.
struct LiverangeInfo {
unsigned lastUseStage = 0;
unsigned defStage = 0;
};
protected:
ForOp forOp;
unsigned maxStage = 0;
DenseMap<Operation *, unsigned> stages;
std::vector<Operation *> opOrder;
Value ub;
Value lb;
Value step;
bool dynamicLoop;
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
bool peelEpilogue;
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
// When peeling the kernel we generate several version of each value for
// different stage of the prologue. This map tracks the mapping between
// original Values in the loop and the different versions
// peeled from the loop.
DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
/// Assign a value to `valueMapping`, this means `val` represents the version
/// `idx` of `key` in the epilogue.
void setValueMapping(Value key, Value el, int64_t idx);
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
public:
/// Initalize the information for the given `op`, return true if it
/// satisfies the pre-condition to apply pipelining.
bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options);
/// Emits the prologue, this creates `maxStage - 1` part which will contain
/// operations from stages [0; i], where i is the part index.
void emitPrologue(RewriterBase &rewriter);
/// Gather liverange information for Values that are used in a different stage
/// than its definition.
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
scf::ForOp createKernelLoop(
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
RewriterBase &rewriter,
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
/// Emits the pipelined kernel. This clones loop operations following user
/// order and remaps operands defined in a different stage as their use.
LogicalResult createKernel(
scf::ForOp newForOp,
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
RewriterBase &rewriter);
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
/// operations from stages [i; maxStage], where i is the part index.
llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
};
bool LoopPipelinerInternal::initializeLoopInfo(
ForOp op, const triton::PipeliningOption &options) {
LDBG("Start initializeLoopInfo");
forOp = op;
ub = forOp.getUpperBound();
lb = forOp.getLowerBound();
step = forOp.getStep();
dynamicLoop = true;
auto upperBoundCst = ub.getDefiningOp<arith::ConstantIndexOp>();
auto lowerBoundCst = lb.getDefiningOp<arith::ConstantIndexOp>();
auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>();
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
if (!options.supportDynamicLoops) {
LDBG("--dynamic loop not supported -> BAIL");
return false;
}
} else {
int64_t ubImm = upperBoundCst.value();
int64_t lbImm = lowerBoundCst.value();
int64_t stepImm = stepCst.value();
int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm);
if (numIteration > maxStage) {
dynamicLoop = false;
} else if (!options.supportDynamicLoops) {
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
return false;
}
}
peelEpilogue = options.peelEpilogue;
predicateFn = options.predicateFn;
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
LDBG("--no epilogue or predicate set -> BAIL");
return false;
}
std::vector<std::pair<Operation *, unsigned>> schedule;
options.getScheduleFn(forOp, schedule);
if (schedule.empty()) {
LDBG("--empty schedule -> BAIL");
return false;
}
opOrder.reserve(schedule.size());
for (auto &opSchedule : schedule) {
maxStage = std::max(maxStage, opSchedule.second);
stages[opSchedule.first] = opSchedule.second;
opOrder.push_back(opSchedule.first);
}
// All operations need to have a stage.
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!stages.contains(&op)) {
op.emitOpError("not assigned a pipeline stage");
LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
return false;
}
}
// Currently, we do not support assigning stages to ops in nested regions. The
// block of all operations assigned a stage should be the single `scf.for`
// body block.
for (const auto &[op, stageNum] : stages) {
(void)stageNum;
if (op == forOp.getBody()->getTerminator()) {
op->emitError("terminator should not be assigned a stage");
LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
return false;
}
if (op->getBlock() != forOp.getBody()) {
op->emitOpError("the owning Block of all operations assigned a stage "
"should be the loop body block");
LDBG("--the owning Block of all operations assigned a stage "
"should be the loop body block: "
<< *op << " -> BAIL");
return false;
}
}
// Only support loop carried dependency with a distance of 1. This means the
// source of all the scf.yield operands needs to be defined by operations in
// the loop.
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
[this](Value operand) {
Operation *def = operand.getDefiningOp();
return !def || !stages.contains(def);
})) {
LDBG("--only support loop carried dependency with a distance of 1 -> BAIL");
return false;
}
annotateFn = options.annotateFn;
return true;
}
/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
/// operands of nested ops that:
/// 1) aren't defined within the new op or
/// 2) are block arguments.
static Operation *
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
function_ref<void(OpOperand *newOperand)> callback) {
Operation *clone = rewriter.clone(*op);
for (OpOperand &operand : clone->getOpOperands())
callback(&operand);
clone->walk([&](Operation *nested) {
for (OpOperand &operand : nested->getOpOperands()) {
Operation *def = operand.get().getDefiningOp();
if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
callback(&operand);
}
});
return clone;
}
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
// Initialize the iteration argument to the loop initiale values.
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
setValueMapping(arg, operand.get(), 0);
}
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
Location loc = forOp.getLoc();
for (int64_t i = 0; i < maxStage; i++) {
Value predicate;
if (dynamicLoop) {
Type t = ub.getType();
// pred = ub > lb + (i * step)
Value iv = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, i))));
predicate = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
iv, ub);
}
// special handling for induction variable as the increment is implicit.
// iv = lb + i * step
Type t = lb.getType();
Value iv = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(loc,
rewriter.getIntegerAttr(t, i))));
setValueMapping(forOp.getInductionVar(), iv, i);
for (Operation *op : opOrder) {
if (stages[op] > i)
continue;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
Value replacement = it->second[i - stages[op]];
newOperand->set(replacement);
}
});
if (predicate) {
newOp = predicateFn(rewriter, newOp, predicate);
assert(newOp && "failed to predicate op.");
}
if (annotateFn)
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(destId), newOp->getResult(destId),
i - stages[op]);
// If the value is a loop carried dependency update the loop argument
// mapping.
for (OpOperand &operand : yield->getOpOperands()) {
if (operand.get() != op->getResult(destId))
continue;
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(destId), i - stages[op] + 1);
}
}
}
}
}
std::pair<Operation *, int64_t>
LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
int64_t distance = 0;
if (auto arg = dyn_cast<BlockArgument>(value)) {
if (arg.getOwner() != forOp.getBody())
return {nullptr, 0};
// Ignore induction variable.
if (arg.getArgNumber() == 0)
return {nullptr, 0};
distance++;
value =
forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
}
Operation *def = value.getDefiningOp();
if (!def)
return {nullptr, 0};
return {def, distance};
}
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
LoopPipelinerInternal::analyzeCrossStageValues() {
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
for (Operation *op : opOrder) {
unsigned stage = stages[op];
auto analyzeOperand = [&](OpOperand &operand) {
auto [def, distance] = getDefiningOpAndDistance(operand.get());
if (!def)
return;
auto defStage = stages.find(def);
if (defStage == stages.end() || defStage->second == stage ||
defStage->second == stage + distance)
return;
assert(stage > defStage->second);
LiverangeInfo &info = crossStageValues[operand.get()];
info.defStage = defStage->second;
info.lastUseStage = std::max(info.lastUseStage, stage);
};
for (OpOperand &operand : op->getOpOperands())
analyzeOperand(operand);
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
analyzeOperand(*operand);
});
}
return crossStageValues;
}
scf::ForOp LoopPipelinerInternal::createKernelLoop(
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
RewriterBase &rewriter,
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
// Creates the list of initial values associated to values used across
// stages. The initial values come from the prologue created above.
// Keep track of the kernel argument associated to each version of the
// values passed to the kernel.
llvm::SmallVector<Value> newLoopArg;
// For existing loop argument initialize them with the right version from the
// prologue.
for (const auto &retVal :
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
Operation *def = retVal.value().getDefiningOp();
assert(def && "Only support loop carried dependencies of distance 1");
unsigned defStage = stages[def];
Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
[maxStage - defStage];
assert(valueVersion);
newLoopArg.push_back(valueVersion);
}
for (auto escape : crossStageValues) {
LiverangeInfo &info = escape.second;
Value value = escape.first;
for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
stageIdx++) {
Value valueVersion =
valueMapping[value][maxStage - info.lastUseStage + stageIdx];
assert(valueVersion);
newLoopArg.push_back(valueVersion);
loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
stageIdx)] = newLoopArg.size() - 1;
}
}
// Create the new kernel loop. When we peel the epilgue we need to peel
// `numStages - 1` iterations. Then we adjust the upper bound to remove those
// iterations.
Value newUb = forOp.getUpperBound();
if (peelEpilogue) {
Type t = ub.getType();
Location loc = forOp.getLoc();
// newUb = ub - maxStage * step
newUb = rewriter.create<arith::AddIOp>(
loc, ub,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, -maxStage))));
}
auto newForOp =
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
forOp.getStep(), newLoopArg);
// When there are no iter args, the loop body terminator will be created.
// Since we always create it below, remove the terminator if it was created.
if (!newForOp.getBody()->empty())
rewriter.eraseOp(newForOp.getBody()->getTerminator());
return newForOp;
}
LogicalResult LoopPipelinerInternal::createKernel(
scf::ForOp newForOp,
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
&crossStageValues,
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
RewriterBase &rewriter) {
valueMapping.clear();
// Create the kernel, we clone instruction based on the order given by
// user and remap operands coming from a previous stages.
rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
IRMapping mapping;
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
}
SmallVector<Value> predicates(maxStage + 1, nullptr);
if (!peelEpilogue) {
// Create a predicate for each stage except the last stage.
Location loc = newForOp.getLoc();
Type t = ub.getType();
for (unsigned i = 0; i < maxStage; i++) {
// c = ub - (maxStage - i) * step
Value c = rewriter.create<arith::AddIOp>(
loc, ub,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(t, -int64_t(maxStage - i)))));
Value pred = rewriter.create<arith::CmpIOp>(
newForOp.getLoc(), arith::CmpIPredicate::slt,
newForOp.getInductionVar(), c);
predicates[i] = pred;
}
}
for (Operation *op : opOrder) {
int64_t useStage = stages[op];
auto *newOp = rewriter.clone(*op, mapping);
SmallVector<OpOperand *> operands;
// Collect all the operands for the cloned op and its nested ops.
op->walk([&operands](Operation *nestedOp) {
for (OpOperand &operand : nestedOp->getOpOperands()) {
operands.push_back(&operand);
}
});
for (OpOperand *operand : operands) {
Operation *nestedNewOp = mapping.lookup(operand->getOwner());
// Special case for the induction variable uses. We replace it with a
// version incremented based on the stage where it is used.
if (operand->get() == forOp.getInductionVar()) {
rewriter.setInsertionPoint(newOp);
// offset = (maxStage - stages[op]) * step
Type t = step.getType();
Value offset = rewriter.create<arith::MulIOp>(
forOp.getLoc(), step,
rewriter.create<arith::ConstantOp>(
forOp.getLoc(),
rewriter.getIntegerAttr(t, maxStage - stages[op])));
Value iv = rewriter.create<arith::AddIOp>(
forOp.getLoc(), newForOp.getInductionVar(), offset);
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
rewriter.setInsertionPointAfter(newOp);
continue;
}
Value source = operand->get();
auto arg = dyn_cast<BlockArgument>(source);
if (arg && arg.getOwner() == forOp.getBody()) {
Value ret = forOp.getBody()->getTerminator()->getOperand(
arg.getArgNumber() - 1);
Operation *dep = ret.getDefiningOp();
if (!dep)
continue;
auto stageDep = stages.find(dep);
if (stageDep == stages.end() || stageDep->second == useStage)
continue;
// If the value is a loop carried value coming from stage N + 1 remap,
// it will become a direct use.
if (stageDep->second == useStage + 1) {
nestedNewOp->setOperand(operand->getOperandNumber(),
mapping.lookupOrDefault(ret));
continue;
}
source = ret;
}
// For operands defined in a previous stage we need to remap it to use
// the correct region argument. We look for the right version of the
// Value based on the stage where it is used.
Operation *def = source.getDefiningOp();
if (!def)
continue;
auto stageDef = stages.find(def);
if (stageDef == stages.end() || stageDef->second == useStage)
continue;
auto remap = loopArgMap.find(
std::make_pair(operand->get(), useStage - stageDef->second));
assert(remap != loopArgMap.end());
nestedNewOp->setOperand(operand->getOperandNumber(),
newForOp.getRegionIterArgs()[remap->second]);
}
if (predicates[useStage]) {
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
if (!newOp)
return failure();
// Remap the results to the new predicated one.
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
mapping.map(std::get<0>(values), std::get<1>(values));
}
rewriter.setInsertionPointAfter(newOp);
if (annotateFn)
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0);
}
// Collect the Values that need to be returned by the forOp. For each
// value we need to have `LastUseStage - DefStage` number of versions
// returned.
// We create a mapping between original values and the associated loop
// returned values that will be needed by the epilogue.
llvm::SmallVector<Value> yieldOperands;
for (OpOperand &yielOperand :
forOp.getBody()->getTerminator()->getOpOperands()) {
Value source = mapping.lookupOrDefault(yielOperand.get());
// When we don't peel the epilogue the yield value is used outside the loop
// we need to make sure we return the version from numStages - defStage.
if (!peelEpilogue &&
!forOp.getResult(yielOperand.getOperandNumber()).use_empty()) {
auto [def, distance] = getDefiningOpAndDistance(yielOperand.get());
if (def) {
auto defStage = stages.find(def);
if (defStage != stages.end()) {
Value pred = predicates[defStage->second];
if (pred) {
source = rewriter.create<arith::SelectOp>(
pred.getLoc(), pred, source,
newForOp.getBody()
->getArguments()[yielOperand.getOperandNumber() + 1]);
}
}
}
}
yieldOperands.push_back(source);
}
for (auto &it : crossStageValues) {
int64_t version = maxStage - it.second.lastUseStage + 1;
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
// add the original version to yield ops.
// If there is a live range spanning across more than 2 stages we need to
// add extra arg.
for (unsigned i = 1; i < numVersionReturned; i++) {
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
version++);
yieldOperands.push_back(
newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
newForOp.getNumInductionVars()]);
}
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
version++);
yieldOperands.push_back(mapping.lookupOrDefault(it.first));
}
// Map the yield operand to the forOp returned value.
for (const auto &retVal :
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
Operation *def = retVal.value().getDefiningOp();
assert(def && "Only support loop carried dependencies of distance 1");
unsigned defStage = stages[def];
if (defStage > 0) {
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
newForOp->getResult(retVal.index()),
maxStage - defStage + 1);
}
}
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
return success();
}
llvm::SmallVector<Value>
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
llvm::SmallVector<Value> returnValues(forOp->getNumResults());
// Emit different versions of the induction variable. They will be
// removed by dead code if not used.
for (int64_t i = 0; i < maxStage; i++) {
Location loc = forOp.getLoc();
Type t = lb.getType();
Value minusOne =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
// number of iterations = ((ub - 1) - lb) / step
Value totlaNumIteration = rewriter.create<arith::DivUIOp>(
loc,
rewriter.create<arith::SubIOp>(
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
step);
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
Value minusI =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
Value newlastIter = rewriter.create<arith::AddIOp>(
loc, lb,
rewriter.create<arith::MulIOp>(
loc, step,
rewriter.create<arith::AddIOp>(loc, totlaNumIteration, minusI)));
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
}
// Emit `maxStage - 1` epilogue part that includes operations from stages
// [i; maxStage].
for (int64_t i = 1; i <= maxStage; i++) {
for (Operation *op : opOrder) {
if (stages[op] < i)
continue;
Operation *newOp =
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
auto it = valueMapping.find(newOperand->get());
if (it != valueMapping.end()) {
Value replacement = it->second[maxStage - stages[op] + i];
newOperand->set(replacement);
}
});
if (annotateFn)
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue,
i - 1);
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
setValueMapping(op->getResult(destId), newOp->getResult(destId),
maxStage - stages[op] + i);
// If the value is a loop carried dependency update the loop argument
// mapping and keep track of the last version to replace the original
// forOp uses.
for (OpOperand &operand :
forOp.getBody()->getTerminator()->getOpOperands()) {
if (operand.get() != op->getResult(destId))
continue;
unsigned version = maxStage - stages[op] + i + 1;
// If the version is greater than maxStage it means it maps to the
// original forOp returned value.
if (version > maxStage) {
returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
continue;
}
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
newOp->getResult(destId), version);
}
}
}
}
return returnValues;
}
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
auto it = valueMapping.find(key);
// If the value is not in the map yet add a vector big enough to store all
// versions.
if (it == valueMapping.end())
it =
valueMapping
.insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
.first;
it->second[idx] = el;
}
} // namespace
FailureOr<ForOp>
mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
const triton::PipeliningOption &options,
bool *modifiedIR) {
if (modifiedIR)
*modifiedIR = false;
LoopPipelinerInternal pipeliner;
if (!pipeliner.initializeLoopInfo(forOp, options))
return failure();
if (modifiedIR)
*modifiedIR = true;
// 1. Emit prologue.
pipeliner.emitPrologue(rewriter);
// 2. Track values used across stages. When a value cross stages it will
// need to be passed as loop iteration arguments.
// We first collect the values that are used in a different stage than where
// they are defined.
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
crossStageValues = pipeliner.analyzeCrossStageValues();
// Mapping between original loop values used cross stage and the block
// arguments associated after pipelining. A Value may map to several
// arguments if its liverange spans across more than 2 stages.
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
// 3. Create the new kernel loop and return the block arguments mapping.
ForOp newForOp =
pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
// Create the kernel block, order ops based on user choice and remap
// operands.
if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
rewriter)))
return failure();
llvm::SmallVector<Value> returnValues =
newForOp.getResults().take_front(forOp->getNumResults());
if (options.peelEpilogue) {
// 4. Emit the epilogue after the new forOp.
rewriter.setInsertionPointAfter(newForOp);
returnValues = pipeliner.emitEpilogue(rewriter);
}
// 5. Erase the original loop and replace the uses with the epilogue output.
if (forOp->getNumResults() > 0)
rewriter.replaceOp(forOp, returnValues);
else
rewriter.eraseOp(forOp);
return newForOp;
}

View File

@@ -0,0 +1,101 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
// This is a fork of upstream pipeline transformation. This will be merged back
// upstream once we have a stable solution.
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/ADT/ArrayRef.h"
namespace mlir {
class RewriterBase;
class Operation;
class Value;
namespace scf {
class ForOp;
}
namespace triton {
/// Options to dictate how loops should be pipelined.
struct PipeliningOption {
/// Lambda returning all the operation in the forOp, with their stage, in the
/// order picked for the pipelined loop.
using GetScheduleFnType = std::function<void(
scf::ForOp, std::vector<std::pair<Operation *, unsigned>> &)>;
GetScheduleFnType getScheduleFn = nullptr;
enum class PipelinerPart {
Prologue,
Kernel,
Epilogue,
};
/// Lambda called by the pipeliner to allow the user to annotate the IR while
/// it is generated.
/// The callback passes the operation created along with the part of the
/// pipeline and the iteration index. The iteration index is always 0 for the
/// kernel. For the prologue and epilogue, it corresponds to the iteration
/// peeled out of the loop in the range [0, maxStage[.
using AnnotationlFnType =
std::function<void(Operation *, PipelinerPart, unsigned)>;
AnnotationlFnType annotateFn = nullptr;
/// Control whether the epilogue should be peeled out of the loop or
/// operations should be predicated to skip the early stages in the last loop
/// iterations. If the epilogue is predicated; the user needs to provide a
/// lambda to generate the predicated version of operations.
bool peelEpilogue = true;
/// Control whether the transformation checks that the number of iterations is
/// greater or equal to the number of stages and skip the transformation if
/// this is not the case. If the loop is dynamic and this is set to true the
/// pipeliner will have to predicate operations in the the prologue/epilogue.
bool supportDynamicLoops = false;
// Callback to predicate operations when the prologue or epilogue are not
// peeled. This takes the original operation, an i1 predicate value and the
// pattern rewriter. It is expected to replace the given operation with
// the predicated equivalent and return it, or return nullptr if the
// predication is impossible. In the latter case, pipelining will fail and
// may leave IR in a partially transformed state.
using PredicateOpFnType =
std::function<Operation *(RewriterBase &, Operation *, Value)>;
PredicateOpFnType predicateFn = nullptr;
// TODO: add option to decide if the prologue should be peeled.
};
/// Generate a pipelined version of the scf.for loop based on the schedule given
/// as option. This applies the mechanical transformation of changing the loop
/// and generating the prologue/epilogue for the pipelining and doesn't make any
/// decision regarding the schedule.
/// Based on the options the loop is split into several stages.
/// The transformation assumes that the scheduling given by user is valid.
/// For example if we break a loop into 3 stages named S0, S1, S2 we would
/// generate the following code with the number in parenthesis as the iteration
/// index:
///
/// S0(0) // Prologue
/// S0(1) S1(0) // Prologue
/// scf.for %I = %C0 to %N - 2 {
/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
/// }
/// S1(N) S2(N-1) // Epilogue
/// S2(N) // Epilogue
///
/// If `modifiedIR` is provided, it will be set to a value that indicates
/// whether pipelining modified the IR before failing, signaling to the caller
/// whether they can proceed with different transformations.
FailureOr<scf::ForOp> pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp,
const PipeliningOption &options,
bool *modifiedIR = nullptr);
} // namespace triton
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_

View File

@@ -0,0 +1,27 @@
#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
#include "PipelineExpander.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ArrayRef.h"
#include <vector>
namespace mlir {
namespace triton {
/// This fill out the pipelining options including schedule and annotations for
/// wait ops. This also does pre-processing by converting some of the loads into
/// async loads so that the IR is ready to be pipelined.
bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
mlir::triton::PipeliningOption &options);
/// This does post-processing on the pipelined loop to try to pipeline wgmma
/// ops.
// TODO: this should be included as part of the pipeline but currently the wgmma
// wait modeling is problematic.
void asyncLaunchDots(scf::ForOp forOp);
} // namespace triton
} // namespace mlir
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

View File

@@ -0,0 +1,88 @@
#include "PipelineExpander.h"
#include "Schedule.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/Debug.h"
//===----------------------------------------------------------------------===//
// This file will create a schedule that will be handed over to the pipeline
// expander.
// Software pipeliners are usually separated into two pieces, one that create a
// modulo schedule and an expander that rewrites the loop and emits a prologue
// and epilogue. This pass first calls a helper that will pre-process the IR
// to create async operations and create a modulo schedule. Then we call the
// expander to generate the prologue and new loop.
//===----------------------------------------------------------------------===//
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
static void pipelineLoop(scf::ForOp forOp, int numStages) {
mlir::triton::PipeliningOption options;
// Skip loop with distance > 1 for now.
// TODO: relax the constraint in the expander.
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
[](Value operand) {
Operation *def = operand.getDefiningOp();
return !def;
}))
return;
bool foundSchedule = false;
foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options);
// TODO: add more pipelines strategy.
if (!foundSchedule)
return;
IRRewriter rewriter(forOp->getContext());
rewriter.setInsertionPoint(forOp);
FailureOr<scf::ForOp> newForOp =
mlir::triton::pipelineForLoop(rewriter, forOp, options);
if (succeeded(newForOp))
mlir::triton::asyncLaunchDots(newForOp.value());
}
namespace {
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
PipelinePass() = default;
PipelinePass(int numStages, int numWarps, int numCTAs,
int computeCapability) {
this->numStages = numStages;
this->numWarps = numWarps;
this->numCTAs = numCTAs;
this->computeCapability = computeCapability;
}
void runOnOperation() override {
if (this->numStages <= 1)
return;
SmallVector<scf::ForOp> loops;
getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
for (scf::ForOp forOp : loops) {
pipelineLoop(forOp, numStages);
}
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages,
int numWarps,
int numCTAs,
int computeCapability) {
return std::make_unique<PipelinePass>(numStages, numWarps, numCTAs,
computeCapability);
}

View File

@@ -332,9 +332,6 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
setEncoding({afterArg, result}, info, changed, user);
continue;
}
// Workaround: don't propagate through truncI
if (isa<arith::TruncIOp>(user))
continue;
if (user->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() ||
user->hasTrait<mlir::OpTrait::Elementwise>() ||
isa<triton::ReduceOp, triton::ExpandDimsOp,
@@ -755,7 +752,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
map(oldResult, newResult);
return newOp;
}
assert(0 && "unexpected op in rewrite");
llvm::report_fatal_error("unexpected op in rewrite");
return nullptr;
}
@@ -772,34 +769,6 @@ static bool canBeRemat(Operation *op) {
return true;
}
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
// updated and needs to be updated separatly for the loop to be correct.
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter,
scf::ForOp loop,
ValueRange newIterOperands) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loop);
// Create a new loop before the existing one, with the extra operands.
rewriter.setInsertionPoint(loop);
auto operands = llvm::to_vector<4>(loop.getInitArgs());
operands.append(newIterOperands.begin(), newIterOperands.end());
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
operands);
newLoop.getBody()->erase();
newLoop.getRegion().getBlocks().splice(
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
for (Value operand : newIterOperands)
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
loop.getNumResults())))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
return newLoop;
}
static void rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp, IRMapping &mapping) {

View File

@@ -98,8 +98,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
// Some ops from SCF are illegal
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
// We have custom versions of some arith operators
addIllegalOp<arith::CmpIOp, arith::CmpFOp>();
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
triton::TritonDialect, cf::ControlFlowDialect,

View File

@@ -232,8 +232,10 @@ std::string GraphLayoutMarker::getColor(const Type &type) const {
return "orange";
else if (layout.isa<triton::gpu::SharedEncodingAttr>())
return "orangered";
else
assert(0 && "Unrecognized layout");
else {
llvm::report_fatal_error("Unrecognized layout");
return "unknown";
}
} else {
return "white";
}
@@ -342,11 +344,39 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
}
return true;
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
auto viewDstType = view.getType().cast<RankedTensorType>();
RankedTensorType newDstType = RankedTensorType::get(
viewDstType.getShape(), viewDstType.getElementType(), targetEncoding);
return !triton::gpu::isExpensiveView(view.getOperand().getType(),
newDstType);
}
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
triton::MakeRangeOp, triton::SplatOp>(op);
}
//
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
ValueRange newIterOperands) {
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(loop);
// Create a new loop before the existing one, with the extra operands.
auto operands = llvm::to_vector<4>(loop.getInitArgs());
operands.append(newIterOperands.begin(), newIterOperands.end());
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
operands);
newLoop.getBody()->erase();
newLoop.getRegion().getBlocks().splice(
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
for (Value operand : newIterOperands)
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
loop.getNumResults())))
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
return newLoop;
}
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping) {

View File

@@ -79,6 +79,17 @@ void CreateMutexOp::build(::mlir::OpBuilder &builder,
build(builder, state, MutexType::get(builder.getContext()));
}
///--- DotWaitOp ---
LogicalResult DotWaitOp::inferReturnTypes(
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
for (Value operand : operands)
inferredReturnTypes.push_back(operand.getType());
return mlir::success();
}
} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir

View File

@@ -68,7 +68,8 @@ Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef<int64_t> shape,
replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout));
} else {
// Other layouts are generated by passes after PlanCTAPass
assert(0 && "replaceCTALayout not implemented");
llvm::report_fatal_error("replaceCTALayout not implemented");
return layout;
}
}
@@ -393,7 +394,8 @@ bool CTAPlanner::propagateBackward(CastOp cast) {
Value output = cast.getResult(0);
unsigned numUsers = getNumUsers(input);
if (numUsers == 0) {
assert(0 && "Unreachable branch");
llvm::report_fatal_error("Unreachable branch");
return false;
} else if (numUsers == 1) {
Type outTy = output.getType();
if (auto ptrTy = outTy.dyn_cast<triton::PointerType>())
@@ -649,7 +651,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
return true;
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
return externElementwiseOp.getPure();
if (llvm::isa<ttg::CmpIOp, ttg::CmpFOp, ttg::SelectOp>(op))
if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
return true;
return false;
}
@@ -711,7 +713,7 @@ bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims,
bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims,
Attribute newSrcLayout) {
assert(0 && "processExpandDimsForward not implemented yet");
llvm::report_fatal_error("processExpandDimsForward not implemented yet");
return true;
}
@@ -827,7 +829,7 @@ int findResultIndex(Operation *op, Value result) {
for (int i = 0; i < op->getNumResults(); ++i)
if (op->getResult(i) == result)
return i;
assert(0 && "Invalid index of op result");
llvm::report_fatal_error("Invalid index of op result");
return -1;
}
@@ -849,7 +851,7 @@ bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) {
auto newType = cast.getResult(0).getType();
return processForOp(forOp, index, newType);
} else {
assert(0 && "Unexpected parent op of block argument");
llvm::report_fatal_error("Unexpected parent op of block argument");
return true;
}
}
@@ -869,7 +871,7 @@ bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) {
else if (auto forOp = llvm::dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
return processForOp(forOp, index, newType);
else
assert(0 && "Unexpected parent op of YieldOp");
llvm::report_fatal_error("Unexpected parent op of YieldOp");
return true;
}
@@ -936,7 +938,8 @@ bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) {
Operation *clonedOp = builder.clone(*defOp);
newInput = clonedOp->getResult(0);
} else {
assert(0 && "Layout conflict for block arg"); // TODO
llvm::report_fatal_error("Layout conflict for block arg"); // TODO
return false;
}
}
first = false;

View File

@@ -55,7 +55,7 @@ bool isDivisible(Value v, unsigned divisor) {
auto func = dyn_cast<tt::FuncOp>(parentOp);
assert(func);
if (auto attr = func.getArgAttrOfType<IntegerAttr>(blockArg.getArgNumber(),
"tt.max_divisibility"))
"tt.divisibility"))
return attr.getValue().getZExtValue() % divisor == 0;
return false;
} else if (v.getParentBlock()->isEntryBlock() && (!v.isa<BlockArgument>())) {
@@ -98,13 +98,8 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) {
return !(boxDimSwizzle && strideDivisible && enableTMA);
}
// TODO: When encoding exists use triton::gpu::CmpIOp as arith::CmpIOp doesn't
// play well with encoding attributes. Move back to arith::CmpIOp when this pass
// moves back to triton IR level.
Value createCmpOp(OpBuilder &builder, Location loc, RankedTensorType type,
arith::CmpIPredicate pred, Value lhs, Value rhs) {
if (type.getEncoding())
return builder.create<ttg::CmpIOp>(loc, type, pred, lhs, rhs);
return builder.create<arith::CmpIOp>(loc, type, pred, lhs, rhs);
}
@@ -358,12 +353,17 @@ class TritonGPURewriteTensorPointerPass
: public TritonGPURewriteTensorPointerBase<
TritonGPURewriteTensorPointerPass> {
private:
int computeCapability;
// int computeCapability;
DenseMap<Value, RewritedInfo> rewritedInfo;
public:
explicit TritonGPURewriteTensorPointerPass(int computeCapability)
: computeCapability(computeCapability) {}
// explicit TritonGPURewriteTensorPointerPass(int computeCapability)
// : computeCapability(computeCapability) {}
TritonGPURewriteTensorPointerPass() = default;
TritonGPURewriteTensorPointerPass(int computeCapability) {
this->computeCapability = computeCapability;
}
static bool needRewrite(Operation *op, const DenseSet<Value> &valueToRemove) {
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
@@ -763,17 +763,16 @@ public:
ModuleOp mod = getOperation();
DenseSet<Value> valueToRemove;
mod.walk([&valueToRemove,
computeCapability = this->computeCapability](Operation *op) {
mod.walk([&valueToRemove, this](Operation *op) {
if (auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op)) {
if (shouldRemove(makeTensorPtrOp, computeCapability))
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
valueToRemove.insert(op->getResult(0));
}
if (llvm::isa<tt::AdvanceOp>(op)) {
auto src = op->getOperand(0);
if (tt::isTensorPointerType(src.getType())) {
auto makeTensorPtrOp = getMakeTensorPtrOp(src);
if (shouldRemove(makeTensorPtrOp, computeCapability)) {
if (shouldRemove(makeTensorPtrOp, this->computeCapability)) {
valueToRemove.insert(op->getResult(0));
}
}
@@ -782,7 +781,7 @@ public:
auto src = op->getOperand(0);
if (tt::isTensorPointerType(src.getType())) {
auto makeTensorPtrOp = getMakeTensorPtrOp(src);
if (shouldRemove(makeTensorPtrOp, computeCapability))
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
valueToRemove.insert(src);
}
}
@@ -791,7 +790,7 @@ public:
for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) {
if (tt::isTensorPointerType(iterOperands[i].getType())) {
auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]);
if (shouldRemove(makeTensorPtrOp, computeCapability))
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
valueToRemove.insert(iterOperands[i]);
}
}
@@ -800,7 +799,7 @@ public:
for (unsigned i = 0, size = yieldOp.getNumOperands(); i < size; ++i) {
if (tt::isTensorPointerType(operands[i].getType())) {
auto makeTensorPtrOp = getMakeTensorPtrOp(operands[i]);
if (shouldRemove(makeTensorPtrOp, computeCapability))
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
valueToRemove.insert(operands[i]);
}
}

View File

@@ -892,7 +892,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
Value result = forOp->getResult(resultIndex);
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
result.replaceAllUsesExcept(dotWait.getResult(0), dotWait);
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);