mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit 'cb3d79a185e40c9d8a579bea07747a8a8d157d52' into ifu-231117
Conflicts: lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp lib/Dialect/TritonGPU/IR/Dialect.cpp python/setup.py python/test/unit/language/assert_helper.py python/test/unit/operators/test_flash_attention.py python/test/unit/runtime/test_subproc.py python/triton/compiler/compiler.py python/triton/language/semantic.py python/triton/runtime/autotuner.py python/triton/runtime/jit.py python/tutorials/03-matrix-multiplication.py python/tutorials/05-layer-norm.py python/tutorials/06-fused-attention.py python/tutorials/11-grouped-gemm.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -52,7 +52,7 @@ unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return dotLayout.getTotalElemsPerThread(shape, eltTy);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
llvm::report_fatal_error("getElemsPerThread not implemented");
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
@@ -68,7 +68,7 @@ SmallVector<unsigned> getElemsPerThread(Attribute layout,
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
return mfmaLayout.getElemsPerThread(shape, eltTy);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
llvm::report_fatal_error("getElemsPerThread not implemented");
|
||||
return SmallVector<unsigned>();
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
|
||||
return threadsPerWarp;
|
||||
}
|
||||
assert(0 && "getThreadsPerWarp not implemented");
|
||||
llvm::report_fatal_error("getThreadsPerWarp not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -180,15 +180,17 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parent = sliceLayout.getParent();
|
||||
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
|
||||
assert(parentWarpsPerCTA.size() == 2 &&
|
||||
"getWarpsPerCTA only implemented for 2D slice layout");
|
||||
assert(parentWarpsPerCTA.size() == 2 ||
|
||||
parentWarpsPerCTA[sliceLayout.getDim()] == 1 &&
|
||||
"getWarpsPerCTA only implemented for 2D slice layout or the "
|
||||
"slice dim must have 1 warp in the parent layout");
|
||||
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
|
||||
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
|
||||
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
|
||||
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
|
||||
return warpsPerCTA;
|
||||
}
|
||||
assert(0 && "getWarpsPerCTA not implemented");
|
||||
llvm::report_fatal_error("getWarpsPerCTA not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -264,7 +266,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
} else if (opIdx == 1) {
|
||||
return {4, 1};
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
return {};
|
||||
}
|
||||
} else if (parentLayout.isa<MfmaEncodingAttr>()) {
|
||||
@@ -278,12 +280,13 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
llvm::report_fatal_error(
|
||||
"DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
return {};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "getSizePerThread not implemented");
|
||||
llvm::report_fatal_error("getSizePerThread not implemented");
|
||||
return {};
|
||||
}
|
||||
}
|
||||
@@ -337,6 +340,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
threads = {8 * mmaLayout.getWarpsPerCTA()[0],
|
||||
4 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else
|
||||
<<<<<<< HEAD
|
||||
assert(0 && "Unimplemented usage of MmaEncodingAttr");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
@@ -346,8 +350,11 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
4 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
}
|
||||
=======
|
||||
llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr");
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getThreadsPerCTA");
|
||||
llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA");
|
||||
}
|
||||
|
||||
return threads;
|
||||
@@ -381,11 +388,15 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
return {16 * mmaLayout.getWarpsPerCTA()[0],
|
||||
instrShape[1] * mmaLayout.getWarpsPerCTA()[1]};
|
||||
}
|
||||
<<<<<<< HEAD
|
||||
assert(0 && "Unexpected MMA layout version found");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
|
||||
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
=======
|
||||
llvm::report_fatal_error("Unexpected MMA layout version found");
|
||||
>>>>>>> cb3d79a185e40c9d8a579bea07747a8a8d157d52
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
@@ -401,7 +412,7 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
} else if (opIdx == 1) {
|
||||
return {16, parentShapePerCTATile[1]};
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
}
|
||||
} else if (auto parentMfmaLayout =
|
||||
parentLayout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
@@ -416,15 +427,20 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
assert(0 && "DotOperandEncodingAttr opIdx must be 0 or 1");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
llvm::report_fatal_error(
|
||||
"DotOperandEncodingAttr non-MmaEncodingAttr parent not "
|
||||
"supported yet");
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getShapePerCTATile");
|
||||
llvm::report_fatal_error("Unimplemented usage of getShapePerCTATile");
|
||||
}
|
||||
return shape;
|
||||
}
|
||||
|
||||
bool isExpensiveView(Type srcType, Type dstType) {
|
||||
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/* Utility function used by getOrder and getCTAOrder of SliceEncodingAttr.
|
||||
@@ -473,7 +489,7 @@ SmallVector<unsigned> getOrder(Attribute layout) {
|
||||
return SmallVector<unsigned>(sharedLayout.getOrder().begin(),
|
||||
sharedLayout.getOrder().end());
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getOrder");
|
||||
llvm::report_fatal_error("Unimplemented usage of getOrder");
|
||||
}
|
||||
return {};
|
||||
};
|
||||
@@ -494,7 +510,7 @@ CTALayoutAttr getCTALayout(Attribute layout) {
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
return sharedLayout.getCTALayout();
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getCTALayout");
|
||||
llvm::report_fatal_error("Unimplemented usage of getCTALayout");
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -522,7 +538,8 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
|
||||
* in the branch where layout is an instance of SliceEncodingAttr. This is
|
||||
* inconvenient but safe.
|
||||
*/
|
||||
assert(0 && "getCTAsPerCGA for SliceEncodingAttr is not well-defined");
|
||||
llvm::report_fatal_error(
|
||||
"getCTAsPerCGA for SliceEncodingAttr is not well-defined");
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
|
||||
ref = mmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#ifdef USE_ROCM
|
||||
@@ -534,7 +551,7 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
ref = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getCTAsPerCGA");
|
||||
llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA");
|
||||
return SmallVector<unsigned>(ref.begin(), ref.end());
|
||||
}
|
||||
|
||||
@@ -589,7 +606,7 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
ref = sharedLayout.getCTALayout().getCTAOrder();
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getCTAOrder");
|
||||
llvm::report_fatal_error("Unimplemented usage of getCTAOrder");
|
||||
}
|
||||
return SmallVector<unsigned>(ref.begin(), ref.end());
|
||||
}
|
||||
@@ -642,9 +659,9 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
|
||||
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
|
||||
return getNumWarpsPerCTA(dotLayout.getParent());
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
assert(0 && "Cannot get numWarps from SharedEncodingAttr");
|
||||
llvm::report_fatal_error("Cannot get numWarps from SharedEncodingAttr");
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getNumWarpsPerCTA");
|
||||
llvm::report_fatal_error("Unimplemented usage of getNumWarpsPerCTA");
|
||||
return product<unsigned>(warpsPerCTA);
|
||||
}
|
||||
|
||||
@@ -665,7 +682,7 @@ unsigned getNumCTAs(Attribute layout) {
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
else
|
||||
assert(0 && "Unimplemented usage of getNumCTAs");
|
||||
llvm::report_fatal_error("Unimplemented usage of getNumCTAs");
|
||||
return product<unsigned>(CTAsPerCGA);
|
||||
}
|
||||
|
||||
@@ -1779,13 +1796,15 @@ struct CanonicalizeConvertFromView
|
||||
Operation *arg = op->getOperand(0).getDefiningOp();
|
||||
if (!arg)
|
||||
return mlir::failure();
|
||||
auto convert = dyn_cast<ConvertLayoutOp>(arg);
|
||||
if (!convert)
|
||||
return failure();
|
||||
if (isExpensiveView(convert.getOperand().getType(), op.getType()))
|
||||
return failure();
|
||||
// view(convert) -> view
|
||||
if (auto convert = dyn_cast<ConvertLayoutOp>(arg)) {
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), convert.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
return mlir::failure();
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(op, op->getResult(0).getType(),
|
||||
convert.getOperand());
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1831,6 +1850,8 @@ struct CanonicalizeConvertFromConvert
|
||||
return mlir::failure();
|
||||
// cvt(view) -> view
|
||||
if (auto view = dyn_cast<triton::ViewOp>(arg)) {
|
||||
if (isExpensiveView(view.getOperand().getType(), op.getType()))
|
||||
return failure();
|
||||
rewriter.replaceOpWithNewOp<triton::ViewOp>(
|
||||
op, op->getResult(0).getType(), view.getResult());
|
||||
return mlir::success();
|
||||
|
||||
@@ -70,10 +70,15 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, {filter});
|
||||
auto slices = multiRootGetSlice(dotOp, {filter});
|
||||
for (Operation *op : slices)
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
if (isa<tt::DotOp>(op) && (op != dotOp)) {
|
||||
if (shape[0] >= shape[1]) {
|
||||
return {(unsigned)numWarps, 1};
|
||||
} else {
|
||||
return {1, (unsigned)numWarps};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
@@ -133,8 +138,18 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
mlir::TypeID::get<arith::ArithDialect>());
|
||||
}
|
||||
|
||||
// finds the first different value bitwidth in the chain of
|
||||
// shape-preserving unary ops that x depends on
|
||||
// Finds the first different bitwidth in the chain of shape-preserving
|
||||
// unary ops that x depends on.
|
||||
// There are two primary scenarios:
|
||||
// (1) Upcasting: A sequence such as loading an fp16, followed by arithmetic
|
||||
// operations, then bitcasting to fp32, and finally computing in fp32.
|
||||
// (2) Downcasting: This might involve loading an fp32, performing arithmetic
|
||||
// operations, bitcasting to fp16, and finally computing in fp16.
|
||||
// In the upcasting scenario, element reordering converts the original
|
||||
// elements distribution to the order of higher precision primitives. As a
|
||||
// result, kwidth can be the bitwidth of the lower precision primitive.
|
||||
// Conversely, in the downcasting scenario, no reordering is performed,
|
||||
// making it directory use the lower precision primitive.
|
||||
static int computeOrigBitWidth(Value x) {
|
||||
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
|
||||
int origBitWidth = finalBitWidth;
|
||||
@@ -143,11 +158,17 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = bwdFilter;
|
||||
getBackwardSlice(x, &slice, opt);
|
||||
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
|
||||
if (firstOp)
|
||||
if (Value arg = firstOp->getOperand(0))
|
||||
if (RankedTensorType argTy = arg.getType().dyn_cast<RankedTensorType>())
|
||||
origBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
|
||||
for (auto op : slice) {
|
||||
if (Value arg = op->getOperand(0))
|
||||
if (RankedTensorType argTy =
|
||||
arg.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto argBitWidth = argTy.getElementType().getIntOrFloatBitWidth();
|
||||
if (argBitWidth != origBitWidth) {
|
||||
origBitWidth = std::min<int>(origBitWidth, argBitWidth);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return origBitWidth;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
DecomposeConversions.cpp
|
||||
OptimizeDotOperands.cpp
|
||||
OptimizeEpilogue.cpp
|
||||
Pipeline.cpp
|
||||
OptimizeThreadLocality.cpp
|
||||
Pipeliner/MatmulLoopPipeline.cpp
|
||||
Pipeliner/PipelineExpander.cpp
|
||||
Pipeliner/SoftwarePipeliner.cpp
|
||||
Prefetch.cpp
|
||||
RemoveLayoutConversions.cpp
|
||||
ReorderInstructions.cpp
|
||||
|
||||
312
lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp
Normal file
312
lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp
Normal file
@@ -0,0 +1,312 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include <memory>
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
class TritonGPUOptimizeThreadLocalityPass
|
||||
: public TritonGPUOptimizeThreadLocalityBase<
|
||||
TritonGPUOptimizeThreadLocalityPass> {
|
||||
void runOnOperation() override {
|
||||
ModuleOp mod = getOperation();
|
||||
DenseSet<triton::ReduceOp> reduceOps;
|
||||
mod.walk([&](triton::ReduceOp reduce) -> void {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto rank = srcType.getShape().size();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
auto reductionOp = getReductionOp(reduce);
|
||||
if (!reductionOp ||
|
||||
!isa<arith::AddFOp, arith::MaximumFOp, arith::MinimumFOp,
|
||||
arith::MulFOp>(reductionOp.value()))
|
||||
return;
|
||||
// TODO: relax this restriction
|
||||
if (!(srcEncoding.isa<triton::gpu::BlockedEncodingAttr>() && rank > 1))
|
||||
return;
|
||||
for (auto operand : reduce->getOperands()) {
|
||||
auto def = operand.getDefiningOp();
|
||||
if (!isa<triton::LoadOp>(def))
|
||||
return;
|
||||
}
|
||||
auto elemsPerThread =
|
||||
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
|
||||
// Not worth applying this optimization if there is only one element per
|
||||
// thread on the reduction axis
|
||||
if (elemsPerThread == 1)
|
||||
return;
|
||||
if (!reduce->hasOneUse())
|
||||
return;
|
||||
Operation *user = *(reduce->getUsers().begin());
|
||||
if (!user->hasOneUse())
|
||||
return;
|
||||
OpOperand &yieldOpOperand = *(user->getUses().begin());
|
||||
auto yieldOp = dyn_cast<scf::YieldOp>(yieldOpOperand.getOwner());
|
||||
if (!yieldOp)
|
||||
return;
|
||||
auto operandNumber = yieldOpOperand.getOperandNumber();
|
||||
Block *block = reduce->getBlock();
|
||||
Operation *parentOp = block->getParentOp();
|
||||
auto forOp = dyn_cast<scf::ForOp>(parentOp);
|
||||
if (!forOp)
|
||||
return;
|
||||
auto argNum = yieldOpOperand.getOperandNumber();
|
||||
auto oldAccum = forOp.getInitArgs()[argNum];
|
||||
auto cstOp = dyn_cast<arith::ConstantOp>(oldAccum.getDefiningOp());
|
||||
if (!cstOp)
|
||||
return;
|
||||
reduceOps.insert(reduce);
|
||||
});
|
||||
|
||||
for (auto reduce : reduceOps) {
|
||||
OpBuilder builder(reduce);
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
assert(srcEncoding.isa<triton::gpu::BlockedEncodingAttr>() &&
|
||||
"Thread locality optimization only supports blocked encoding");
|
||||
auto blocked = srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto elemsPerThread =
|
||||
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
|
||||
auto rank = srcShape.size();
|
||||
// create new layouts
|
||||
auto blocked3d = getThreadLocalityOptimizedEncoding(reduce);
|
||||
auto viewOpTensorShape = getThreadLocalityOptimizedShape(reduce);
|
||||
auto viewOpTensorType = RankedTensorType::get(
|
||||
viewOpTensorShape, srcType.getElementType(), blocked3d);
|
||||
auto slice2d = triton::gpu::SliceEncodingAttr::get(mod.getContext(), rank,
|
||||
blocked3d);
|
||||
// Get forOp
|
||||
assert(reduce->hasOneUse());
|
||||
OpOperand &use = *(reduce->getUses().begin());
|
||||
auto operandNumber = use.getOperandNumber();
|
||||
auto oldUpdate = use.getOwner();
|
||||
assert(oldUpdate->getNumOperands() == 2);
|
||||
auto accumOperandNumber = (operandNumber == 0) ? 1 : 0;
|
||||
auto accumOperand = oldUpdate->getOperand(accumOperandNumber);
|
||||
assert(accumOperand.isa<BlockArgument>());
|
||||
auto blockArg = accumOperand.dyn_cast<BlockArgument>();
|
||||
auto blockArgNum = blockArg.getArgNumber();
|
||||
auto forOp = dyn_cast<scf::ForOp>(blockArg.getOwner()->getParentOp());
|
||||
// get oldAccum
|
||||
auto oldAccum =
|
||||
forOp.getInitArgs()[blockArgNum - forOp.getNumInductionVars()];
|
||||
// get old loop user
|
||||
Value loopResult =
|
||||
forOp.getResult(blockArgNum - forOp.getNumInductionVars());
|
||||
assert(loopResult.hasOneUse());
|
||||
OpOperand &loopUse = *(loopResult.getUses().begin());
|
||||
Operation *loopUser = loopUse.getOwner();
|
||||
// get old loop yield
|
||||
auto oldYield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
// create newAccum initialization
|
||||
auto newAccum =
|
||||
createAccum(builder, reduce, oldAccum, viewOpTensorShape, slice2d);
|
||||
// create new loop by copying the old for op signature and appending
|
||||
// newAccum to the block arguments
|
||||
auto newLoop = replaceForOpWithNewSignature(
|
||||
builder, forOp, ValueRange{newAccum->getResult(0)});
|
||||
// create thread local reduction (also adds viewOps)
|
||||
auto newReduce = createReduce(builder, reduce, viewOpTensorType);
|
||||
|
||||
// create new accum update
|
||||
auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate);
|
||||
// create new yield
|
||||
auto newYield = createYield(builder, newLoop, oldYield,
|
||||
newUpdate->getResult(0), blockArgNum);
|
||||
// create post loop reduction on the original reduce axis
|
||||
auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce);
|
||||
// add convert_layout to get back to original layout, the result layout
|
||||
// should now match the layout of the old accumulator (%cst)
|
||||
Type destType = loopResult.getType();
|
||||
auto cvtLayout = createConvertLayout(builder, destType, newReduce2);
|
||||
// incorporate the original accumulator value into the final result
|
||||
auto finalOp = incorporateOriginalAccumulatorValue(builder, oldUpdate,
|
||||
cvtLayout, oldAccum);
|
||||
// Replace the old loop user with the final result
|
||||
loopUser->setOperand(loopUse.getOperandNumber(), finalOp->getResult(0));
|
||||
|
||||
// cleanup
|
||||
oldYield.erase();
|
||||
forOp.erase();
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
std::optional<Operation *> getReductionOp(triton::ReduceOp reduce) const {
|
||||
auto numRegions = reduce->getNumRegions();
|
||||
if (numRegions != 1)
|
||||
return std::nullopt;
|
||||
Region ®ion = reduce->getRegion(0);
|
||||
auto numBlocks = region.getBlocks().size();
|
||||
if (numBlocks != 1)
|
||||
return std::nullopt;
|
||||
Block &block = region.front();
|
||||
auto blockWithoutTerminator = block.without_terminator();
|
||||
auto blockSizeWithoutTerminator = std::distance(
|
||||
blockWithoutTerminator.begin(), blockWithoutTerminator.end());
|
||||
if (blockSizeWithoutTerminator != 1)
|
||||
return std::nullopt;
|
||||
Operation *op = &block.front();
|
||||
return std::optional<Operation *>(op);
|
||||
}
|
||||
Operation *incorporateOriginalAccumulatorValue(OpBuilder &builder,
|
||||
Operation *oldUpdate,
|
||||
Operation *cvtLayout,
|
||||
Value oldAccum) const {
|
||||
builder.setInsertionPointAfter(cvtLayout);
|
||||
IRMapping mapping;
|
||||
mapping.map(oldUpdate->getOperand(0), oldAccum);
|
||||
mapping.map(oldUpdate->getOperand(1), cvtLayout->getResult(0));
|
||||
auto finalOp = cloneWithInferType(builder, &(*oldUpdate), mapping);
|
||||
return finalOp;
|
||||
}
|
||||
Operation *createConvertLayout(OpBuilder &builder, Type destType,
|
||||
Operation *newReduce) const {
|
||||
builder.setInsertionPointAfter(newReduce);
|
||||
auto newCvt = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
newReduce->getLoc(), destType, newReduce->getResult(0));
|
||||
return newCvt;
|
||||
}
|
||||
|
||||
Operation *createPostLoopReduce(OpBuilder &builder, scf::ForOp &loop,
|
||||
triton::ReduceOp &reduce) const {
|
||||
auto resultIndex =
|
||||
loop.getBody()->getNumArguments() - 1 - loop.getNumInductionVars();
|
||||
auto newLoopResult = loop.getResult(resultIndex);
|
||||
builder.setInsertionPointAfter(loop);
|
||||
IRMapping mapping;
|
||||
mapping.map(*(reduce.getOperands().begin()), newLoopResult);
|
||||
auto newReduce2 = cloneWithInferType(builder, &(*reduce), mapping);
|
||||
return newReduce2;
|
||||
}
|
||||
|
||||
Operation *createYield(OpBuilder &builder, scf::ForOp &loop,
|
||||
scf::YieldOp &oldYield, Value newUpdate,
|
||||
int oldAccumBlockArgNum) const {
|
||||
builder.setInsertionPoint(oldYield);
|
||||
SmallVector<Value> yieldValues = llvm::to_vector(oldYield.getOperands());
|
||||
yieldValues[oldAccumBlockArgNum - 1] =
|
||||
loop.getBody()->getArgument(oldAccumBlockArgNum);
|
||||
yieldValues.push_back(newUpdate);
|
||||
auto newYield =
|
||||
builder.create<scf::YieldOp>(oldYield.getLoc(), yieldValues);
|
||||
return newYield;
|
||||
}
|
||||
|
||||
Operation *createUpdate(OpBuilder &builder, scf::ForOp &loop,
|
||||
Operation *newReduce, Operation *oldUpdate) const {
|
||||
auto blockArgNum = loop.getBody()->getNumArguments() - 1;
|
||||
auto newArg = loop.getBody()->getArgument(blockArgNum);
|
||||
builder.setInsertionPointAfter(newReduce);
|
||||
IRMapping mapping;
|
||||
mapping.map(oldUpdate->getOperand(0), newArg);
|
||||
mapping.map(oldUpdate->getOperand(1), newReduce->getResult(0));
|
||||
auto newUpdate = cloneWithInferType(builder, oldUpdate, mapping);
|
||||
return newUpdate;
|
||||
}
|
||||
|
||||
Operation *createReduce(OpBuilder &builder, triton::ReduceOp reduce,
|
||||
Type viewOpTensorType) const {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto rank = srcType.getShape().size();
|
||||
builder.setInsertionPointAfter(reduce);
|
||||
IRMapping mapping;
|
||||
for (auto operand : reduce.getOperands()) {
|
||||
auto viewOp = builder.create<triton::ViewOp>(reduce.getLoc(),
|
||||
viewOpTensorType, operand);
|
||||
mapping.map(operand, viewOp);
|
||||
}
|
||||
|
||||
auto newReduce = cloneWithInferType(builder, &(*reduce), mapping);
|
||||
newReduce->setAttr("axis", builder.getI32IntegerAttr(rank));
|
||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newReduce);
|
||||
if (typeInfer) {
|
||||
SmallVector<Type, 1> newTypes;
|
||||
auto success = typeInfer.inferReturnTypes(
|
||||
newReduce->getContext(), newReduce->getLoc(),
|
||||
newReduce->getOperands(), newReduce->getAttrDictionary(),
|
||||
newReduce->getPropertiesStorage(), newReduce->getRegions(), newTypes);
|
||||
if (succeeded(success)) {
|
||||
for (size_t i = 0; i < newTypes.size(); i++)
|
||||
newReduce->getResult(i).setType(newTypes[i]);
|
||||
}
|
||||
}
|
||||
return newReduce;
|
||||
}
|
||||
|
||||
Operation *createAccum(OpBuilder &builder, triton::ReduceOp reduce,
|
||||
Value &oldAccum, SmallVector<int64_t> &shape,
|
||||
Attribute &slice2d) const {
|
||||
// Drop the last dimension (thread locality dimension)
|
||||
SmallVector<int64_t> accumShape(shape.begin(), shape.end() - 1);
|
||||
auto elemType =
|
||||
oldAccum.getType().cast<RankedTensorType>().getElementType();
|
||||
// Create tensor type for the new accumulator
|
||||
auto accumType = RankedTensorType::get(accumShape, elemType, slice2d);
|
||||
// Create new accumulator
|
||||
builder.setInsertionPointAfter(oldAccum.getDefiningOp());
|
||||
auto reductionOp = getReductionOp(reduce);
|
||||
assert(reductionOp && "Processing a reduce that is not supported!");
|
||||
auto neutralVal = mlir::arith::getNeutralElement(reductionOp.value());
|
||||
assert(neutralVal && "Could not find neutral value for reduction op!");
|
||||
auto denseAttr = DenseElementsAttr::get(accumType, neutralVal.value());
|
||||
auto newAccum = builder.create<arith::ConstantOp>(oldAccum.getLoc(),
|
||||
accumType, denseAttr);
|
||||
return newAccum;
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
getThreadLocalityOptimizedShape(triton::ReduceOp reduce) const {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto srcShape = srcType.getShape();
|
||||
auto rank = srcShape.size();
|
||||
auto elemsPerThread =
|
||||
triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()];
|
||||
auto viewOpTensorShape = insertValue(srcShape, rank, 1);
|
||||
viewOpTensorShape[reduce.getAxis()] /= elemsPerThread;
|
||||
viewOpTensorShape[rank] = elemsPerThread;
|
||||
return viewOpTensorShape;
|
||||
}
|
||||
|
||||
Attribute getThreadLocalityOptimizedEncoding(triton::ReduceOp reduce) const {
|
||||
auto srcType = reduce.getOperands()[0].getType().cast<RankedTensorType>();
|
||||
auto rank = srcType.getShape().size();
|
||||
auto srcEncoding = srcType.getEncoding();
|
||||
auto blocked = srcEncoding.dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
auto sizePerThread3d =
|
||||
insertValue(blocked.getSizePerThread(), rank,
|
||||
blocked.getSizePerThread()[reduce.getAxis()]);
|
||||
sizePerThread3d[reduce.getAxis()] = 1;
|
||||
auto threadsPerWarp3d = insertValue(blocked.getThreadsPerWarp(), rank, 1);
|
||||
auto warsPerCTA3d = insertValue(blocked.getWarpsPerCTA(), rank, 1);
|
||||
auto order3d = insertValue(blocked.getOrder(), 0, rank);
|
||||
auto ctasPerCGA3d =
|
||||
insertValue(blocked.getCTALayout().getCTAsPerCGA(), rank, 1);
|
||||
auto ctasSplitNum3d =
|
||||
insertValue(blocked.getCTALayout().getCTASplitNum(), rank, 1);
|
||||
auto ctaOrder3d =
|
||||
insertValue(blocked.getCTALayout().getCTAOrder(), rank, rank);
|
||||
auto ctaLayout3d = triton::gpu::CTALayoutAttr::get(
|
||||
reduce.getContext(), ctasPerCGA3d, ctasSplitNum3d, ctaOrder3d);
|
||||
auto blocked3d = triton::gpu::BlockedEncodingAttr::get(
|
||||
reduce.getContext(), sizePerThread3d, threadsPerWarp3d, warsPerCTA3d,
|
||||
order3d, ctaLayout3d);
|
||||
return blocked3d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
SmallVector<T> insertValue(ArrayRef<T> vec, unsigned index, int value) const {
|
||||
SmallVector<T> res(vec.begin(), vec.end());
|
||||
res.insert(res.begin() + index, static_cast<T>(value));
|
||||
return res;
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUOptimizeThreadLocalityPass() {
|
||||
return std::make_unique<TritonGPUOptimizeThreadLocalityPass>();
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,814 @@
|
||||
#include "PipelineExpander.h"
|
||||
#include "Schedule.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#define int_attr(num) builder.getI64IntegerAttr(num)
|
||||
|
||||
using namespace mlir;
|
||||
namespace tt = mlir::triton;
|
||||
namespace ttg = mlir::triton::gpu;
|
||||
namespace ttng = mlir::triton::nvidia_gpu;
|
||||
|
||||
// TODO: We can extra some helpers into common utilities once we add more
|
||||
// schedules.
|
||||
|
||||
/// Replace the yield with a new one with the given operands appended.
|
||||
static void appendToYield(scf::ForOp forOp, ArrayRef<Value> newOperands) {
|
||||
// Fix up the yield op.
|
||||
Operation *yieldOp = forOp.getBody()->getTerminator();
|
||||
SmallVector<Value> operands(yieldOp->getOperands().begin(),
|
||||
yieldOp->getOperands().end());
|
||||
operands.append(newOperands.begin(), newOperands.end());
|
||||
OpBuilder builder(yieldOp);
|
||||
builder.create<scf::YieldOp>(yieldOp->getLoc(), operands);
|
||||
yieldOp->erase();
|
||||
}
|
||||
|
||||
static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
|
||||
Value insertIdx, Value extractIdx) {
|
||||
OpBuilder builder(forOp);
|
||||
// Replace the load with insert/extract slice.
|
||||
builder.setInsertionPoint(loadOp);
|
||||
Location loc = loadOp.getLoc();
|
||||
auto insertOp = builder.create<ttg::InsertSliceAsyncOp>(
|
||||
loc, alloc.getType(), loadOp.getPtr(), alloc, insertIdx, loadOp.getMask(),
|
||||
loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(),
|
||||
loadOp.getIsVolatile(), /*axis*/ 0);
|
||||
auto commmit = builder.create<ttg::AsyncCommitGroupOp>(loc);
|
||||
|
||||
// Extract part.
|
||||
auto allocType = alloc.getType().cast<RankedTensorType>();
|
||||
RankedTensorType sliceType = RankedTensorType::get(
|
||||
{allocType.getShape()[1], allocType.getShape()[2]},
|
||||
allocType.getElementType(), allocType.getEncoding());
|
||||
auto extract = builder.create<ttg::ExtractSliceOp>(
|
||||
loc, sliceType, insertOp.getResult(),
|
||||
SmallVector<OpFoldResult>{extractIdx, int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
Operation *user = *loadOp.getResult().getUsers().begin();
|
||||
auto convertLayout = llvm::cast<ttg::ConvertLayoutOp>(user);
|
||||
auto newCvt = builder.create<ttg::ConvertLayoutOp>(
|
||||
convertLayout->getLoc(), convertLayout.getType(), extract.getResult());
|
||||
convertLayout->replaceAllUsesWith(newCvt->getResults());
|
||||
convertLayout->erase();
|
||||
loadOp.erase();
|
||||
|
||||
// Fix up the yield op.
|
||||
appendToYield(forOp, {insertOp});
|
||||
}
|
||||
|
||||
static void createTMALoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
|
||||
Value insertIdx, Value extractIdx, Value phase) {
|
||||
OpBuilder builder(forOp);
|
||||
Location loc = loadOp.getLoc();
|
||||
auto CTALayout = ttg::CTALayoutAttr::get(loadOp.getContext(),
|
||||
/*CTAsPerCGA*/ {1},
|
||||
/*CTASplitNum*/ {1},
|
||||
/*CTAOrder*/ {0});
|
||||
auto sharedEncoding = ttg::SharedEncodingAttr::get(loadOp.getContext(), 1, 1,
|
||||
1, {0}, CTALayout, false);
|
||||
int64_t numBuffers = alloc.getType().cast<RankedTensorType>().getShape()[0];
|
||||
auto mBarriersTy = RankedTensorType::get(
|
||||
{numBuffers}, builder.getIntegerType(64), sharedEncoding);
|
||||
// Allocate an array of mbarrier objects outside the loop.
|
||||
Value barrierArray =
|
||||
builder.create<ttng::AllocMBarrierOp>(loc, mBarriersTy, 1);
|
||||
// extract the barrier and emit arriver/copy/wait/extract code sequence.
|
||||
builder.setInsertionPoint(loadOp);
|
||||
auto mBarTy = tt::PointerType::get(builder.getIntegerType(64), 3);
|
||||
Value barrier = builder.create<ttng::ExtractMBarrierOp>(
|
||||
loc, mBarTy, barrierArray, insertIdx);
|
||||
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value threadId = builder.create<ttng::GetThreadIdOp>(loc);
|
||||
Value pred = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
threadId, zero);
|
||||
|
||||
auto loadTy = loadOp.getType().dyn_cast<RankedTensorType>();
|
||||
auto loadShape = loadTy.getShape();
|
||||
auto CTASplitNum = ttg::getCTASplitNum(loadTy.getEncoding());
|
||||
auto shapePerSlice = ttg::getShapePerCTA(CTASplitNum, loadShape);
|
||||
auto elemTy = loadTy.getElementType();
|
||||
unsigned elems = std::accumulate(shapePerSlice.begin(), shapePerSlice.end(),
|
||||
1, std::multiplies{});
|
||||
elems *= (elemTy.getIntOrFloatBitWidth() / 8);
|
||||
builder.create<ttng::MBarrierArriveOp>(loc, barrier, pred,
|
||||
/*remoteCtaId*/ nullptr,
|
||||
/*trackAsyncOp*/ false, elems);
|
||||
auto allocType = alloc.getType().cast<RankedTensorType>();
|
||||
auto insertOp = builder.create<ttng::InsertSliceAsyncV2Op>(
|
||||
loc, allocType, loadOp.getPtr(), alloc,
|
||||
/*index*/ insertIdx, barrier, loadOp.getMask(), loadOp.getOther(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile(),
|
||||
/*axis*/ 0);
|
||||
|
||||
RankedTensorType sliceType = RankedTensorType::get(
|
||||
{allocType.getShape()[1], allocType.getShape()[2]},
|
||||
allocType.getElementType(), allocType.getEncoding());
|
||||
auto extract = builder.create<mlir::triton::gpu::ExtractSliceOp>(
|
||||
loc, sliceType, insertOp.getResult(),
|
||||
SmallVector<OpFoldResult>{extractIdx, int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(sliceType.getShape()[0]),
|
||||
int_attr(sliceType.getShape()[1])},
|
||||
SmallVector<OpFoldResult>{int_attr(1), int_attr(1), int_attr(1)});
|
||||
|
||||
Value barrierWait = builder.create<ttng::ExtractMBarrierOp>(
|
||||
loc, mBarTy, barrierArray, extractIdx);
|
||||
builder.create<ttng::MBarrierWaitOp>(loc, barrierWait, phase);
|
||||
|
||||
Operation *user = *loadOp.getResult().getUsers().begin();
|
||||
auto convertLayout = llvm::cast<ttg::ConvertLayoutOp>(user);
|
||||
auto newCvt = builder.create<ttg::ConvertLayoutOp>(
|
||||
convertLayout->getLoc(), convertLayout.getType(), extract.getResult());
|
||||
convertLayout->replaceAllUsesWith(newCvt->getResults());
|
||||
convertLayout->erase();
|
||||
loadOp.erase();
|
||||
|
||||
// Fix up the yield op.
|
||||
appendToYield(forOp, {insertOp});
|
||||
}
|
||||
|
||||
/// Create an async load equivalent to the given load.
|
||||
static void createAsyncLoad(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc,
|
||||
Value insertIdx, Value extractIdx, Value phase) {
|
||||
if (isLoadFromTensorPtr(loadOp)) {
|
||||
createTMALoad(forOp, loadOp, alloc, insertIdx, extractIdx, phase);
|
||||
} else {
|
||||
createAsyncCopy(forOp, loadOp, alloc, insertIdx, extractIdx);
|
||||
}
|
||||
}
|
||||
|
||||
// Return the transitive use of the load which is a dot operand.
|
||||
static Value loadDotOperand(tt::LoadOp loadOp, bool &hasMMAV3) {
|
||||
// We only pipeline loads that have one covert_layout (to dot_op) use
|
||||
// TODO: lift this constraint in the future
|
||||
bool isCandidate = false;
|
||||
if (!loadOp.getResult().hasOneUse())
|
||||
return Value();
|
||||
|
||||
Operation *use = *loadOp.getResult().getUsers().begin();
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
auto tensorType =
|
||||
convertLayout.getResult().getType().cast<RankedTensorType>();
|
||||
if (auto sharedEnc =
|
||||
tensorType.getEncoding().dyn_cast<ttg::SharedEncodingAttr>()) {
|
||||
if (sharedEnc.getHasLeadingOffset()) {
|
||||
// MMA V3 case.
|
||||
auto newOrder = sharedEnc.getOrder();
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
auto oldOrder = ttg::getOrder(ty.getEncoding());
|
||||
if (newOrder[0] == oldOrder[0] || newOrder[1] == oldOrder[1]) {
|
||||
// The operand of MMAv3 is in SharedEncoding and it's order should
|
||||
// not be changed after FuseTranspositions Pass. So we only pipeline
|
||||
// the load if the order of the loaded BlockedEncoding is the same
|
||||
// as the order of the SharedEncoding it is converted to.
|
||||
// TODO: remove this constraint once the LoadOp supports transpose
|
||||
// fusion
|
||||
hasMMAV3 = true;
|
||||
return convertLayout.getResult();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Advance to the first conversion as long as the use resides in shared
|
||||
// memory and it has a single use itself
|
||||
while (use) {
|
||||
if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse())
|
||||
break;
|
||||
auto tensorType = use->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>())
|
||||
break;
|
||||
use = *use->getResult(0).getUsers().begin();
|
||||
}
|
||||
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType =
|
||||
convertLayout.getResult().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto dotOpEnc = tensorType.getEncoding()
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
return convertLayout.getResult();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Value();
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LoadDotOperand {
|
||||
LoadDotOperand(tt::LoadOp load, Value dotOperand)
|
||||
: load(load), dotOperand(dotOperand) {}
|
||||
tt::LoadOp load;
|
||||
Value dotOperand;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Collect loads to pipeline. Return success if we can pipeline this loop
|
||||
static void collectOpsToPipeline(scf::ForOp forOp,
|
||||
SmallVectorImpl<LoadDotOperand> &ops,
|
||||
bool &hasMMAV3) {
|
||||
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
|
||||
|
||||
// We cannot use forOp.walk(...) here because we only want to visit the
|
||||
// operations in the loop body block. Nested blocks are handled separately.
|
||||
for (Operation &op : forOp) {
|
||||
if (auto loadOp = dyn_cast<tt::LoadOp>(&op)) {
|
||||
bool candidate = false;
|
||||
if (isLoadFromTensorPtr(loadOp)) {
|
||||
// Map to TMA load.
|
||||
candidate = true;
|
||||
} else {
|
||||
auto ptr = loadOp.getPtr();
|
||||
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
|
||||
if (auto mask = loadOp.getMask())
|
||||
vec =
|
||||
std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
|
||||
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy || tensorTy.getRank() < 2)
|
||||
continue;
|
||||
auto ty =
|
||||
tensorTy.getElementType().cast<tt::PointerType>().getPointeeType();
|
||||
unsigned width = vec * ty.getIntOrFloatBitWidth();
|
||||
// We do not pipeline all loads for the following reasons:
|
||||
// 1. On nvidia GPUs, cp.async's cp-size can only be 4, 8 and 16.
|
||||
// 2. It's likely that pipling small loads won't offer much performance
|
||||
// improvement and may even hurt performance by increasing register
|
||||
// pressure.
|
||||
if (width >= 32)
|
||||
candidate = true;
|
||||
}
|
||||
if (!candidate)
|
||||
continue;
|
||||
Value dotOperand = loadDotOperand(loadOp, hasMMAV3);
|
||||
if (!dotOperand)
|
||||
continue;
|
||||
ops.emplace_back(loadOp, dotOperand);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create an allocation that can old distance number of loadOp shapes.
|
||||
static Value createAlloc(scf::ForOp &forOp, tt::LoadOp loadOp, Value dotOperand,
|
||||
unsigned distance) {
|
||||
OpBuilder builder(forOp);
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
if (!loadOp.getResult().hasOneUse())
|
||||
return Value();
|
||||
Attribute sharedEnc;
|
||||
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
|
||||
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
|
||||
if (auto dotOpEnc =
|
||||
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
auto convertLayout = dotOperand.getDefiningOp<ttg::ConvertLayoutOp>();
|
||||
bool needTrans = dyn_cast_or_null<tt::TransOp>(
|
||||
convertLayout->getOperand(0).getDefiningOp());
|
||||
unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth();
|
||||
sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans);
|
||||
} else {
|
||||
// MMAv3
|
||||
sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(),
|
||||
ttg::getOrder(ty.getEncoding()),
|
||||
CTALayout, ty.getElementType());
|
||||
}
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(), ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), distance);
|
||||
Type allocType =
|
||||
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
|
||||
Value alloc = builder.create<mlir::triton::gpu::AllocTensorOp>(
|
||||
loadOp.getLoc(), allocType);
|
||||
return alloc;
|
||||
}
|
||||
|
||||
// Convert load ops into their asyn version and apply multi-buffering based on
|
||||
// the number of stages.
|
||||
static void createAsynOps(scf::ForOp &forOp, ArrayRef<LoadDotOperand> loads,
|
||||
int numStages, bool hasMMAV3) {
|
||||
struct AsyncLoad {
|
||||
AsyncLoad(tt::LoadOp loadOp, Value alloc) : loadOp(loadOp), alloc(alloc) {}
|
||||
tt::LoadOp loadOp;
|
||||
Value alloc;
|
||||
};
|
||||
int numBuffers = numStages - 1;
|
||||
// For MMAv3 we need an extra buffer as this is assumed in the wgmma
|
||||
// pipelining post-processing.
|
||||
// TODO: Improve modeling of wgmma pipelining.
|
||||
if (hasMMAV3)
|
||||
numBuffers++;
|
||||
SmallVector<AsyncLoad> asyncLoads;
|
||||
SmallVector<Value> newOperands;
|
||||
bool needsMbarrierPhase = false;
|
||||
bool needsAsyncWait = false;
|
||||
for (const LoadDotOperand &loadOperand : loads) {
|
||||
tt::LoadOp loadOp = loadOperand.load;
|
||||
Value dotOperand = loadOperand.dotOperand;
|
||||
Value alloc = createAlloc(forOp, loadOp, dotOperand, numBuffers);
|
||||
assert(alloc && "Failed to create alloc for the async load.");
|
||||
newOperands.push_back(alloc);
|
||||
asyncLoads.emplace_back(loadOp, alloc);
|
||||
if (isLoadFromTensorPtr(loadOp))
|
||||
needsMbarrierPhase = true;
|
||||
else
|
||||
needsAsyncWait = true;
|
||||
}
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
Location loc = forOp.getLoc();
|
||||
// Create two new counters to index into the allocs.
|
||||
Value minusOne = builder.create<arith::ConstantIntOp>(loc, -1, 32);
|
||||
Value zero = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value one = builder.create<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value insertIdx = minusOne;
|
||||
Value extractIdx = minusOne;
|
||||
Value numBuffersVal =
|
||||
builder.create<arith::ConstantIntOp>(loc, numBuffers, 32);
|
||||
newOperands.push_back(insertIdx);
|
||||
newOperands.push_back(extractIdx);
|
||||
Value phase;
|
||||
if (needsMbarrierPhase) {
|
||||
phase = builder.create<arith::ConstantIntOp>(loc, 0, 1);
|
||||
newOperands.push_back(phase);
|
||||
}
|
||||
unsigned newOperandIndex = forOp.getBody()->getNumArguments();
|
||||
// Patch the loop to add the new loop carried dependencies.
|
||||
scf::ForOp newForOp =
|
||||
replaceForOpWithNewSignature(builder, forOp, newOperands);
|
||||
forOp.erase();
|
||||
forOp = newForOp;
|
||||
for (int i = 0; i < asyncLoads.size(); i++) {
|
||||
asyncLoads[i].alloc = newForOp.getBody()->getArgument(newOperandIndex + i);
|
||||
}
|
||||
insertIdx =
|
||||
newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size());
|
||||
extractIdx =
|
||||
newForOp.getBody()->getArgument(newOperandIndex + asyncLoads.size() + 1);
|
||||
|
||||
// Create two counters for the insert and extract indices to avoid creating
|
||||
// long liverange.
|
||||
builder.setInsertionPoint(asyncLoads.front().loadOp);
|
||||
insertIdx = builder.create<arith::AddIOp>(loc, insertIdx, one);
|
||||
Value cndIns = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
insertIdx, numBuffersVal);
|
||||
insertIdx = builder.create<arith::SelectOp>(loc, cndIns, insertIdx, zero);
|
||||
|
||||
extractIdx = builder.create<arith::AddIOp>(loc, extractIdx, one);
|
||||
Value cndExt = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
extractIdx, numBuffersVal);
|
||||
extractIdx = builder.create<arith::SelectOp>(loc, cndExt, extractIdx, zero);
|
||||
|
||||
if (needsMbarrierPhase) {
|
||||
phase = newForOp.getBody()->getArgument(newOperandIndex +
|
||||
asyncLoads.size() + 2);
|
||||
Value oneI1 = builder.create<arith::ConstantIntOp>(loc, 1, 1);
|
||||
Value nextPhase = builder.create<arith::XOrIOp>(loc, phase, oneI1);
|
||||
phase = builder.create<arith::SelectOp>(loc, cndExt, phase, nextPhase);
|
||||
}
|
||||
|
||||
bool firstLoad = true;
|
||||
for (AsyncLoad &asyncLoad : asyncLoads) {
|
||||
createAsyncLoad(forOp, asyncLoad.loadOp, asyncLoad.alloc, insertIdx,
|
||||
extractIdx, phase);
|
||||
firstLoad = false;
|
||||
}
|
||||
// Insert a waitOp after the first async copy. This does make the assumption
|
||||
// that the wait will be scheduled in a different stage that all the async
|
||||
// copy but we cannot guarantee that one wait is enough otherwise.
|
||||
for (auto &op : forOp.getBody()->without_terminator()) {
|
||||
if (isa<ttg::InsertSliceAsyncOp>(op)) {
|
||||
OpBuilder builder(op.getContext());
|
||||
builder.setInsertionPointAfter(&op);
|
||||
builder.create<ttg::AsyncWaitOp>(op.getLoc(), 0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
SmallVector<Value> newYieldOperands = {insertIdx, extractIdx};
|
||||
if (needsMbarrierPhase)
|
||||
newYieldOperands.push_back(phase);
|
||||
// Patch the yield with the updated counters.
|
||||
appendToYield(forOp, newYieldOperands);
|
||||
}
|
||||
|
||||
// Combine the current mask with the given predicate.
|
||||
static Value getPredMask(RewriterBase &rewriter, Type typeLike,
|
||||
Value currentMask, Value pred) {
|
||||
Type maskType = tt::getI1SameShape(typeLike);
|
||||
Location loc = pred.getLoc();
|
||||
Value mask = pred;
|
||||
if (maskType.isa<RankedTensorType>()) {
|
||||
mask = rewriter.create<tt::SplatOp>(loc, maskType, pred);
|
||||
}
|
||||
if (currentMask) {
|
||||
mask = rewriter.create<arith::AndIOp>(loc, mask, currentMask);
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
// Function to mask operations during scheduling.
|
||||
static Operation *predicateOp(RewriterBase &rewriter, Operation *op,
|
||||
Value pred) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
if (mlir::isMemoryEffectFree(op))
|
||||
return op;
|
||||
if (isa<ttg::AsyncCommitGroupOp>(op))
|
||||
return op;
|
||||
if (isa<ttg::AsyncWaitOp>(op))
|
||||
return op;
|
||||
if (auto insertOp = dyn_cast<ttg::InsertSliceAsyncOp>(op)) {
|
||||
rewriter.setInsertionPoint(insertOp);
|
||||
Value mask = getPredMask(rewriter, insertOp.getSrc().getType(),
|
||||
insertOp.getMask(), pred);
|
||||
insertOp.getMaskMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
if (auto insertOp = dyn_cast<ttng::InsertSliceAsyncV2Op>(op)) {
|
||||
rewriter.setInsertionPoint(insertOp);
|
||||
Value mask = getPredMask(
|
||||
rewriter,
|
||||
insertOp.getSrc().getType().cast<tt::PointerType>().getPointeeType(),
|
||||
insertOp.getMask(), pred);
|
||||
insertOp.getMaskMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
if (auto arriveOp = dyn_cast<ttng::MBarrierArriveOp>(op)) {
|
||||
rewriter.setInsertionPoint(arriveOp);
|
||||
Value mask = getPredMask(rewriter, rewriter.getIntegerType(1),
|
||||
arriveOp.getPred(), pred);
|
||||
arriveOp.getPredMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
if (isa<ttng::MBarrierWaitOp>(op)) {
|
||||
return op;
|
||||
}
|
||||
if (auto loadOp = dyn_cast<tt::LoadOp>(op)) {
|
||||
rewriter.setInsertionPoint(loadOp);
|
||||
Value mask = getPredMask(rewriter, loadOp.getPtr().getType(),
|
||||
loadOp.getMask(), pred);
|
||||
loadOp.getMaskMutable().assign(mask);
|
||||
return op;
|
||||
}
|
||||
|
||||
assert("don't know how to predicate this op" && false);
|
||||
return op;
|
||||
}
|
||||
|
||||
static void setWaitNum(Operation *op,
|
||||
mlir::triton::PipeliningOption::PipelinerPart part,
|
||||
unsigned iteration, unsigned numLoadsInStage) {
|
||||
if (auto waitOp = dyn_cast<ttg::AsyncWaitOp>(op)) {
|
||||
waitOp.setNum(numLoadsInStage);
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to recursively add dependencies to the same stage.
|
||||
static void addDep(Operation *op, DenseSet<Operation *> &deps,
|
||||
bool includeArg = true,
|
||||
DenseSet<Operation *> *filter = nullptr) {
|
||||
if (filter && filter->count(op))
|
||||
return;
|
||||
if (!deps.insert(op).second)
|
||||
return;
|
||||
for (Value operand : op->getOperands()) {
|
||||
Value v = operand;
|
||||
llvm::SmallDenseSet<Value> seen;
|
||||
while (auto arg = v.dyn_cast<BlockArgument>()) {
|
||||
if (!includeArg)
|
||||
break;
|
||||
if (!seen.insert(v).second)
|
||||
break;
|
||||
if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) {
|
||||
auto yieldOp = op->getBlock()->getTerminator();
|
||||
v = yieldOp->getOperand(arg.getArgNumber() - 1);
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
Operation *defOp = v.getDefiningOp();
|
||||
if (defOp && defOp->getBlock() == op->getBlock()) {
|
||||
addDep(defOp, deps, includeArg, filter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add operations to the shedule with the given stage based on the filter
|
||||
// function.
|
||||
static void addOps(scf::ForOp forOp, int stage,
|
||||
std::vector<std::pair<Operation *, unsigned>> &schedule,
|
||||
std::function<bool(Operation *)> filter) {
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (!filter(&op))
|
||||
continue;
|
||||
schedule.emplace_back(&op, stage);
|
||||
}
|
||||
}
|
||||
|
||||
// create the schedule for a matmul loop. This is ad hoc based on how we know
|
||||
// matmul loops should be pipelined and is not a generic scheduler.
|
||||
static std::vector<std::pair<Operation *, unsigned>>
|
||||
createSchedule(scf::ForOp forOp, int numStages, bool prefetchExtract) {
|
||||
SmallVector<Operation *> insertOps;
|
||||
SmallVector<Operation *> extractOps;
|
||||
// Find the insert/extract ops that will go respectively in stage 0 and stage
|
||||
// `numStages - 2`. All the other operations will go in stage `numStages - 1`.
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (isa<ttg::InsertSliceAsyncOp, ttg::AsyncCommitGroupOp,
|
||||
ttng::MBarrierArriveOp, ttng::InsertSliceAsyncV2Op>(op))
|
||||
insertOps.emplace_back(&op);
|
||||
if (prefetchExtract) {
|
||||
if (isa<ttg::ExtractSliceOp, ttg::AsyncWaitOp>(op))
|
||||
extractOps.emplace_back(&op);
|
||||
}
|
||||
}
|
||||
DenseSet<Operation *> insertAndDeps;
|
||||
for (Operation *op : insertOps) {
|
||||
addDep(op, insertAndDeps, false);
|
||||
}
|
||||
|
||||
// Find depenencies with distance of 1.
|
||||
SmallVector<Operation *> distanceOneUsers;
|
||||
for (Operation *op : insertAndDeps) {
|
||||
for (Value operand : op->getOperands()) {
|
||||
if (auto arg = operand.dyn_cast<BlockArgument>()) {
|
||||
if (arg.getArgNumber() > 0 && arg.getOwner() == op->getBlock()) {
|
||||
auto yieldOp = op->getBlock()->getTerminator();
|
||||
Value v = yieldOp->getOperand(arg.getArgNumber() - 1);
|
||||
Operation *defOp = v.getDefiningOp();
|
||||
if (defOp && insertAndDeps.count(defOp) == 0) {
|
||||
distanceOneUsers.push_back(defOp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Schedule loads with a distance of 1 in stage 0
|
||||
for (Operation *op : distanceOneUsers) {
|
||||
if (isa<tt::LoadOp>(op)) {
|
||||
addDep(op, insertAndDeps, true);
|
||||
}
|
||||
}
|
||||
// For the rest of the ops we can move then into stage 1 so that they can be
|
||||
// closer to their uses.
|
||||
DenseSet<Operation *> stage1deps;
|
||||
for (Operation *op : distanceOneUsers) {
|
||||
if (!isa<tt::LoadOp>(op)) {
|
||||
addDep(op, stage1deps, true, &insertAndDeps);
|
||||
}
|
||||
}
|
||||
|
||||
DenseSet<Operation *> extractAndDeps;
|
||||
for (Operation *op : extractOps) {
|
||||
addDep(op, extractAndDeps, true, &insertAndDeps);
|
||||
}
|
||||
std::vector<std::pair<Operation *, unsigned>> schedule;
|
||||
// Schedule stage `numStage - 1` first.
|
||||
addOps(forOp, numStages - 1, schedule, [&](Operation *op) {
|
||||
return insertAndDeps.count(op) == 0 && stage1deps.count(op) == 0 &&
|
||||
extractAndDeps.count(op) == 0;
|
||||
});
|
||||
|
||||
// Schedule some dependencies with distance of 1 into stage 1 to reduce
|
||||
// pressure.
|
||||
addOps(forOp, 1, schedule,
|
||||
[&](Operation *op) { return stage1deps.count(op); });
|
||||
|
||||
// Then Schedule stage 0.
|
||||
addOps(forOp, 0, schedule,
|
||||
[&](Operation *op) { return insertAndDeps.count(op); });
|
||||
|
||||
// Finally schedule the extract ops in stage `numStage - 2` so that they get
|
||||
// pre-fetched and play well with pretech pass.
|
||||
addOps(forOp, numStages - 2, schedule,
|
||||
[&](Operation *op) { return extractAndDeps.count(op); });
|
||||
return schedule;
|
||||
}
|
||||
|
||||
bool mlir::triton::preProcessLoopAndGetSchedule(
|
||||
scf::ForOp &forOp, int numStages, mlir::triton::PipeliningOption &options) {
|
||||
// 1. First collect "interesting" operations with a stage where to schedule
|
||||
// them. This gives a coarse scheduling for the loop.
|
||||
SmallVector<LoadDotOperand> loads;
|
||||
bool hasMMAV3 = false;
|
||||
collectOpsToPipeline(forOp, loads, hasMMAV3);
|
||||
if (loads.empty())
|
||||
return false;
|
||||
bool hasAsynCp = llvm::any_of(loads, [](LoadDotOperand &load) {
|
||||
return !isLoadFromTensorPtr(load.load);
|
||||
});
|
||||
// 2. Convert the loads into async loads and create the allocs.
|
||||
createAsynOps(forOp, loads, numStages, hasMMAV3);
|
||||
|
||||
// 3. Create the final schedule for the kernel loop. This will dictate the
|
||||
// stages and order of operations to the pipeline expander.
|
||||
std::vector<std::pair<Operation *, unsigned>> schedule =
|
||||
createSchedule(forOp, numStages, /*prefetchExtract=*/!hasMMAV3);
|
||||
|
||||
// 4. Fill out the pipeline options.
|
||||
options.getScheduleFn =
|
||||
[schedule](scf::ForOp forOp,
|
||||
std::vector<std::pair<Operation *, unsigned>> &s) {
|
||||
s = std::move(schedule);
|
||||
};
|
||||
options.peelEpilogue = false;
|
||||
options.predicateFn = predicateOp;
|
||||
options.supportDynamicLoops = true;
|
||||
unsigned numLoadsInStage = (numStages - 2) * loads.size();
|
||||
options.annotateFn =
|
||||
[numLoadsInStage](Operation *op,
|
||||
mlir::triton::PipeliningOption::PipelinerPart part,
|
||||
unsigned iteration) {
|
||||
return setWaitNum(op, part, iteration, numLoadsInStage);
|
||||
};
|
||||
|
||||
if (hasAsynCp) {
|
||||
// Insert a wait 0 after the loop
|
||||
OpBuilder builder(forOp);
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
builder.create<ttg::AsyncWaitOp>(forOp.getLoc(), 0);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/// MMA V3 post-processing.
|
||||
static bool selfDepend(tt::DotOp dotOp, scf::ForOp forOp,
|
||||
Operation **firstUse) {
|
||||
std::function<bool(Value, int, scf::ForOp)> dependOn =
|
||||
[&dependOn](Value v, int argId, scf::ForOp forOp) {
|
||||
auto op = v.getDefiningOp();
|
||||
if (isa<BlockArgument>(v)) {
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
auto iter = std::find(iterArgs.begin(), iterArgs.end(), v);
|
||||
if (iter != iterArgs.end())
|
||||
return std::distance(iterArgs.begin(), iter) == argId;
|
||||
} else {
|
||||
if (!op)
|
||||
return false;
|
||||
for (auto operand : op->getOperands()) {
|
||||
if (dependOn(operand, argId, forOp))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
auto result = dotOp.getResult();
|
||||
auto yieldOp = forOp.getBody()->getTerminator();
|
||||
int argIdx = -1;
|
||||
auto iter = std::find(yieldOp->getOperands().begin(),
|
||||
yieldOp->getOperands().end(), result);
|
||||
if (iter != yieldOp->getOperands().end())
|
||||
argIdx = std::distance(yieldOp->getOperands().begin(), iter);
|
||||
if (argIdx == -1)
|
||||
return false;
|
||||
for (auto operand : dotOp.getOperands()) {
|
||||
if (dependOn(operand, argIdx, forOp)) {
|
||||
auto iterArgs = forOp.getRegionIterArgs();
|
||||
*firstUse = iterArgs[argIdx].use_begin().getUser();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static void removeExtraWait(tt::nvidia_gpu::DotWaitOp dotWaitOp,
|
||||
bool hasDotWait0) {
|
||||
if (hasDotWait0) {
|
||||
dotWaitOp->erase();
|
||||
}
|
||||
}
|
||||
|
||||
void mlir::triton::asyncLaunchDots(scf::ForOp forOp) {
|
||||
Block *loop = forOp.getBody();
|
||||
auto getBlockNumInFor = [](Operation *op, scf::ForOp forOp) {
|
||||
if (!op)
|
||||
return -1l;
|
||||
auto lastOp = op;
|
||||
while (op->getBlock()->getParentOp() != forOp) {
|
||||
lastOp = op;
|
||||
op = op->getBlock()->getParentOp();
|
||||
}
|
||||
return std::distance(lastOp->getBlock()->getParent()->begin(),
|
||||
lastOp->getBlock()->getIterator());
|
||||
};
|
||||
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
|
||||
/// dots to be pipelined
|
||||
bool hasSyncDot = false;
|
||||
bool hasDotWait0 = false;
|
||||
SmallVector<tt::DotOp> allDots;
|
||||
SmallVector<tt::DotOp> dots;
|
||||
SmallVector<unsigned> resultNeedSync;
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotWaitOp = dyn_cast<tt::nvidia_gpu::DotWaitOp>(&op)) {
|
||||
auto attr = dotWaitOp->getAttrOfType<IntegerAttr>("pendings");
|
||||
auto pendingCount = attr.getInt();
|
||||
if (pendingCount == 0)
|
||||
hasDotWait0 = true;
|
||||
}
|
||||
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
|
||||
allDots.push_back(dotOp);
|
||||
}
|
||||
}
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
|
||||
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (auto resEnc = resTy.getEncoding().dyn_cast<ttg::MmaEncodingAttr>()) {
|
||||
if (resEnc && resEnc.isHopper()) {
|
||||
auto dot = dotOp.getResult();
|
||||
bool valid = true;
|
||||
|
||||
// all users of dot should be scf.yield
|
||||
if (!dot.hasOneUse())
|
||||
valid = false;
|
||||
if (!isa<scf::YieldOp>(*dot.getUsers().begin()))
|
||||
valid = false;
|
||||
|
||||
Operation *firstUse = nullptr;
|
||||
auto depend = selfDepend(dotOp, forOp, &firstUse);
|
||||
bool selfDirectDepend = (dotOp == firstUse);
|
||||
for (auto tempInAll : allDots) {
|
||||
auto iter = std::find(dots.begin(), dots.end(), tempInAll);
|
||||
if (iter != dots.end())
|
||||
continue;
|
||||
auto db = getBlockNumInFor(tempInAll, forOp);
|
||||
auto fb = getBlockNumInFor(firstUse, forOp);
|
||||
if (db < fb ||
|
||||
(db == fb && db >= 0 && tempInAll->isBeforeInBlock(firstUse)))
|
||||
hasSyncDot = true;
|
||||
}
|
||||
auto CArg = dotOp.getOperand(2);
|
||||
if (!(selfDirectDepend ||
|
||||
(depend && !selfDirectDepend && hasSyncDot)) ||
|
||||
!CArg.hasOneUse())
|
||||
valid = false;
|
||||
|
||||
if (valid) {
|
||||
dots.push_back(dotOp);
|
||||
resultNeedSync.push_back(
|
||||
dotOp->getUses().begin()->getOperandNumber());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Early stop: no need to continue if there is no valid dot in the loop.
|
||||
if (dots.empty())
|
||||
return;
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
|
||||
// wgmma ops by one stage.
|
||||
// This is needed to prevent shared memory inputs to be overriden before the
|
||||
// operation is completed.
|
||||
// TODO: merge this with the rest of the pipelining transformation and look at
|
||||
// a better representation for async dots.
|
||||
tt::DotOp lastDot = dots.back();
|
||||
auto loc = lastDot.getLoc();
|
||||
builder.setInsertionPointAfter(lastDot);
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
|
||||
lastDot.getLoc(), lastDot.getResult(), dots.size());
|
||||
|
||||
// 1. replace Dot with DotAsync
|
||||
for (size_t idx = 0; idx < dots.size(); ++idx) {
|
||||
tt::DotOp dotOp = dots[idx];
|
||||
builder.setInsertionPoint(dotOp);
|
||||
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
|
||||
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dotOp.replaceAllUsesWith(dotAsync.getResult());
|
||||
dotOp->erase();
|
||||
}
|
||||
|
||||
hasDotWait0 = hasDotWait0 || hasSyncDot;
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
SmallVector<Value> waitOperands;
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
waitOperands.push_back(result);
|
||||
}
|
||||
if (!waitOperands.empty()) {
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
|
||||
waitOperands, 0);
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(i), dotWait);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. potentially remove redundant dot_wait after dot_async if having mutiple
|
||||
// DotOp in the loop
|
||||
removeExtraWait(dotWait, hasDotWait0);
|
||||
}
|
||||
704
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
Normal file
704
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp
Normal file
@@ -0,0 +1,704 @@
|
||||
//===- LoopPipelining.cpp - Code to perform loop software pipelining-------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file implements loop software pipelining
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Fork of upstream pipeliner. This will be merged upstream once things are
|
||||
// stable. Modifications so far are:
|
||||
// -Bug fix for def with a distance of 1 scheduled in stage 0.
|
||||
// -Support dynamic loops and predicate operations in the prologue.
|
||||
// -Support for non-index type for induction variable.
|
||||
// -Support source with distance of 1 used multiple stages later.
|
||||
// -Fix bug when a value yield is used outside the loop and the value def is not
|
||||
// in the last stage. If we are not peeling the epilgue we need to remap the
|
||||
// output correctly.
|
||||
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
|
||||
#include "mlir/Dialect/SCF/Utils/Utils.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Support/MathExtras.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
#include "PipelineExpander.h"
|
||||
|
||||
#define DEBUG_TYPE "triton-loop-pipelining"
|
||||
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
|
||||
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::scf;
|
||||
using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
/// Helper to keep internal information during pipelining transformation.
|
||||
struct LoopPipelinerInternal {
|
||||
/// Coarse liverange information for ops used across stages.
|
||||
struct LiverangeInfo {
|
||||
unsigned lastUseStage = 0;
|
||||
unsigned defStage = 0;
|
||||
};
|
||||
|
||||
protected:
|
||||
ForOp forOp;
|
||||
unsigned maxStage = 0;
|
||||
DenseMap<Operation *, unsigned> stages;
|
||||
std::vector<Operation *> opOrder;
|
||||
Value ub;
|
||||
Value lb;
|
||||
Value step;
|
||||
bool dynamicLoop;
|
||||
triton::PipeliningOption::AnnotationlFnType annotateFn = nullptr;
|
||||
bool peelEpilogue;
|
||||
triton::PipeliningOption::PredicateOpFnType predicateFn = nullptr;
|
||||
|
||||
// When peeling the kernel we generate several version of each value for
|
||||
// different stage of the prologue. This map tracks the mapping between
|
||||
// original Values in the loop and the different versions
|
||||
// peeled from the loop.
|
||||
DenseMap<Value, llvm::SmallVector<Value>> valueMapping;
|
||||
|
||||
/// Assign a value to `valueMapping`, this means `val` represents the version
|
||||
/// `idx` of `key` in the epilogue.
|
||||
void setValueMapping(Value key, Value el, int64_t idx);
|
||||
|
||||
std::pair<Operation *, int64_t> getDefiningOpAndDistance(Value value);
|
||||
|
||||
public:
|
||||
/// Initalize the information for the given `op`, return true if it
|
||||
/// satisfies the pre-condition to apply pipelining.
|
||||
bool initializeLoopInfo(ForOp op, const triton::PipeliningOption &options);
|
||||
/// Emits the prologue, this creates `maxStage - 1` part which will contain
|
||||
/// operations from stages [0; i], where i is the part index.
|
||||
void emitPrologue(RewriterBase &rewriter);
|
||||
/// Gather liverange information for Values that are used in a different stage
|
||||
/// than its definition.
|
||||
llvm::MapVector<Value, LiverangeInfo> analyzeCrossStageValues();
|
||||
scf::ForOp createKernelLoop(
|
||||
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
|
||||
RewriterBase &rewriter,
|
||||
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap);
|
||||
/// Emits the pipelined kernel. This clones loop operations following user
|
||||
/// order and remaps operands defined in a different stage as their use.
|
||||
LogicalResult createKernel(
|
||||
scf::ForOp newForOp,
|
||||
const llvm::MapVector<Value, LiverangeInfo> &crossStageValues,
|
||||
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
|
||||
RewriterBase &rewriter);
|
||||
/// Emits the epilogue, this creates `maxStage - 1` part which will contain
|
||||
/// operations from stages [i; maxStage], where i is the part index.
|
||||
llvm::SmallVector<Value> emitEpilogue(RewriterBase &rewriter);
|
||||
};
|
||||
|
||||
bool LoopPipelinerInternal::initializeLoopInfo(
|
||||
ForOp op, const triton::PipeliningOption &options) {
|
||||
LDBG("Start initializeLoopInfo");
|
||||
forOp = op;
|
||||
ub = forOp.getUpperBound();
|
||||
lb = forOp.getLowerBound();
|
||||
step = forOp.getStep();
|
||||
|
||||
dynamicLoop = true;
|
||||
auto upperBoundCst = ub.getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto lowerBoundCst = lb.getDefiningOp<arith::ConstantIndexOp>();
|
||||
auto stepCst = step.getDefiningOp<arith::ConstantIndexOp>();
|
||||
if (!upperBoundCst || !lowerBoundCst || !stepCst) {
|
||||
if (!options.supportDynamicLoops) {
|
||||
LDBG("--dynamic loop not supported -> BAIL");
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
int64_t ubImm = upperBoundCst.value();
|
||||
int64_t lbImm = lowerBoundCst.value();
|
||||
int64_t stepImm = stepCst.value();
|
||||
int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm);
|
||||
if (numIteration > maxStage) {
|
||||
dynamicLoop = false;
|
||||
} else if (!options.supportDynamicLoops) {
|
||||
LDBG("--fewer loop iterations than pipeline stages -> BAIL");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
peelEpilogue = options.peelEpilogue;
|
||||
predicateFn = options.predicateFn;
|
||||
if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) {
|
||||
LDBG("--no epilogue or predicate set -> BAIL");
|
||||
return false;
|
||||
}
|
||||
std::vector<std::pair<Operation *, unsigned>> schedule;
|
||||
options.getScheduleFn(forOp, schedule);
|
||||
if (schedule.empty()) {
|
||||
LDBG("--empty schedule -> BAIL");
|
||||
return false;
|
||||
}
|
||||
|
||||
opOrder.reserve(schedule.size());
|
||||
for (auto &opSchedule : schedule) {
|
||||
maxStage = std::max(maxStage, opSchedule.second);
|
||||
stages[opSchedule.first] = opSchedule.second;
|
||||
opOrder.push_back(opSchedule.first);
|
||||
}
|
||||
|
||||
// All operations need to have a stage.
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
if (!stages.contains(&op)) {
|
||||
op.emitOpError("not assigned a pipeline stage");
|
||||
LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Currently, we do not support assigning stages to ops in nested regions. The
|
||||
// block of all operations assigned a stage should be the single `scf.for`
|
||||
// body block.
|
||||
for (const auto &[op, stageNum] : stages) {
|
||||
(void)stageNum;
|
||||
if (op == forOp.getBody()->getTerminator()) {
|
||||
op->emitError("terminator should not be assigned a stage");
|
||||
LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL");
|
||||
return false;
|
||||
}
|
||||
if (op->getBlock() != forOp.getBody()) {
|
||||
op->emitOpError("the owning Block of all operations assigned a stage "
|
||||
"should be the loop body block");
|
||||
LDBG("--the owning Block of all operations assigned a stage "
|
||||
"should be the loop body block: "
|
||||
<< *op << " -> BAIL");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Only support loop carried dependency with a distance of 1. This means the
|
||||
// source of all the scf.yield operands needs to be defined by operations in
|
||||
// the loop.
|
||||
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
|
||||
[this](Value operand) {
|
||||
Operation *def = operand.getDefiningOp();
|
||||
return !def || !stages.contains(def);
|
||||
})) {
|
||||
LDBG("--only support loop carried dependency with a distance of 1 -> BAIL");
|
||||
return false;
|
||||
}
|
||||
annotateFn = options.annotateFn;
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Clone `op` and call `callback` on the cloned op's oeprands as well as any
|
||||
/// operands of nested ops that:
|
||||
/// 1) aren't defined within the new op or
|
||||
/// 2) are block arguments.
|
||||
static Operation *
|
||||
cloneAndUpdateOperands(RewriterBase &rewriter, Operation *op,
|
||||
function_ref<void(OpOperand *newOperand)> callback) {
|
||||
Operation *clone = rewriter.clone(*op);
|
||||
for (OpOperand &operand : clone->getOpOperands())
|
||||
callback(&operand);
|
||||
clone->walk([&](Operation *nested) {
|
||||
for (OpOperand &operand : nested->getOpOperands()) {
|
||||
Operation *def = operand.get().getDefiningOp();
|
||||
if ((def && !clone->isAncestor(def)) || isa<BlockArgument>(operand.get()))
|
||||
callback(&operand);
|
||||
}
|
||||
});
|
||||
return clone;
|
||||
}
|
||||
|
||||
void LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
|
||||
// Initialize the iteration argument to the loop initiale values.
|
||||
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
||||
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
||||
setValueMapping(arg, operand.get(), 0);
|
||||
}
|
||||
auto yield = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
Location loc = forOp.getLoc();
|
||||
for (int64_t i = 0; i < maxStage; i++) {
|
||||
Value predicate;
|
||||
if (dynamicLoop) {
|
||||
Type t = ub.getType();
|
||||
// pred = ub > lb + (i * step)
|
||||
Value iv = rewriter.create<arith::AddIOp>(
|
||||
loc, lb,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(t, i))));
|
||||
predicate = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
|
||||
iv, ub);
|
||||
}
|
||||
|
||||
// special handling for induction variable as the increment is implicit.
|
||||
// iv = lb + i * step
|
||||
Type t = lb.getType();
|
||||
Value iv = rewriter.create<arith::AddIOp>(
|
||||
loc, lb,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(loc,
|
||||
rewriter.getIntegerAttr(t, i))));
|
||||
setValueMapping(forOp.getInductionVar(), iv, i);
|
||||
for (Operation *op : opOrder) {
|
||||
if (stages[op] > i)
|
||||
continue;
|
||||
Operation *newOp =
|
||||
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
|
||||
auto it = valueMapping.find(newOperand->get());
|
||||
if (it != valueMapping.end()) {
|
||||
Value replacement = it->second[i - stages[op]];
|
||||
newOperand->set(replacement);
|
||||
}
|
||||
});
|
||||
if (predicate) {
|
||||
newOp = predicateFn(rewriter, newOp, predicate);
|
||||
assert(newOp && "failed to predicate op.");
|
||||
}
|
||||
if (annotateFn)
|
||||
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Prologue, i);
|
||||
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
setValueMapping(op->getResult(destId), newOp->getResult(destId),
|
||||
i - stages[op]);
|
||||
// If the value is a loop carried dependency update the loop argument
|
||||
// mapping.
|
||||
for (OpOperand &operand : yield->getOpOperands()) {
|
||||
if (operand.get() != op->getResult(destId))
|
||||
continue;
|
||||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(destId), i - stages[op] + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<Operation *, int64_t>
|
||||
LoopPipelinerInternal::getDefiningOpAndDistance(Value value) {
|
||||
int64_t distance = 0;
|
||||
if (auto arg = dyn_cast<BlockArgument>(value)) {
|
||||
if (arg.getOwner() != forOp.getBody())
|
||||
return {nullptr, 0};
|
||||
// Ignore induction variable.
|
||||
if (arg.getArgNumber() == 0)
|
||||
return {nullptr, 0};
|
||||
distance++;
|
||||
value =
|
||||
forOp.getBody()->getTerminator()->getOperand(arg.getArgNumber() - 1);
|
||||
}
|
||||
Operation *def = value.getDefiningOp();
|
||||
if (!def)
|
||||
return {nullptr, 0};
|
||||
return {def, distance};
|
||||
}
|
||||
|
||||
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
LoopPipelinerInternal::analyzeCrossStageValues() {
|
||||
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo> crossStageValues;
|
||||
for (Operation *op : opOrder) {
|
||||
unsigned stage = stages[op];
|
||||
|
||||
auto analyzeOperand = [&](OpOperand &operand) {
|
||||
auto [def, distance] = getDefiningOpAndDistance(operand.get());
|
||||
if (!def)
|
||||
return;
|
||||
auto defStage = stages.find(def);
|
||||
if (defStage == stages.end() || defStage->second == stage ||
|
||||
defStage->second == stage + distance)
|
||||
return;
|
||||
assert(stage > defStage->second);
|
||||
LiverangeInfo &info = crossStageValues[operand.get()];
|
||||
info.defStage = defStage->second;
|
||||
info.lastUseStage = std::max(info.lastUseStage, stage);
|
||||
};
|
||||
|
||||
for (OpOperand &operand : op->getOpOperands())
|
||||
analyzeOperand(operand);
|
||||
visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand *operand) {
|
||||
analyzeOperand(*operand);
|
||||
});
|
||||
}
|
||||
return crossStageValues;
|
||||
}
|
||||
|
||||
scf::ForOp LoopPipelinerInternal::createKernelLoop(
|
||||
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
&crossStageValues,
|
||||
RewriterBase &rewriter,
|
||||
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap) {
|
||||
// Creates the list of initial values associated to values used across
|
||||
// stages. The initial values come from the prologue created above.
|
||||
// Keep track of the kernel argument associated to each version of the
|
||||
// values passed to the kernel.
|
||||
llvm::SmallVector<Value> newLoopArg;
|
||||
// For existing loop argument initialize them with the right version from the
|
||||
// prologue.
|
||||
for (const auto &retVal :
|
||||
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
|
||||
Operation *def = retVal.value().getDefiningOp();
|
||||
assert(def && "Only support loop carried dependencies of distance 1");
|
||||
unsigned defStage = stages[def];
|
||||
Value valueVersion = valueMapping[forOp.getRegionIterArgs()[retVal.index()]]
|
||||
[maxStage - defStage];
|
||||
assert(valueVersion);
|
||||
newLoopArg.push_back(valueVersion);
|
||||
}
|
||||
for (auto escape : crossStageValues) {
|
||||
LiverangeInfo &info = escape.second;
|
||||
Value value = escape.first;
|
||||
for (unsigned stageIdx = 0; stageIdx < info.lastUseStage - info.defStage;
|
||||
stageIdx++) {
|
||||
Value valueVersion =
|
||||
valueMapping[value][maxStage - info.lastUseStage + stageIdx];
|
||||
assert(valueVersion);
|
||||
newLoopArg.push_back(valueVersion);
|
||||
loopArgMap[std::make_pair(value, info.lastUseStage - info.defStage -
|
||||
stageIdx)] = newLoopArg.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Create the new kernel loop. When we peel the epilgue we need to peel
|
||||
// `numStages - 1` iterations. Then we adjust the upper bound to remove those
|
||||
// iterations.
|
||||
Value newUb = forOp.getUpperBound();
|
||||
if (peelEpilogue) {
|
||||
Type t = ub.getType();
|
||||
Location loc = forOp.getLoc();
|
||||
// newUb = ub - maxStage * step
|
||||
newUb = rewriter.create<arith::AddIOp>(
|
||||
loc, ub,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(t, -maxStage))));
|
||||
}
|
||||
auto newForOp =
|
||||
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
|
||||
forOp.getStep(), newLoopArg);
|
||||
// When there are no iter args, the loop body terminator will be created.
|
||||
// Since we always create it below, remove the terminator if it was created.
|
||||
if (!newForOp.getBody()->empty())
|
||||
rewriter.eraseOp(newForOp.getBody()->getTerminator());
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
LogicalResult LoopPipelinerInternal::createKernel(
|
||||
scf::ForOp newForOp,
|
||||
const llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
&crossStageValues,
|
||||
const llvm::DenseMap<std::pair<Value, unsigned>, unsigned> &loopArgMap,
|
||||
RewriterBase &rewriter) {
|
||||
valueMapping.clear();
|
||||
|
||||
// Create the kernel, we clone instruction based on the order given by
|
||||
// user and remap operands coming from a previous stages.
|
||||
rewriter.setInsertionPoint(newForOp.getBody(), newForOp.getBody()->begin());
|
||||
IRMapping mapping;
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs())) {
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
}
|
||||
SmallVector<Value> predicates(maxStage + 1, nullptr);
|
||||
if (!peelEpilogue) {
|
||||
// Create a predicate for each stage except the last stage.
|
||||
Location loc = newForOp.getLoc();
|
||||
Type t = ub.getType();
|
||||
for (unsigned i = 0; i < maxStage; i++) {
|
||||
// c = ub - (maxStage - i) * step
|
||||
Value c = rewriter.create<arith::AddIOp>(
|
||||
loc, ub,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getIntegerAttr(t, -int64_t(maxStage - i)))));
|
||||
|
||||
Value pred = rewriter.create<arith::CmpIOp>(
|
||||
newForOp.getLoc(), arith::CmpIPredicate::slt,
|
||||
newForOp.getInductionVar(), c);
|
||||
predicates[i] = pred;
|
||||
}
|
||||
}
|
||||
for (Operation *op : opOrder) {
|
||||
int64_t useStage = stages[op];
|
||||
auto *newOp = rewriter.clone(*op, mapping);
|
||||
SmallVector<OpOperand *> operands;
|
||||
// Collect all the operands for the cloned op and its nested ops.
|
||||
op->walk([&operands](Operation *nestedOp) {
|
||||
for (OpOperand &operand : nestedOp->getOpOperands()) {
|
||||
operands.push_back(&operand);
|
||||
}
|
||||
});
|
||||
for (OpOperand *operand : operands) {
|
||||
Operation *nestedNewOp = mapping.lookup(operand->getOwner());
|
||||
// Special case for the induction variable uses. We replace it with a
|
||||
// version incremented based on the stage where it is used.
|
||||
if (operand->get() == forOp.getInductionVar()) {
|
||||
rewriter.setInsertionPoint(newOp);
|
||||
|
||||
// offset = (maxStage - stages[op]) * step
|
||||
Type t = step.getType();
|
||||
Value offset = rewriter.create<arith::MulIOp>(
|
||||
forOp.getLoc(), step,
|
||||
rewriter.create<arith::ConstantOp>(
|
||||
forOp.getLoc(),
|
||||
rewriter.getIntegerAttr(t, maxStage - stages[op])));
|
||||
Value iv = rewriter.create<arith::AddIOp>(
|
||||
forOp.getLoc(), newForOp.getInductionVar(), offset);
|
||||
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
|
||||
rewriter.setInsertionPointAfter(newOp);
|
||||
continue;
|
||||
}
|
||||
Value source = operand->get();
|
||||
auto arg = dyn_cast<BlockArgument>(source);
|
||||
if (arg && arg.getOwner() == forOp.getBody()) {
|
||||
Value ret = forOp.getBody()->getTerminator()->getOperand(
|
||||
arg.getArgNumber() - 1);
|
||||
Operation *dep = ret.getDefiningOp();
|
||||
if (!dep)
|
||||
continue;
|
||||
auto stageDep = stages.find(dep);
|
||||
if (stageDep == stages.end() || stageDep->second == useStage)
|
||||
continue;
|
||||
// If the value is a loop carried value coming from stage N + 1 remap,
|
||||
// it will become a direct use.
|
||||
if (stageDep->second == useStage + 1) {
|
||||
nestedNewOp->setOperand(operand->getOperandNumber(),
|
||||
mapping.lookupOrDefault(ret));
|
||||
continue;
|
||||
}
|
||||
source = ret;
|
||||
}
|
||||
// For operands defined in a previous stage we need to remap it to use
|
||||
// the correct region argument. We look for the right version of the
|
||||
// Value based on the stage where it is used.
|
||||
Operation *def = source.getDefiningOp();
|
||||
if (!def)
|
||||
continue;
|
||||
auto stageDef = stages.find(def);
|
||||
if (stageDef == stages.end() || stageDef->second == useStage)
|
||||
continue;
|
||||
auto remap = loopArgMap.find(
|
||||
std::make_pair(operand->get(), useStage - stageDef->second));
|
||||
assert(remap != loopArgMap.end());
|
||||
nestedNewOp->setOperand(operand->getOperandNumber(),
|
||||
newForOp.getRegionIterArgs()[remap->second]);
|
||||
}
|
||||
|
||||
if (predicates[useStage]) {
|
||||
newOp = predicateFn(rewriter, newOp, predicates[useStage]);
|
||||
if (!newOp)
|
||||
return failure();
|
||||
// Remap the results to the new predicated one.
|
||||
for (auto values : llvm::zip(op->getResults(), newOp->getResults()))
|
||||
mapping.map(std::get<0>(values), std::get<1>(values));
|
||||
}
|
||||
rewriter.setInsertionPointAfter(newOp);
|
||||
if (annotateFn)
|
||||
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Kernel, 0);
|
||||
}
|
||||
|
||||
// Collect the Values that need to be returned by the forOp. For each
|
||||
// value we need to have `LastUseStage - DefStage` number of versions
|
||||
// returned.
|
||||
// We create a mapping between original values and the associated loop
|
||||
// returned values that will be needed by the epilogue.
|
||||
llvm::SmallVector<Value> yieldOperands;
|
||||
for (OpOperand &yielOperand :
|
||||
forOp.getBody()->getTerminator()->getOpOperands()) {
|
||||
Value source = mapping.lookupOrDefault(yielOperand.get());
|
||||
// When we don't peel the epilogue the yield value is used outside the loop
|
||||
// we need to make sure we return the version from numStages - defStage.
|
||||
if (!peelEpilogue &&
|
||||
!forOp.getResult(yielOperand.getOperandNumber()).use_empty()) {
|
||||
auto [def, distance] = getDefiningOpAndDistance(yielOperand.get());
|
||||
if (def) {
|
||||
auto defStage = stages.find(def);
|
||||
if (defStage != stages.end()) {
|
||||
Value pred = predicates[defStage->second];
|
||||
if (pred) {
|
||||
source = rewriter.create<arith::SelectOp>(
|
||||
pred.getLoc(), pred, source,
|
||||
newForOp.getBody()
|
||||
->getArguments()[yielOperand.getOperandNumber() + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
yieldOperands.push_back(source);
|
||||
}
|
||||
|
||||
for (auto &it : crossStageValues) {
|
||||
int64_t version = maxStage - it.second.lastUseStage + 1;
|
||||
unsigned numVersionReturned = it.second.lastUseStage - it.second.defStage;
|
||||
// add the original version to yield ops.
|
||||
// If there is a live range spanning across more than 2 stages we need to
|
||||
// add extra arg.
|
||||
for (unsigned i = 1; i < numVersionReturned; i++) {
|
||||
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
|
||||
version++);
|
||||
yieldOperands.push_back(
|
||||
newForOp.getBody()->getArguments()[yieldOperands.size() + 1 +
|
||||
newForOp.getNumInductionVars()]);
|
||||
}
|
||||
setValueMapping(it.first, newForOp->getResult(yieldOperands.size()),
|
||||
version++);
|
||||
yieldOperands.push_back(mapping.lookupOrDefault(it.first));
|
||||
}
|
||||
// Map the yield operand to the forOp returned value.
|
||||
for (const auto &retVal :
|
||||
llvm::enumerate(forOp.getBody()->getTerminator()->getOperands())) {
|
||||
Operation *def = retVal.value().getDefiningOp();
|
||||
assert(def && "Only support loop carried dependencies of distance 1");
|
||||
unsigned defStage = stages[def];
|
||||
if (defStage > 0) {
|
||||
setValueMapping(forOp.getRegionIterArgs()[retVal.index()],
|
||||
newForOp->getResult(retVal.index()),
|
||||
maxStage - defStage + 1);
|
||||
}
|
||||
}
|
||||
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
|
||||
return success();
|
||||
}
|
||||
|
||||
llvm::SmallVector<Value>
|
||||
LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter) {
|
||||
llvm::SmallVector<Value> returnValues(forOp->getNumResults());
|
||||
// Emit different versions of the induction variable. They will be
|
||||
// removed by dead code if not used.
|
||||
for (int64_t i = 0; i < maxStage; i++) {
|
||||
Location loc = forOp.getLoc();
|
||||
Type t = lb.getType();
|
||||
Value minusOne =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -1));
|
||||
// number of iterations = ((ub - 1) - lb) / step
|
||||
Value totlaNumIteration = rewriter.create<arith::DivUIOp>(
|
||||
loc,
|
||||
rewriter.create<arith::SubIOp>(
|
||||
loc, rewriter.create<arith::AddIOp>(loc, ub, minusOne), lb),
|
||||
step);
|
||||
// newLastIter = lb + step * ((((ub - 1) - lb) / step) - i)
|
||||
Value minusI =
|
||||
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(t, -i));
|
||||
Value newlastIter = rewriter.create<arith::AddIOp>(
|
||||
loc, lb,
|
||||
rewriter.create<arith::MulIOp>(
|
||||
loc, step,
|
||||
rewriter.create<arith::AddIOp>(loc, totlaNumIteration, minusI)));
|
||||
setValueMapping(forOp.getInductionVar(), newlastIter, maxStage - i);
|
||||
}
|
||||
// Emit `maxStage - 1` epilogue part that includes operations from stages
|
||||
// [i; maxStage].
|
||||
for (int64_t i = 1; i <= maxStage; i++) {
|
||||
for (Operation *op : opOrder) {
|
||||
if (stages[op] < i)
|
||||
continue;
|
||||
Operation *newOp =
|
||||
cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) {
|
||||
auto it = valueMapping.find(newOperand->get());
|
||||
if (it != valueMapping.end()) {
|
||||
Value replacement = it->second[maxStage - stages[op] + i];
|
||||
newOperand->set(replacement);
|
||||
}
|
||||
});
|
||||
if (annotateFn)
|
||||
annotateFn(newOp, triton::PipeliningOption::PipelinerPart::Epilogue,
|
||||
i - 1);
|
||||
for (unsigned destId : llvm::seq(unsigned(0), op->getNumResults())) {
|
||||
setValueMapping(op->getResult(destId), newOp->getResult(destId),
|
||||
maxStage - stages[op] + i);
|
||||
// If the value is a loop carried dependency update the loop argument
|
||||
// mapping and keep track of the last version to replace the original
|
||||
// forOp uses.
|
||||
for (OpOperand &operand :
|
||||
forOp.getBody()->getTerminator()->getOpOperands()) {
|
||||
if (operand.get() != op->getResult(destId))
|
||||
continue;
|
||||
unsigned version = maxStage - stages[op] + i + 1;
|
||||
// If the version is greater than maxStage it means it maps to the
|
||||
// original forOp returned value.
|
||||
if (version > maxStage) {
|
||||
returnValues[operand.getOperandNumber()] = newOp->getResult(destId);
|
||||
continue;
|
||||
}
|
||||
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
|
||||
newOp->getResult(destId), version);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return returnValues;
|
||||
}
|
||||
|
||||
void LoopPipelinerInternal::setValueMapping(Value key, Value el, int64_t idx) {
|
||||
auto it = valueMapping.find(key);
|
||||
// If the value is not in the map yet add a vector big enough to store all
|
||||
// versions.
|
||||
if (it == valueMapping.end())
|
||||
it =
|
||||
valueMapping
|
||||
.insert(std::make_pair(key, llvm::SmallVector<Value>(maxStage + 1)))
|
||||
.first;
|
||||
it->second[idx] = el;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FailureOr<ForOp>
|
||||
mlir::triton::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
|
||||
const triton::PipeliningOption &options,
|
||||
bool *modifiedIR) {
|
||||
if (modifiedIR)
|
||||
*modifiedIR = false;
|
||||
LoopPipelinerInternal pipeliner;
|
||||
if (!pipeliner.initializeLoopInfo(forOp, options))
|
||||
return failure();
|
||||
|
||||
if (modifiedIR)
|
||||
*modifiedIR = true;
|
||||
|
||||
// 1. Emit prologue.
|
||||
pipeliner.emitPrologue(rewriter);
|
||||
|
||||
// 2. Track values used across stages. When a value cross stages it will
|
||||
// need to be passed as loop iteration arguments.
|
||||
// We first collect the values that are used in a different stage than where
|
||||
// they are defined.
|
||||
llvm::MapVector<Value, LoopPipelinerInternal::LiverangeInfo>
|
||||
crossStageValues = pipeliner.analyzeCrossStageValues();
|
||||
|
||||
// Mapping between original loop values used cross stage and the block
|
||||
// arguments associated after pipelining. A Value may map to several
|
||||
// arguments if its liverange spans across more than 2 stages.
|
||||
llvm::DenseMap<std::pair<Value, unsigned>, unsigned> loopArgMap;
|
||||
// 3. Create the new kernel loop and return the block arguments mapping.
|
||||
ForOp newForOp =
|
||||
pipeliner.createKernelLoop(crossStageValues, rewriter, loopArgMap);
|
||||
// Create the kernel block, order ops based on user choice and remap
|
||||
// operands.
|
||||
if (failed(pipeliner.createKernel(newForOp, crossStageValues, loopArgMap,
|
||||
rewriter)))
|
||||
return failure();
|
||||
|
||||
llvm::SmallVector<Value> returnValues =
|
||||
newForOp.getResults().take_front(forOp->getNumResults());
|
||||
if (options.peelEpilogue) {
|
||||
// 4. Emit the epilogue after the new forOp.
|
||||
rewriter.setInsertionPointAfter(newForOp);
|
||||
returnValues = pipeliner.emitEpilogue(rewriter);
|
||||
}
|
||||
// 5. Erase the original loop and replace the uses with the epilogue output.
|
||||
if (forOp->getNumResults() > 0)
|
||||
rewriter.replaceOp(forOp, returnValues);
|
||||
else
|
||||
rewriter.eraseOp(forOp);
|
||||
|
||||
return newForOp;
|
||||
}
|
||||
101
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h
Normal file
101
lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.h
Normal file
@@ -0,0 +1,101 @@
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
#ifndef TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
|
||||
#define TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
|
||||
|
||||
// This is a fork of upstream pipeline transformation. This will be merged back
|
||||
// upstream once we have a stable solution.
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class RewriterBase;
|
||||
class Operation;
|
||||
class Value;
|
||||
|
||||
namespace scf {
|
||||
class ForOp;
|
||||
}
|
||||
|
||||
namespace triton {
|
||||
|
||||
/// Options to dictate how loops should be pipelined.
|
||||
struct PipeliningOption {
|
||||
/// Lambda returning all the operation in the forOp, with their stage, in the
|
||||
/// order picked for the pipelined loop.
|
||||
using GetScheduleFnType = std::function<void(
|
||||
scf::ForOp, std::vector<std::pair<Operation *, unsigned>> &)>;
|
||||
GetScheduleFnType getScheduleFn = nullptr;
|
||||
enum class PipelinerPart {
|
||||
Prologue,
|
||||
Kernel,
|
||||
Epilogue,
|
||||
};
|
||||
/// Lambda called by the pipeliner to allow the user to annotate the IR while
|
||||
/// it is generated.
|
||||
/// The callback passes the operation created along with the part of the
|
||||
/// pipeline and the iteration index. The iteration index is always 0 for the
|
||||
/// kernel. For the prologue and epilogue, it corresponds to the iteration
|
||||
/// peeled out of the loop in the range [0, maxStage[.
|
||||
using AnnotationlFnType =
|
||||
std::function<void(Operation *, PipelinerPart, unsigned)>;
|
||||
AnnotationlFnType annotateFn = nullptr;
|
||||
|
||||
/// Control whether the epilogue should be peeled out of the loop or
|
||||
/// operations should be predicated to skip the early stages in the last loop
|
||||
/// iterations. If the epilogue is predicated; the user needs to provide a
|
||||
/// lambda to generate the predicated version of operations.
|
||||
bool peelEpilogue = true;
|
||||
|
||||
/// Control whether the transformation checks that the number of iterations is
|
||||
/// greater or equal to the number of stages and skip the transformation if
|
||||
/// this is not the case. If the loop is dynamic and this is set to true the
|
||||
/// pipeliner will have to predicate operations in the the prologue/epilogue.
|
||||
bool supportDynamicLoops = false;
|
||||
|
||||
// Callback to predicate operations when the prologue or epilogue are not
|
||||
// peeled. This takes the original operation, an i1 predicate value and the
|
||||
// pattern rewriter. It is expected to replace the given operation with
|
||||
// the predicated equivalent and return it, or return nullptr if the
|
||||
// predication is impossible. In the latter case, pipelining will fail and
|
||||
// may leave IR in a partially transformed state.
|
||||
using PredicateOpFnType =
|
||||
std::function<Operation *(RewriterBase &, Operation *, Value)>;
|
||||
PredicateOpFnType predicateFn = nullptr;
|
||||
|
||||
// TODO: add option to decide if the prologue should be peeled.
|
||||
};
|
||||
|
||||
/// Generate a pipelined version of the scf.for loop based on the schedule given
|
||||
/// as option. This applies the mechanical transformation of changing the loop
|
||||
/// and generating the prologue/epilogue for the pipelining and doesn't make any
|
||||
/// decision regarding the schedule.
|
||||
/// Based on the options the loop is split into several stages.
|
||||
/// The transformation assumes that the scheduling given by user is valid.
|
||||
/// For example if we break a loop into 3 stages named S0, S1, S2 we would
|
||||
/// generate the following code with the number in parenthesis as the iteration
|
||||
/// index:
|
||||
///
|
||||
/// S0(0) // Prologue
|
||||
/// S0(1) S1(0) // Prologue
|
||||
/// scf.for %I = %C0 to %N - 2 {
|
||||
/// S0(I+2) S1(I+1) S2(I) // Pipelined kernel
|
||||
/// }
|
||||
/// S1(N) S2(N-1) // Epilogue
|
||||
/// S2(N) // Epilogue
|
||||
///
|
||||
/// If `modifiedIR` is provided, it will be set to a value that indicates
|
||||
/// whether pipelining modified the IR before failing, signaling to the caller
|
||||
/// whether they can proceed with different transformations.
|
||||
FailureOr<scf::ForOp> pipelineForLoop(RewriterBase &rewriter, scf::ForOp forOp,
|
||||
const PipeliningOption &options,
|
||||
bool *modifiedIR = nullptr);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_PIPELINE_H_
|
||||
27
lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h
Normal file
27
lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#ifndef TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
|
||||
#define TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
|
||||
|
||||
#include "PipelineExpander.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include <vector>
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
/// This fill out the pipelining options including schedule and annotations for
|
||||
/// wait ops. This also does pre-processing by converting some of the loads into
|
||||
/// async loads so that the IR is ready to be pipelined.
|
||||
bool preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
|
||||
mlir::triton::PipeliningOption &options);
|
||||
|
||||
/// This does post-processing on the pipelined loop to try to pipeline wgmma
|
||||
/// ops.
|
||||
// TODO: this should be included as part of the pipeline but currently the wgmma
|
||||
// wait modeling is problematic.
|
||||
void asyncLaunchDots(scf::ForOp forOp);
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_
|
||||
@@ -0,0 +1,88 @@
|
||||
#include "PipelineExpander.h"
|
||||
#include "Schedule.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/Support/Debug.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// This file will create a schedule that will be handed over to the pipeline
|
||||
// expander.
|
||||
// Software pipeliners are usually separated into two pieces, one that create a
|
||||
// modulo schedule and an expander that rewrites the loop and emits a prologue
|
||||
// and epilogue. This pass first calls a helper that will pre-process the IR
|
||||
// to create async operations and create a modulo schedule. Then we call the
|
||||
// expander to generate the prologue and new loop.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
static void pipelineLoop(scf::ForOp forOp, int numStages) {
|
||||
mlir::triton::PipeliningOption options;
|
||||
// Skip loop with distance > 1 for now.
|
||||
// TODO: relax the constraint in the expander.
|
||||
if (llvm::any_of(forOp.getBody()->getTerminator()->getOperands(),
|
||||
[](Value operand) {
|
||||
Operation *def = operand.getDefiningOp();
|
||||
return !def;
|
||||
}))
|
||||
return;
|
||||
|
||||
bool foundSchedule = false;
|
||||
foundSchedule = preProcessLoopAndGetSchedule(forOp, numStages, options);
|
||||
|
||||
// TODO: add more pipelines strategy.
|
||||
if (!foundSchedule)
|
||||
return;
|
||||
|
||||
IRRewriter rewriter(forOp->getContext());
|
||||
rewriter.setInsertionPoint(forOp);
|
||||
FailureOr<scf::ForOp> newForOp =
|
||||
mlir::triton::pipelineForLoop(rewriter, forOp, options);
|
||||
|
||||
if (succeeded(newForOp))
|
||||
mlir::triton::asyncLaunchDots(newForOp.value());
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
PipelinePass() = default;
|
||||
PipelinePass(int numStages, int numWarps, int numCTAs,
|
||||
int computeCapability) {
|
||||
this->numStages = numStages;
|
||||
this->numWarps = numWarps;
|
||||
this->numCTAs = numCTAs;
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
if (this->numStages <= 1)
|
||||
return;
|
||||
SmallVector<scf::ForOp> loops;
|
||||
getOperation()->walk([&](scf::ForOp forOp) { loops.push_back(forOp); });
|
||||
for (scf::ForOp forOp : loops) {
|
||||
pipelineLoop(forOp, numStages);
|
||||
}
|
||||
}
|
||||
};
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUPipelinePass(int numStages,
|
||||
int numWarps,
|
||||
int numCTAs,
|
||||
int computeCapability) {
|
||||
return std::make_unique<PipelinePass>(numStages, numWarps, numCTAs,
|
||||
computeCapability);
|
||||
}
|
||||
@@ -332,9 +332,6 @@ SmallVector<Value> LayoutPropagation::propagateToUsers(Value value,
|
||||
setEncoding({afterArg, result}, info, changed, user);
|
||||
continue;
|
||||
}
|
||||
// Workaround: don't propagate through truncI
|
||||
if (isa<arith::TruncIOp>(user))
|
||||
continue;
|
||||
if (user->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() ||
|
||||
user->hasTrait<mlir::OpTrait::Elementwise>() ||
|
||||
isa<triton::ReduceOp, triton::ExpandDimsOp,
|
||||
@@ -755,7 +752,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) {
|
||||
map(oldResult, newResult);
|
||||
return newOp;
|
||||
}
|
||||
assert(0 && "unexpected op in rewrite");
|
||||
llvm::report_fatal_error("unexpected op in rewrite");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -772,34 +769,6 @@ static bool canBeRemat(Operation *op) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Replace ForOp with a new ForOp with extra operands. The YieldOp is not
|
||||
// updated and needs to be updated separatly for the loop to be correct.
|
||||
static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter,
|
||||
scf::ForOp loop,
|
||||
ValueRange newIterOperands) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(loop);
|
||||
|
||||
// Create a new loop before the existing one, with the extra operands.
|
||||
rewriter.setInsertionPoint(loop);
|
||||
auto operands = llvm::to_vector<4>(loop.getInitArgs());
|
||||
operands.append(newIterOperands.begin(), newIterOperands.end());
|
||||
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
|
||||
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
|
||||
operands);
|
||||
newLoop.getBody()->erase();
|
||||
|
||||
newLoop.getRegion().getBlocks().splice(
|
||||
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
|
||||
for (Value operand : newIterOperands)
|
||||
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
|
||||
|
||||
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
|
||||
loop.getNumResults())))
|
||||
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
|
||||
return newLoop;
|
||||
}
|
||||
|
||||
static void rewriteSlice(SetVector<Value> &slice,
|
||||
DenseMap<Value, Attribute> &layout,
|
||||
ConvertLayoutOp convertOp, IRMapping &mapping) {
|
||||
|
||||
@@ -98,8 +98,6 @@ TritonGPUConversionTarget::TritonGPUConversionTarget(
|
||||
// Some ops from SCF are illegal
|
||||
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
|
||||
scf::ReduceReturnOp>();
|
||||
// We have custom versions of some arith operators
|
||||
addIllegalOp<arith::CmpIOp, arith::CmpFOp>();
|
||||
|
||||
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
|
||||
triton::TritonDialect, cf::ControlFlowDialect,
|
||||
|
||||
@@ -232,8 +232,10 @@ std::string GraphLayoutMarker::getColor(const Type &type) const {
|
||||
return "orange";
|
||||
else if (layout.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return "orangered";
|
||||
else
|
||||
assert(0 && "Unrecognized layout");
|
||||
else {
|
||||
llvm::report_fatal_error("Unrecognized layout");
|
||||
return "unknown";
|
||||
}
|
||||
} else {
|
||||
return "white";
|
||||
}
|
||||
@@ -342,11 +344,39 @@ bool canFoldIntoConversion(Operation *op, Attribute targetEncoding) {
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (auto view = dyn_cast<triton::ViewOp>(op)) {
|
||||
auto viewDstType = view.getType().cast<RankedTensorType>();
|
||||
RankedTensorType newDstType = RankedTensorType::get(
|
||||
viewDstType.getShape(), viewDstType.getElementType(), targetEncoding);
|
||||
return !triton::gpu::isExpensiveView(view.getOperand().getType(),
|
||||
newDstType);
|
||||
}
|
||||
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp>(op);
|
||||
triton::MakeRangeOp, triton::SplatOp>(op);
|
||||
}
|
||||
|
||||
//
|
||||
scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, scf::ForOp loop,
|
||||
ValueRange newIterOperands) {
|
||||
OpBuilder::InsertionGuard g(rewriter);
|
||||
rewriter.setInsertionPoint(loop);
|
||||
|
||||
// Create a new loop before the existing one, with the extra operands.
|
||||
auto operands = llvm::to_vector<4>(loop.getInitArgs());
|
||||
operands.append(newIterOperands.begin(), newIterOperands.end());
|
||||
scf::ForOp newLoop = rewriter.create<scf::ForOp>(
|
||||
loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
|
||||
operands);
|
||||
newLoop.getBody()->erase();
|
||||
newLoop.getRegion().getBlocks().splice(
|
||||
newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
|
||||
for (Value operand : newIterOperands)
|
||||
newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
|
||||
|
||||
for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
|
||||
loop.getNumResults())))
|
||||
std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
|
||||
return newLoop;
|
||||
}
|
||||
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping) {
|
||||
|
||||
@@ -79,6 +79,17 @@ void CreateMutexOp::build(::mlir::OpBuilder &builder,
|
||||
build(builder, state, MutexType::get(builder.getContext()));
|
||||
}
|
||||
|
||||
///--- DotWaitOp ---
|
||||
LogicalResult DotWaitOp::inferReturnTypes(
|
||||
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
for (Value operand : operands)
|
||||
inferredReturnTypes.push_back(operand.getType());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace nvidia_gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -68,7 +68,8 @@ Attribute replaceCTALayout(Attribute layout, llvm::ArrayRef<int64_t> shape,
|
||||
replaceCTALayout(sliceLayout.getParent(), shape, newCTALayout));
|
||||
} else {
|
||||
// Other layouts are generated by passes after PlanCTAPass
|
||||
assert(0 && "replaceCTALayout not implemented");
|
||||
llvm::report_fatal_error("replaceCTALayout not implemented");
|
||||
return layout;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -393,7 +394,8 @@ bool CTAPlanner::propagateBackward(CastOp cast) {
|
||||
Value output = cast.getResult(0);
|
||||
unsigned numUsers = getNumUsers(input);
|
||||
if (numUsers == 0) {
|
||||
assert(0 && "Unreachable branch");
|
||||
llvm::report_fatal_error("Unreachable branch");
|
||||
return false;
|
||||
} else if (numUsers == 1) {
|
||||
Type outTy = output.getType();
|
||||
if (auto ptrTy = outTy.dyn_cast<triton::PointerType>())
|
||||
@@ -649,7 +651,7 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
|
||||
return true;
|
||||
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
|
||||
return externElementwiseOp.getPure();
|
||||
if (llvm::isa<ttg::CmpIOp, ttg::CmpFOp, ttg::SelectOp>(op))
|
||||
if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
@@ -711,7 +713,7 @@ bool CTAPlanner::processExpandDimsBackward(triton::ExpandDimsOp expandDims,
|
||||
|
||||
bool CTAPlanner::processExpandDimsForward(triton::ExpandDimsOp expandDims,
|
||||
Attribute newSrcLayout) {
|
||||
assert(0 && "processExpandDimsForward not implemented yet");
|
||||
llvm::report_fatal_error("processExpandDimsForward not implemented yet");
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -827,7 +829,7 @@ int findResultIndex(Operation *op, Value result) {
|
||||
for (int i = 0; i < op->getNumResults(); ++i)
|
||||
if (op->getResult(i) == result)
|
||||
return i;
|
||||
assert(0 && "Invalid index of op result");
|
||||
llvm::report_fatal_error("Invalid index of op result");
|
||||
return -1;
|
||||
}
|
||||
|
||||
@@ -849,7 +851,7 @@ bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) {
|
||||
auto newType = cast.getResult(0).getType();
|
||||
return processForOp(forOp, index, newType);
|
||||
} else {
|
||||
assert(0 && "Unexpected parent op of block argument");
|
||||
llvm::report_fatal_error("Unexpected parent op of block argument");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
@@ -869,7 +871,7 @@ bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) {
|
||||
else if (auto forOp = llvm::dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
|
||||
return processForOp(forOp, index, newType);
|
||||
else
|
||||
assert(0 && "Unexpected parent op of YieldOp");
|
||||
llvm::report_fatal_error("Unexpected parent op of YieldOp");
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -936,7 +938,8 @@ bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) {
|
||||
Operation *clonedOp = builder.clone(*defOp);
|
||||
newInput = clonedOp->getResult(0);
|
||||
} else {
|
||||
assert(0 && "Layout conflict for block arg"); // TODO
|
||||
llvm::report_fatal_error("Layout conflict for block arg"); // TODO
|
||||
return false;
|
||||
}
|
||||
}
|
||||
first = false;
|
||||
|
||||
@@ -55,7 +55,7 @@ bool isDivisible(Value v, unsigned divisor) {
|
||||
auto func = dyn_cast<tt::FuncOp>(parentOp);
|
||||
assert(func);
|
||||
if (auto attr = func.getArgAttrOfType<IntegerAttr>(blockArg.getArgNumber(),
|
||||
"tt.max_divisibility"))
|
||||
"tt.divisibility"))
|
||||
return attr.getValue().getZExtValue() % divisor == 0;
|
||||
return false;
|
||||
} else if (v.getParentBlock()->isEntryBlock() && (!v.isa<BlockArgument>())) {
|
||||
@@ -98,13 +98,8 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, int computeCapability) {
|
||||
return !(boxDimSwizzle && strideDivisible && enableTMA);
|
||||
}
|
||||
|
||||
// TODO: When encoding exists use triton::gpu::CmpIOp as arith::CmpIOp doesn't
|
||||
// play well with encoding attributes. Move back to arith::CmpIOp when this pass
|
||||
// moves back to triton IR level.
|
||||
Value createCmpOp(OpBuilder &builder, Location loc, RankedTensorType type,
|
||||
arith::CmpIPredicate pred, Value lhs, Value rhs) {
|
||||
if (type.getEncoding())
|
||||
return builder.create<ttg::CmpIOp>(loc, type, pred, lhs, rhs);
|
||||
return builder.create<arith::CmpIOp>(loc, type, pred, lhs, rhs);
|
||||
}
|
||||
|
||||
@@ -358,12 +353,17 @@ class TritonGPURewriteTensorPointerPass
|
||||
: public TritonGPURewriteTensorPointerBase<
|
||||
TritonGPURewriteTensorPointerPass> {
|
||||
private:
|
||||
int computeCapability;
|
||||
// int computeCapability;
|
||||
DenseMap<Value, RewritedInfo> rewritedInfo;
|
||||
|
||||
public:
|
||||
explicit TritonGPURewriteTensorPointerPass(int computeCapability)
|
||||
: computeCapability(computeCapability) {}
|
||||
// explicit TritonGPURewriteTensorPointerPass(int computeCapability)
|
||||
// : computeCapability(computeCapability) {}
|
||||
|
||||
TritonGPURewriteTensorPointerPass() = default;
|
||||
TritonGPURewriteTensorPointerPass(int computeCapability) {
|
||||
this->computeCapability = computeCapability;
|
||||
}
|
||||
|
||||
static bool needRewrite(Operation *op, const DenseSet<Value> &valueToRemove) {
|
||||
if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
|
||||
@@ -763,17 +763,16 @@ public:
|
||||
ModuleOp mod = getOperation();
|
||||
|
||||
DenseSet<Value> valueToRemove;
|
||||
mod.walk([&valueToRemove,
|
||||
computeCapability = this->computeCapability](Operation *op) {
|
||||
mod.walk([&valueToRemove, this](Operation *op) {
|
||||
if (auto makeTensorPtrOp = dyn_cast<tt::MakeTensorPtrOp>(op)) {
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(op->getResult(0));
|
||||
}
|
||||
if (llvm::isa<tt::AdvanceOp>(op)) {
|
||||
auto src = op->getOperand(0);
|
||||
if (tt::isTensorPointerType(src.getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(src);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability)) {
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability)) {
|
||||
valueToRemove.insert(op->getResult(0));
|
||||
}
|
||||
}
|
||||
@@ -782,7 +781,7 @@ public:
|
||||
auto src = op->getOperand(0);
|
||||
if (tt::isTensorPointerType(src.getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(src);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(src);
|
||||
}
|
||||
}
|
||||
@@ -791,7 +790,7 @@ public:
|
||||
for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) {
|
||||
if (tt::isTensorPointerType(iterOperands[i].getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(iterOperands[i]);
|
||||
}
|
||||
}
|
||||
@@ -800,7 +799,7 @@ public:
|
||||
for (unsigned i = 0, size = yieldOp.getNumOperands(); i < size; ++i) {
|
||||
if (tt::isTensorPointerType(operands[i].getType())) {
|
||||
auto makeTensorPtrOp = getMakeTensorPtrOp(operands[i]);
|
||||
if (shouldRemove(makeTensorPtrOp, computeCapability))
|
||||
if (shouldRemove(makeTensorPtrOp, this->computeCapability))
|
||||
valueToRemove.insert(operands[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -892,7 +892,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(0), dotWait);
|
||||
|
||||
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
|
||||
Reference in New Issue
Block a user