[MLIR] Added tritongpu-stream-pipeline pass (#305)

* [MLIR] Added tritongpu-stream-pipeline pass
     - Prologue: Hoist the pipelinable load operations and shared memory store
       for the ramp up stage
     - Pipelined Loop: Assemble the loop body minus last iteration
       - Prefetch next tile from global into regs (while computing from previous)
       - Non-load loop body
       - Store next tile into shared mem
     - Epilogue: Peeled non-load loop body for last iteration

* * updated comment
This commit is contained in:
SJW
2023-09-07 15:24:59 -05:00
committed by GitHub
parent 83a0958566
commit 491eb9ddfe
6 changed files with 856 additions and 2 deletions

View File

@@ -6,6 +6,8 @@
namespace mlir {
std::unique_ptr<Pass> createTritonGPUPipelinePass(int numStages = 2);
std::unique_ptr<Pass> createTritonGPUStreamPipelinePass();
std::unique_ptr<Pass>
createTritonGPUAccelerateMatmulPass(int computeCapability = 80);

View File

@@ -24,6 +24,21 @@ def TritonGPUPipeline : Pass<"tritongpu-pipeline", "mlir::ModuleOp"> {
];
}
def TritonGPUStreamPipeline : Pass<"tritongpu-stream-pipeline", "mlir::ModuleOp"> {
let summary = "pipeline";
let description = [{
Pipeline global loads through registers to shared memory while computing on previous
tile
}];
let constructor = "mlir::createTritonGPUStreamPipelinePass()";
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
"mlir::scf::SCFDialect",
"mlir::arith::ArithDialect"];
}
def TritonGPUPrefetch : Pass<"tritongpu-prefetch", "mlir::ModuleOp"> {
let summary = "prefetch";

View File

@@ -7,6 +7,7 @@ add_mlir_dialect_library(TritonGPUTransforms
Prefetch.cpp
RemoveLayoutConversions.cpp
ReorderInstructions.cpp
StreamPipeline.cpp
TritonGPUConversion.cpp
Utility.cpp

View File

@@ -0,0 +1,827 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/MapVector.h"
//===----------------------------------------------------------------------===//
// This file implements stream software pipelining for loops. The implementation
// here is inspired by the pipeline pass in Triton and the rocMLIR pipeliner.
//
// We divide the loop body into the following phases:
// a. Pre-load operations: for instance, index computation.
// b. Load operations: loading from global memory to shared memory.
// c. Compute operations: for instance, Triton dot.
// d. Post-load operations: for instance, index computation.
//
// To pipeline the loop, we need to:
// - Find all the dependencies of the load operations.
// - Prologue: Hoist the pipelinable load operations and shared memory store
// for the ramp up stage
// - Pipelined Loop: Assemble the loop body minus last iteration
// - Prefetch next tile from global into regs (while computing from previous)
// - Non-load loop body
// - Store next tile into shared mem
// - Epilogue: Peeled non-load loop body for last iteration
//
//===----------------------------------------------------------------------===//
using llvm::MapVector;
using namespace mlir;
namespace ttg = triton::gpu;
#define GEN_PASS_CLASSES
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
namespace {
class LoopPipeliner {
/// Cache of ForOp and YieldOp related to this pipeliner.
scf::ForOp forOp;
scf::YieldOp yieldOp;
bool peelLastIter = true;
/// The new pipelined ForOp.
scf::ForOp pplForOp;
/// Loads to be pipelined
SetVector<Value> validLoads;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> convertMapping;
/// load => buffer
DenseMap<Value, Value> loadsBuffer;
/// load => buffer type (with shared layout after swizzling)
DenseMap<Value, RankedTensorType> loadsBufferType;
/// Iterator values
Value nextLoopCond;
/// Yield values
SmallVector<Value> nextBuffers;
SmallVector<Value> yieldValues;
/// The number of stages in the pipeline is fixed to '2' for
/// analysis since there will be a current buffer stored in
/// shared mem and a next buffer stored in regs.
int numStages = 2;
/// Arg indicies
size_t bufferIdx, depArgsBeginIdx;
DenseMap<BlockArgument, size_t> depArgsIdx;
/// value (in loop) => value at stage N
DenseMap<Value, SmallVector<Value>> valueMapping;
/// loop iter arg => value
DenseMap<BlockArgument, Value> depArgsMapping;
/// forOp value => pplForOp value
IRMapping curMapping;
/// forOp value => prefetch value
IRMapping nextMapping;
/// Dependency ops by program order
SmallVector<Operation *> orderedDeps;
/// block arguments that loads depend on
SetVector<BlockArgument> depArgs;
/// operation => source operand defined stages
DenseMap<Operation *, DenseSet<int>> immediateOpStages;
/// operations that loads depend on
SetVector<Operation *> depOps;
/// Collect all pipelinable ops
LogicalResult collectOps(SetVector<Operation *> &ops);
/// Collect values that `v` depends on and are defined inside the loop
void collectValueDep(Value v, int stage, SetVector<Value> &opDeps);
/// Collect all op dependencies
void collectDeps(SetVector<Operation *> &ops,
MapVector<Operation *, SetVector<Value>> &opDeps);
/// Check if none of the ops has valid uses
LogicalResult checkOpUses(SetVector<Operation *> &ops);
/// Check if ops have dependencies that are not pipelinable
void checkOpDeps(SetVector<Operation *> &ops);
void createBufferTypes();
void createOrderedDeps();
/// Return the stage at which `v` is defined prior to `stage`
int getValueDefStage(Value v, int stage);
/// Map `origin` to `newValue` at `stage`
void setValueMapping(Value origin, Value newValue, int stage);
/// Map `origin` to `newValue` at `stage` according to the association between
/// yieldOp and forOp
void setValueMappingYield(Value origin, Value newValue, int stage);
/// Map `origin` to `newValue` at the next stage according to the association
/// between yieldOp and forOp
void setValueMappingYield(Value origin, Value newValue);
/// Return the value mapped to `origin` at `stage`, if it exists.
Value lookupOrDefault(Value origin, int stage);
Value getLoadMask(triton::LoadOp loadOp, Value mappedMask,
Value loopCond, OpBuilder &builder);
/// Collect all args of the new loop
SmallVector<Value> collectNewLoopArgs();
/// Clone the forOp and return the new forOp
scf::ForOp cloneForOp(ArrayRef<Value> newLoopArgs, OpBuilder &builder);
void updateLoadMask(triton::LoadOp loadOp, Value newMask);
/// Prefetch the next iteration for `pplForOp`
void prefetchNextBuffer(OpBuilder &builder);
void cloneCurrentBody(OpBuilder &builder);
void storeNextBuffer(OpBuilder &builder);
/// Assemble `pplForOp`'s yield op
void finalizeYield(OpBuilder &builder);
public:
LoopPipeliner(scf::ForOp forOp)
: forOp(forOp) {
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult initialize();
/// Emit pipelined loads (before loop body)
void emitPrologue();
/// emit pipelined loads (after loop body)
void emitEpilogue(DenseMap<Value, Value> &newResults);
/// create the new ForOp (add new args & insert prefetched ops)
scf::ForOp createNewForOp();
friend struct PipelinePass;
};
/// Collect loads to pipeline. Return success if we can pipeline this loop
LogicalResult LoopPipeliner::collectOps(SetVector<Operation *> &ops) {
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
// We cannot use forOp.walk(...) here because we only want to visit the
// operations in the loop body block. Nested blocks are handled separately.
for (Operation &op : forOp)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
// pipeline all loads
ops.insert(loadOp);
}
if (ops.empty())
return failure();
else
return success();
}
void LoopPipeliner::collectValueDep(Value v, int stage,
SetVector<Value> &deps) {
// Loop-invariant value, skip
if (v.getParentRegion() != &forOp.getLoopBody())
return;
if (deps.contains(v))
return;
// Since we only need to peel the loop numStages-1 times, don't worry
// about depends that are too far away
if (stage < 0)
return;
if (auto arg = v.dyn_cast<BlockArgument>()) {
if (arg.getArgNumber() > 0) {
deps.insert(v);
collectValueDep(yieldOp->getOperand(arg.getArgNumber() - 1), stage - 1,
deps);
}
} else { // value
deps.insert(v);
for (Value op : v.getDefiningOp()->getOperands())
collectValueDep(op, stage, deps);
}
}
void LoopPipeliner::collectDeps(
SetVector<Operation *> &ops,
MapVector<Operation *, SetVector<Value>> &valueDeps) {
for (auto op : ops) {
for (Value v : op->getOperands()) {
SetVector<Value> deps;
collectValueDep(v, numStages - 1, deps);
valueDeps[op] = deps;
}
}
}
LogicalResult LoopPipeliner::checkOpUses(SetVector<Operation *> &ops) {
DenseSet<Operation *> invalidOps;
// Collect all ops' dependencies
MapVector<Operation *, SetVector<Value>> opDeps;
collectDeps(ops, opDeps);
for (Operation *op : ops) {
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
// Don't pipeline valid loads that depend on other valid loads
// (Because if a valid load depends on another valid load, this load needs
// to wait on the other load in the prologue, which is against the point
// of the pipeline pass)
bool isCandidate = true;
for (Operation *other : ops)
if (isa<triton::LoadOp>(other))
if (opDeps[op].contains(other->getResult(0))) {
isCandidate = false;
break;
}
// We only pipeline loads that have one covert_layout (to dot_op) use
// TODO: lift this constraint in the future
if (isCandidate && loadOp.getResult().hasOneUse()) {
isCandidate = false;
Operation *use = *loadOp.getResult().getUsers().begin();
// Advance to the first conversion as long as the use resides in shared
// memory and it has a single use itself
while (use) {
if (use->getNumResults() != 1 || !use->getResult(0).hasOneUse())
break;
auto tensorType =
use->getResult(0).getType().dyn_cast<RankedTensorType>();
if (!tensorType.getEncoding().isa<ttg::SharedEncodingAttr>())
break;
use = *use->getResult(0).getUsers().begin();
}
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use))
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>())
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
convertMapping[loadOp] = convertLayout;
}
} else
isCandidate = false;
if (!isCandidate)
invalidOps.insert(loadOp);
else
validLoads.insert(loadOp);
}
}
for (Operation *op : invalidOps)
ops.remove(op);
if (ops.empty())
return failure();
else
return success();
}
void LoopPipeliner::checkOpDeps(SetVector<Operation *> &ops) {
/// arg => source operand defined stages
DenseMap<BlockArgument, DenseSet<int>> immediateArgStages;
SetVector<BlockArgument> nonImmediateDepArgs;
SetVector<Operation *> nonImmediateOps;
for (Operation *op : ops) {
for (Value v : op->getOperands()) {
SetVector<Value> deps;
collectValueDep(v, numStages - 1, deps);
int defStage = getValueDefStage(v, numStages - 1);
assert(defStage >= 0 &&
"newLoopArgs has null args without a define op. Consider either "
"rewrite the loop to reduce cross iteration dependencies or "
"increase the num_stages value.");
for (auto dep : deps) {
auto immediate = deps.front().isa<BlockArgument>();
if (auto arg = dyn_cast<BlockArgument>(dep)) {
depArgs.insert(arg);
if (immediate)
immediateArgStages[arg].insert(defStage);
else
nonImmediateDepArgs.insert(arg);
} else {
depOps.insert(dep.getDefiningOp());
if (immediate)
immediateOpStages[dep.getDefiningOp()].insert(defStage);
else
nonImmediateOps.insert(dep.getDefiningOp());
}
}
}
}
// XXX: We could remove the following constraints if we can rematerialize in
// the loop.
// Check if immediateDepArgs and nonImmediateDepArgs are disjoint.
for (auto &[arg, stages] : immediateArgStages) {
assert(stages.size() == 1 &&
"Triton doesn't support an argument provides values for "
"immediate operands of loads from multiple stages. Consider "
"removing post load instructions dependency on this argument.");
assert(!(nonImmediateDepArgs.contains(arg) &&
stages.contains(numStages - 2)) &&
"Loop-carried arguments provide values for both immediate and "
"non-immediate operands of loads. Please consider removing "
"pre/post load instructions dependency on this argument.");
}
// Check if immediateOps and nonImmediateOps are disjoint.
for (auto &[op, stages] : immediateOpStages) {
assert(stages.size() == 1 &&
"Triton doesn't support an operation provides values for "
"immediate operands of loads from multiple stages. Consider "
"removing post load instructions dependency on this argument.");
assert(!(nonImmediateOps.contains(op) && stages.contains(numStages - 2)) &&
"Operations provide values for both immediate and "
"non-immediate operands of loads. Please consider "
"removing pre/post load instructions dependency on this "
"operation.");
}
}
// helpers
void LoopPipeliner::setValueMapping(Value origin, Value newValue, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
valueMapping[origin] = SmallVector<Value>(numStages);
valueMapping[origin][stage] = newValue;
}
void LoopPipeliner::setValueMappingYield(Value origin, Value newValue,
int stage) {
for (OpOperand &operand : origin.getUses()) {
if (operand.getOwner() == yieldOp) {
auto yieldIdx = operand.getOperandNumber();
auto value = forOp.getRegionIterArgs()[yieldIdx];
setValueMapping(value, newValue, stage);
}
}
}
void LoopPipeliner::setValueMappingYield(Value origin,
Value newValue) {
for (OpOperand &operand : origin.getUses()) {
if (operand.getOwner() == yieldOp) {
auto yieldIdx = operand.getOperandNumber();
auto depYieldIdx = depArgsIdx[forOp.getRegionIterArgs()[yieldIdx]];
auto originArg = forOp.getRegionIterArgs()[yieldIdx];
nextMapping.map(originArg, newValue);
auto newArg = pplForOp.getRegionIterArgs()[depYieldIdx];
if (!depArgsMapping.contains(newArg))
depArgsMapping[newArg] = newValue;
}
}
}
Value LoopPipeliner::lookupOrDefault(Value origin, int stage) {
if (valueMapping.find(origin) == valueMapping.end())
return origin;
return valueMapping[origin][stage];
}
void LoopPipeliner::createBufferTypes() {
for (auto loadCvt : convertMapping) {
auto loadOp = loadCvt.first;
Value cvt = loadCvt.second;
auto dotOpEnc = cvt.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<ttg::DotOperandEncodingAttr>();
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
Type eType = ty.getElementType();
auto blockedEnc = ty.getEncoding().cast<ttg::BlockedEncodingAttr>();
// unsigned bitWidth = dotOpEnc.getMMAv2kWidth()
// ? 32 / dotOpEnc.getMMAv2kWidth()
// : ty.getElementType().getIntOrFloatBitWidth();
auto sharedEnc =
ttg::SharedEncodingAttr::get(ty.getContext(), dotOpEnc, ty.getShape(),
ttg::getOrder(ty.getEncoding()), eType);
loadsBufferType[loadOp] = RankedTensorType::get(bufferShape, eType, sharedEnc);
}
}
void LoopPipeliner::createOrderedDeps() {
for (Operation &op : forOp.getLoopBody().front()) { // @@@ front necessary?
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + validLoads.size() == orderedDeps.size() &&
"depOps contains invalid values");
}
int LoopPipeliner::getValueDefStage(Value v, int stage) {
if (stage < 0)
return -1;
if (auto arg = v.dyn_cast<BlockArgument>()) {
if (arg.getArgNumber() > 0)
return getValueDefStage(yieldOp->getOperand(arg.getArgNumber() - 1),
stage - 1);
llvm_unreachable("Loop induction variable should not be a dependency");
} else
return stage;
}
LogicalResult LoopPipeliner::initialize() {
// All ops that maybe pipelined
SetVector<Operation *> ops;
if (collectOps(ops).failed())
return failure();
if (checkOpUses(ops).failed())
return failure();
checkOpDeps(ops);
createBufferTypes();
createOrderedDeps();
return success();
}
Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
Value loopCond, OpBuilder &builder) {
if (!peelLastIter) {
// add mask for last iteration when not peeled to epilogue
Value mask = loadOp.getMask();
Type maskType = triton::getI1SameShape(loadOp.getType());
Value newMask;
if (mask) {
Value cond = loopCond;
if (isa<RankedTensorType>(maskType)) {
cond = builder.create<triton::SplatOp>(mask.getLoc(), maskType, loopCond);
}
newMask = builder.create<arith::AndIOp>(mask.getLoc(), mappedMask, cond);
} else {
if (isa<RankedTensorType>(maskType)) {
newMask = builder.create<triton::SplatOp>(loopCond.getLoc(), maskType,
loopCond);
} else {
newMask = loopCond;
}
}
return newMask;
}
// use original mask when peeling last iteration bc the loop will not do
// extra loads for the tail of the pipeline
return mappedMask;
}
void LoopPipeliner::emitPrologue() {
/// forOp block args => forOp operands
/// forOp iterator => lower bound
IRMapping prologueMap;
OpBuilder builder(forOp);
// Get init operands for loop carried values
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
prologueMap.map(arg, operand.get());
}
// Emit prologue
// Map IV to lower bound
prologueMap.map(forOp.getInductionVar(), forOp.getLowerBound());
// Emit Iteration 0 loads, etc
for (Operation *op : orderedDeps) {
Operation *newOp = nullptr;
if (validLoads.contains(op->getResult(0))) {
auto loadOp = cast<triton::LoadOp>(op);
// Load from global -> regs
auto newLoadOp = cloneWithInferType(builder, op, prologueMap);
Value loadVal = newLoadOp->getResult(0);
// Convert from regs to shared mem
newOp = builder.create<ttg::ConvertLayoutOp>(loadOp.getLoc(),
loadsBufferType[loadOp],
loadVal);
Value cvtVal = newOp->getResult(0);
prologueMap.map(loadOp->getResult(0), cvtVal);
loadsBuffer[loadOp] = cvtVal;
} else {
newOp = cloneWithInferType(builder, op, prologueMap);
}
// Capture loop carried results for pipelined for input
for (unsigned idx : llvm::seq(unsigned(0), op->getNumResults()))
setValueMappingYield(op->getResult(idx), newOp->getResult(idx), 1);
} // for (Operation *op : orderedDeps)
}
void LoopPipeliner::emitEpilogue(DenseMap<Value, Value> &newResults) {
if (!peelLastIter)
return;
OpBuilder builder(forOp);
builder.setInsertionPointAfter(forOp);
IRMapping epilogueMap;
// Map 'for' iteration args to pipelined-for results
auto args = forOp.getRegionIterArgs();
for (uint32_t i = 0; i < args.size(); ++i)
epilogueMap.map(args[i], pplForOp.getResult(i));
for (uint32_t i = 0; i < validLoads.size(); ++i)
epilogueMap.map(validLoads[i], pplForOp.getResult(bufferIdx + i));
// Map IV to original upper bound (ie. last iteration)
epilogueMap.map(forOp.getInductionVar(), forOp.getUpperBound());
const auto &yieldOprs = yieldOp.getOperands();
// Clone the loop body after the new ForOp
// , replace original args with results of the new ForOp.
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!llvm::is_contained(orderedDeps, &op)) {
Operation *newOp = nullptr;
auto cvtOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (cvtOp && validLoads.contains(cvtOp.getSrc())) {
auto cvtDstTy = cvtOp.getResult().getType().cast<RankedTensorType>();
if (cvtDstTy.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
newOp = builder.clone(op, epilogueMap);
}
}
if (newOp == nullptr)
newOp = cloneWithInferType(builder, &op, epilogueMap);
// substitute for these results for the results of the new for loop
for (const auto &pair : llvm::zip(op.getResults(), newOp->getResults())) {
auto val = std::get<0>(pair);
auto it = llvm::find(yieldOprs, val);
if (it != yieldOprs.end()) {
uint32_t idx = std::distance(yieldOprs.begin(), it);
newResults[forOp->getResult(idx)] = std::get<1>(pair);
}
}
}
}
}
SmallVector<Value> LoopPipeliner::collectNewLoopArgs() {
// Order of new args:
// (original args)
// (shared mem buffers for each load)
// (depArgs at stage numStages - 1)
// We need this to update operands for yield
// original block arg => new arg's idx
SmallVector<Value> newLoopArgs;
for (auto v : forOp.getIterOperands()) {
newLoopArgs.push_back(lookupOrDefault(v, numStages - 1));/*1*/
}
// Shared mem locations from iteration 0
bufferIdx = newLoopArgs.size();
for (auto loadOp : validLoads)
newLoopArgs.push_back(loadsBuffer[loadOp]);
// Loop carried vals
depArgsBeginIdx = newLoopArgs.size();
for (auto depArg : depArgs) {
depArgsIdx[depArg] = newLoopArgs.size();
newLoopArgs.push_back(valueMapping[depArg][numStages - 1]);/*1*/
}
return newLoopArgs;
}
scf::ForOp LoopPipeliner::cloneForOp(ArrayRef<Value> newLoopArgs,
OpBuilder &builder) {
auto loc = forOp.getLoc();
// Peel off the last iteration
auto pplUpperBound = forOp.getUpperBound();
if (peelLastIter)
pplUpperBound = builder.create<arith::SubIOp>(loc, pplUpperBound,
forOp.getStep());
// Clone the original ForOp
pplForOp = builder.create<scf::ForOp>(
loc, forOp.getLowerBound(), pplUpperBound,
forOp.getStep(), newLoopArgs);
// Set mapping on body of the new ForOp
builder.setInsertionPointToStart(pplForOp.getBody());
for (const auto &arg : llvm::enumerate(forOp.getRegionIterArgs()))
curMapping.map(arg.value(), pplForOp.getRegionIterArgs()[arg.index()]);
uint32_t bufIdx = bufferIdx;
for (auto loadOp : validLoads)
curMapping.map(loadOp, pplForOp.getRegionIterArgs()[bufIdx++]);
curMapping.map(forOp.getInductionVar(), pplForOp.getInductionVar());
nextMapping = curMapping;
// Map the dep args of the next iteration to the dep args of the current
auto iterArgs = pplForOp.getRegionIterArgs();
size_t argIdx = 0;
for (auto depArg : depArgs) {
BlockArgument nextArg = iterArgs[argIdx + depArgsBeginIdx];
nextMapping.map(depArg, nextArg);
++argIdx;
}
// Compute next IV for pre-loads
Value iv = pplForOp.getInductionVar();
curMapping.map(forOp.getInductionVar(), iv);
Value nextIV = builder.create<arith::AddIOp>(iv.getLoc(), iv, pplForOp.getStep());
nextMapping.map(forOp.getInductionVar(), nextIV);
nextLoopCond =
builder.create<arith::CmpIOp>(nextIV.getLoc(), arith::CmpIPredicate::slt,
nextIV, pplForOp.getUpperBound());
return pplForOp;
}
void LoopPipeliner::updateLoadMask(triton::LoadOp loadOp, Value newMask) {
if (newMask) {
if (loadOp->getNumOperands() > 1)
loadOp->setOperand(1, newMask);
else {
auto mask = loadOp.getMaskMutable();
mask.assign(newMask);
}
}
}
void LoopPipeliner::prefetchNextBuffer(OpBuilder &builder) {
// Emit prefetch loads of next buffer before compute of current buffer
for (Operation *op : orderedDeps) {
Operation *nextOp = nullptr;
if (validLoads.contains(op->getResult(0))) {
// Update loading mask
auto loadOp = llvm::cast<triton::LoadOp>(op);
auto mask = loadOp.getMask();
// pre-load global -> regs
Value newMask = getLoadMask(loadOp, nextMapping.lookupOrDefault(mask),
nextLoopCond, builder);
if (mask) {
// If mask is defined outside the loop, don't update the map more than
// once
if (!(forOp.isDefinedOutsideOfLoop(mask) && nextMapping.contains(mask)))
nextMapping.map(loadOp.getMask(), newMask);
newMask = nextMapping.lookupOrDefault(mask);
}
auto newOp = builder.clone(*op, nextMapping);
updateLoadMask(cast<triton::LoadOp>(newOp), newMask);
} else if (!immediateOpStages[op].contains(numStages - 2)) {
Operation *nextOp = builder.clone(*op, nextMapping);
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
if (auto newMask = getLoadMask(loadOp,
nextMapping.lookupOrDefault(loadOp.getMask()),
nextLoopCond, builder)) {
updateLoadMask(cast<triton::LoadOp>(nextOp), newMask);
}
}
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults()))
nextMapping.map(op->getResult(dstIdx), nextOp->getResult(dstIdx));
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults()))
setValueMappingYield(op->getResult(dstIdx),
nextOp->getResult(dstIdx));
}
}
}
void LoopPipeliner::cloneCurrentBody(OpBuilder &builder) {
auto loc = forOp.getLoc();
// only add instructions that are not part of the restructuring
for (Operation &op : forOp.getBody()->without_terminator()) {
if (!llvm::is_contained(orderedDeps, &op)) {
Operation *newOp = nullptr;
auto cvtOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op);
if (cvtOp && validLoads.contains(cvtOp.getSrc())) {
auto cvtDstTy = cvtOp.getResult().getType().cast<RankedTensorType>();
if (cvtDstTy.getEncoding().isa<ttg::DotOperandEncodingAttr>())
newOp = builder.clone(op, curMapping);
}
if (newOp == nullptr)
newOp = cloneWithInferType(builder, &op, curMapping);
} else {
// hack for yield operands
if (auto ttadd = dyn_cast<triton::AddPtrOp>(op)) {
curMapping.map(ttadd.getResult(), curMapping.lookup(ttadd.getPtr()));
}
}
}
}
void LoopPipeliner::storeNextBuffer(OpBuilder &builder) {
// Store the next buffer at the end of the loop body for the next iteration
for (Operation *op : orderedDeps) {
if (!validLoads.contains(op->getResult(0))) {
if (immediateOpStages[op].contains(numStages - 2)) {
Operation *nextOp = builder.clone(*op, nextMapping);
if (auto loadOp = dyn_cast<triton::LoadOp>(op)) {
auto newMask = getLoadMask(loadOp,
nextMapping.lookupOrDefault(loadOp.getMask()),
nextLoopCond, builder);
updateLoadMask(cast<triton::LoadOp>(nextOp), newMask);
}
for (unsigned dstIdx : llvm::seq(unsigned(0), op->getNumResults()))
setValueMappingYield(op->getResult(dstIdx),
nextOp->getResult(dstIdx));
}
}
}
// PL loads -> store next to shared
for (auto loadOp : validLoads) {
Value loadVal = nextMapping.lookup(loadOp);
// then store regs -> shared
Value storeBuf = pplForOp.getRegionIterArgs()[bufferIdx + nextBuffers.size()];
auto cvt = builder.create<ttg::ConvertLayoutOp>(
loadOp.getLoc(), storeBuf.getType(), loadVal);
nextBuffers.push_back(cvt);
}
// Some values have not been used by any ops in the loop body
for (BlockArgument arg : forOp.getRegionIterArgs())
setValueMappingYield(arg,
pplForOp.getRegionIterArgs()[depArgsIdx[arg]]);
}
void LoopPipeliner::finalizeYield(OpBuilder &builder) {
SmallVector<Value> yieldValues;
for (Value v : yieldOp->getOperands())
yieldValues.push_back(curMapping.lookup(v));
for (Value nextBuffer : nextBuffers)
yieldValues.push_back(nextBuffer);
for (size_t i = 0; i < depArgsMapping.size(); ++i) {
auto arg = pplForOp.getRegionIterArgs()[depArgsBeginIdx + i];
assert(depArgsMapping.count(arg) && "Missing loop-carried value");
yieldValues.push_back(depArgsMapping[arg]);
}
builder.setInsertionPointToEnd(pplForOp.getBody());
builder.create<scf::YieldOp>(yieldOp->getLoc(), yieldValues);
}
scf::ForOp LoopPipeliner::createNewForOp() {
OpBuilder builder(forOp);
auto newLoopArgs = collectNewLoopArgs();
cloneForOp(newLoopArgs, builder);
prefetchNextBuffer(builder);
cloneCurrentBody(builder);
storeNextBuffer(builder);
finalizeYield(builder);
return pplForOp;
}
// Stream Pipeline
struct PipelinePass : public TritonGPUStreamPipelineBase<PipelinePass> {
PipelinePass() = default;
void runOnOperation() override {
// Pre-processing
// we make sure element-wise ops are done *after* the conversion
// to dot operands
// we can achieve this with simple recursive pattern matching
// MLIRContext *context = &getContext();
// mlir::RewritePatternSet patterns(context);
// patterns.add<MoveOpAfterLayoutConversion>(context);
// auto didPreprocess =
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// Do the pipelining
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp);
if (pipeliner.initialize().failed())
return;
pipeliner.emitPrologue();
scf::ForOp pplForOp = pipeliner.createNewForOp();
DenseMap<Value, Value> newResults;
for (unsigned i = 0; i < forOp->getNumResults(); ++i)
newResults[forOp->getResult(i)] = pplForOp->getResult(i);
pipeliner.emitEpilogue(newResults);
// Replace the original loop
for (const auto &pair : newResults)
std::get<0>(pair).replaceAllUsesWith(std::get<1>(pair));
forOp->erase();
});
}
};
} // anonymous namespace
std::unique_ptr<Pass> mlir::createTritonGPUStreamPipelinePass() {
return std::make_unique<PipelinePass>();
}

View File

@@ -1698,6 +1698,10 @@ void init_triton_ir(py::module &&m) {
[](mlir::PassManager &self, int numStages) {
self.addPass(mlir::createTritonGPUPipelinePass(numStages));
})
.def("add_tritongpu_stream_pipeline_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUStreamPipelinePass());
})
.def("add_tritongpu_prefetch_pass",
[](mlir::PassManager &self) {
self.addPass(mlir::createTritonGPUPrefetchPass());

View File

@@ -90,12 +90,17 @@ def optimize_ttgir(mod, num_stages, arch):
pm.add_tritongpu_accelerate_matmul_pass(80)
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_pipeline_pass(num_stages)
if num_stages == 0 and is_hip() and gpu_has_mfma():
pm.add_tritongpu_stream_pipeline_pass()
pm.add_canonicalizer_pass()
else:
pm.add_tritongpu_pipeline_pass(num_stages)
pm.add_tritongpu_prefetch_pass()
pm.add_tritongpu_optimize_dot_operands_pass()
pm.add_tritongpu_remove_layout_conversions_pass()
pm.add_tritongpu_decompose_conversions_pass()
pm.add_tritongpu_reorder_instructions_pass()
if num_stages != 0:
pm.add_tritongpu_reorder_instructions_pass()
pm.add_cse_pass()
pm.add_symbol_dce_pass()
pm.run(mod)