Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108

Conflicts:
	bin/triton-translate.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	python/triton/compiler/compiler.py
	python/triton/runtime/jit.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
Jason Furmanek
2023-11-08 18:51:23 +00:00
72 changed files with 1623 additions and 838 deletions

View File

@@ -1,9 +1,9 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/FunctionImplementation.h"
#include "mlir/IR/FunctionInterfaces.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"

View File

@@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR
TritonGPUAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRGPUOps
MLIRGPUDialect
TritonIR
)

View File

@@ -1603,7 +1603,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) {
result.operands))
return failure();
// Deduce operand_segment_sizes from the number of the operands.
// Deduce operandSegmentSizes from the number of the operands.
auto operandSegmentSizesAttrName =
OpT::getOperandSegmentSizesAttrName(result.name);
result.addAttribute(
@@ -1616,7 +1616,7 @@ template <class OpT>
void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) {
printer << " ";
printer << insertSliceOp.getOperation()->getOperands();
// "operand_segment_sizes" can be deduced, so we don't print it.
// "operandSegmentSizes" can be deduced, so we don't print it.
printer.printOptionalAttrDict(
insertSliceOp->getAttrs(),
{insertSliceOp.getOperandSegmentSizesAttrName()});

View File

@@ -139,7 +139,10 @@ class BlockedToMMA : public mlir::RewritePattern {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = bwdFilter;
getBackwardSlice(x, &slice, opt);
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
@@ -235,8 +238,11 @@ public:
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
getBackwardSlice(a, &aBwdSlices, {isCvt});
getBackwardSlice(b, &bBwdSlices, {isCvt});
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = isCvt;
getBackwardSlice(a, &aBwdSlices, opt);
getBackwardSlice(b, &bBwdSlices, opt);
// get the source of the first conversion found in slices
auto getCvtArgOrder = [](Operation *op) {
return cast<ConvertLayoutOp>(op)

View File

@@ -98,7 +98,9 @@ public:
// and all operations between the load and the conversion
// should be layout preserving
SetVector<Operation *> slice;
getBackwardSlice(op, &slice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op, &slice, opt);
int loadIdx = -1;
bool checkOp = false;
for (int i = 0; i < slice.size(); i++) {

View File

@@ -160,6 +160,8 @@ class LoopPipeliner {
void checkOpShareBarriers(SetVector<Operation *> &ops);
int numLoadsRequireAsyncWait = 0;
int numLoadsRequireMBarrier = 0;
// Number of buffers to allocate for each input.
int numSharedMemorySlices = 0;
/// Iterator values
Value nextIV;
@@ -280,9 +282,12 @@ class LoopPipeliner {
public:
LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs,
bool mode, ConsumerReleaseMap &consumerReleaseMap)
bool mode, int numSharedMemorySlices,
ConsumerReleaseMap &consumerReleaseMap)
: forOp(forOp), numStages(numStages), numWarps(numWarps),
numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) {
numCTAs(numCTAs), mode(mode),
numSharedMemorySlices(numSharedMemorySlices),
consumerReleaseMap(consumerReleaseMap) {
// cache yieldOp
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
}
@@ -644,7 +649,7 @@ void LoopPipeliner::createBufferTypes() {
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
bufferShape.insert(bufferShape.begin(), numSharedMemorySlices);
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
Attribute sharedEnc;
if (auto dotOpEnc = cvt.getType()
@@ -946,6 +951,11 @@ void LoopPipeliner::emitPrologue() {
pipelineIterIdx = builder.create<arith::AddIOp>(
iv.getLoc(), pipelineIterIdx,
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
Value numSlices = builder.create<arith::ConstantIntOp>(
iv.getLoc(), numSharedMemorySlices, 32);
Value _0 = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx,
numSlices, pipelineIterIdx, _0);
// Some values have not been used by any ops in the loop body
for (BlockArgument arg : forOp.getRegionIterArgs())
setValueMappingYield(arg, valueMapping[arg][stage], stage + 1);
@@ -1220,11 +1230,13 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
Value _1 = builder.create<arith::ConstantIntOp>(idxLoc, 1, 32);
Value numStagesVal =
builder.create<arith::ConstantIntOp>(idxLoc, numStages, 32);
Value numSlices =
builder.create<arith::ConstantIntOp>(idxLoc, numSharedMemorySlices, 32);
// nextWaitIdx
Value waitIdxPlusOne = builder.create<arith::AddIOp>(idxLoc, curWaitIdx, _1);
Value nextWaitIdx = getBoundedIterationValue(
builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0);
Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne,
numSlices, waitIdxPlusOne, _0);
// Indices of InsertSliceAsyncOp and ExtractSliceOp
Value insertSliceIndex = pipelineIterIdx;
@@ -1417,9 +1429,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
// Bump pipelineIterIdx
Value pipelineIterIdxPlusOne =
builder.create<arith::AddIOp>(idxLoc, pipelineIterIdx, _1);
pipelineIterIdx =
getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal,
pipelineIterIdxPlusOne, _0);
pipelineIterIdx = getBoundedIterationValue(
builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0);
// Bump curWaitIdx
curWaitIdx = nextWaitIdx;
@@ -1516,10 +1527,23 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
llvm::SmallVector<scf::ForOp> newForOps;
// Currently we schedule stage 0 after stage `numStages - 1` during
// pipelining therefore we only need `numStages - 1` slice of memory.
// On Hopper we have a separate post-processing that pipelines wgmma so we
// need an extra buffer for each input.
// Note that an alternative would be to keep allocating `numStages` buffers
// and remove the barrier between the loads from shared memory and the
// copies from global to shared. This would require improving existing
// membar analysis.
int numSharedMemorySlices =
computeCapability < 90 ? numStages - 1 : numStages;
// Do the pipelining
getOperation()->walk([&](scf::ForOp forOp) -> void {
LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps,
this->numCTAs, mode, consumerReleaseMap);
this->numCTAs, mode, numSharedMemorySlices,
consumerReleaseMap);
if (pipeliner.initialize().failed())
return;
@@ -1593,7 +1617,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
/// dots to be pipelined
SetVector<Value> dots;
SmallVector<tt::DotOp> dots;
SmallVector<unsigned> resultNeedSync;
for (Operation &op : *loop) {
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
@@ -1615,8 +1640,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
if (!CArg || !CArg.hasOneUse())
valid = false;
if (valid)
dots.insert(dotOp);
if (valid) {
dots.push_back(dotOp);
resultNeedSync.push_back(
dotOp->getUses().begin()->getOperandNumber());
}
}
}
}
@@ -1627,39 +1655,39 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
return;
OpBuilder builder(forOp);
// 0. insert dot_wait after the last dot in the loop
Value dot = dots.back();
auto loc = dot.getLoc();
builder.setInsertionPointAfter(dot.getDefiningOp());
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(loc, dots.size());
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
// wgmma ops by one stage.
// This is needed to prevent shared memory inputs to be overriden before the
// operation is completed.
// TODO: merge this with the rest of the pipelining transformation and look at
// a better representation for async dots.
tt::DotOp lastDot = dots.back();
builder.setInsertionPointAfter(lastDot);
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
lastDot.getLoc(), lastDot.getResult(), dots.size());
// 1. replace Dot with DotAsync
for (size_t idx = 0; idx < dots.size(); ++idx) {
Value dot = dots[idx];
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
builder.setInsertionPoint(dot.getDefiningOp());
tt::DotOp dotOp = dots[idx];
builder.setInsertionPoint(dotOp);
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
dot.getDefiningOp()->erase();
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dotOp.replaceAllUsesWith(dotAsync.getResult());
updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1);
dotOp->erase();
}
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
Value loopNotEmpty = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
forOp.getUpperBound());
// TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for
// a bug in ptxas which mistakenly analysis the control flow and turn the GMMA
// into synchronuous implementation for safety.
// Remove this If once the bug is fixed.
auto ifOp = builder.create<scf::IfOp>(loc, ArrayRef<Type>{}, loopNotEmpty,
/*hasElse*/ false);
builder.setInsertionPointToStart(ifOp.thenBlock());
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), 0);
for (unsigned resultIndex : resultNeedSync) {
Value result = forOp->getResult(resultIndex);
if (result.use_empty())
continue;
auto dotWait =
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
}
}
Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,

View File

@@ -31,6 +31,7 @@ using triton::gpu::SliceEncodingAttr;
//
// -----------------------------------------------------------------------------
<<<<<<< HEAD
// convert(blocked, dot_operand) ->
// convert(blocked, mma) + convert(mma, dot_operand)
// if this value is itself the result of a dot operation
@@ -102,6 +103,9 @@ public:
};
//
=======
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
class ConvertDotConvert : public mlir::RewritePattern {
public:
ConvertDotConvert(mlir::MLIRContext *context)
@@ -233,12 +237,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
getForwardSlice(currentValue, &forwardSlice);
for (Operation *op : forwardSlice) {
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
if (convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.isa<triton::gpu::MmaEncodingAttr>())
return true;
Attribute dstEncoding = convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding();
if (auto mmaLayout =
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
return (mmaLayout.getVersionMajor() > 1) ? true
: mmaLayout == encoding;
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return encoding.cast<triton::gpu::MmaEncodingAttr>()
.getVersionMajor() > 1;
}
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)
@@ -560,6 +569,15 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
return rewrittenValue;
OpBuilder rewriter(value.getContext());
rewriter.setInsertionPointAfterValue(rewrittenValue);
// Workaround: The pipeliner will insert async.wait after a pipelined loop
// to ensure that there is no pending copies and it is safe to re-use shared
// memory. We shouldn't insert ops that may use shared memory in between the
// loop and the async.wait. This is a hack until we fix the IR
// representation of async wait.
if (Operation *op = rewrittenValue.getDefiningOp()) {
if (isa<triton::gpu::AsyncWaitOp>(op->getNextNode()))
rewriter.setInsertionPointAfter(op->getNextNode());
}
auto tmpType = RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
Value converted = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -1122,7 +1140,6 @@ public:
hoistConvert(m);
mlir::RewritePatternSet decomposePatterns(context);
decomposePatterns.add<DecomposeDotOperand>(context);
decomposePatterns.add<ConvertDotConvert>(context);
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
.failed()) {

View File

@@ -91,7 +91,7 @@ private:
// suport ForOp only
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
// prologue
auto iterOperands = forOp.getIterOperands();
auto iterOperands = forOp.getInitArgs();
if (argNum == 0)
return false;
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))

View File

@@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp,
arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op))
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
arith::UIToFPOp, arith::XOrIOp>(op))
return true;
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,

View File

@@ -220,7 +220,9 @@ public:
SetVector<Operation *> backwardSlice;
mod.walk([&](triton::MakeTensorPtrOp op) -> void {
assert(isa<triton::FuncOp>(op->getParentOp()));
getBackwardSlice(op.getOperation(), &backwardSlice);
mlir::BackwardSliceOptions opt;
opt.omitBlockArguments = true;
getBackwardSlice(op.getOperation(), &backwardSlice, opt);
op->removeAttr("async_agent");
});
for (auto op : backwardSlice) {

View File

@@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) {
builder.setInsertionPoint(agentIdOp);
Value globalRoleId = builder.create<arith::ConstantIntOp>(loc, 0, 32);
int globalNumWarps = 0;
SmallVector<Operation *> deprecatedOps;
for (auto cmpOp : agentIdOp->getUsers()) {
assert(isa<arith::CmpIOp>(cmpOp));
for (auto u : cmpOp->getUsers()) {
@@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) {
Value cond =
builder.create<arith::AndIOp>(loc, lowerBound, upperBound);
cmpOp->getResult(0).replaceAllUsesWith(cond);
cmpOp->erase();
deprecatedOps.push_back(cmpOp);
break;
}
}
}
for (Operation *cmpOp : deprecatedOps) {
cmpOp->erase();
}
});
}
@@ -145,39 +149,24 @@ LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) {
}
Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op,
bool skipFirstWait) {
bool emptyBarrier) {
// TODO: currently we only support one loop, no nested loop, while or
// condition.
auto loc = op->getLoc();
auto forOp = op->getParentOfType<scf::ForOp>();
if (!forOp) {
return builder.create<arith::ConstantIntOp>(loc, skipFirstWait, 1);
return builder.create<arith::ConstantIntOp>(loc, emptyBarrier, 1);
}
auto defOp = op->getOperand(0).getDefiningOp();
assert(isa<ttng::CreateTokenOp>(defOp) &&
"mbarrier's definingOp is not createTokenOp");
ttng::CreateTokenOp createTokenOp = dyn_cast<ttng::CreateTokenOp>(defOp);
Value numStage =
builder.create<arith::ConstantIntOp>(loc, createTokenOp.getNum(), 32);
Value curStep = forOp.getBody()->getArguments().back();
if (curStep.getType() == builder.getIndexType()) {
curStep =
builder.create<arith::IndexCastOp>(loc, numStage.getType(), curStep);
// for (..., phase, pipelineIdx)
unsigned numArgs = forOp.getBody()->getNumArguments();
assert(numArgs > 2 && "Unexpected number of arguments");
Value curPhase = forOp.getBody()->getArgument(numArgs - 2);
if (emptyBarrier) {
Value _1_1b = builder.create<arith::ConstantIntOp>(loc, 1, 1);
curPhase = builder.create<mlir::arith::XOrIOp>(loc, curPhase, _1_1b);
}
Value curPhase = builder.create<arith::DivUIOp>(loc, curStep, numStage);
if (skipFirstWait) {
// If skipFirstWait, it waits for phaseBit 1
Value _1 = builder.create<arith::ConstantIntOp>(loc, 1, 32);
curPhase = builder.create<arith::AddIOp>(loc, curPhase, _1);
}
Value _2 = builder.create<arith::ConstantIntOp>(loc, 2, 32);
// TODO: May use alternative methods of phaseBit calculation to avoid high
// overhead of RemOp
Value phaseBit = builder.create<arith::RemUIOp>(loc, curPhase, _2);
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, phaseBit,
_0);
return curPhase;
}
int getTxBytes(ttng::InsertSliceAsyncV2Op load) {
@@ -260,7 +249,7 @@ void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op,
auto loc = op.getLoc();
// The first producer_aquire should be met immediately, so initailly producer
// skips the fisrt wait
Value phase = getMBarrierPhaseBit(builder, op, 1);
Value phase = getMBarrierPhaseBit(builder, op, true);
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferEmpty, phase);
assert(op.getOperation()->hasAttr("async_agent"));
setAgentIds(waitOp, getAgentIds(op.getOperation()));
@@ -296,7 +285,7 @@ void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op,
void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
Value bufferFull) {
auto loc = op.getLoc();
Value phase = getMBarrierPhaseBit(builder, op, 0);
Value phase = getMBarrierPhaseBit(builder, op, false);
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferFull, phase);
assert(op.getOperation()->hasAttr("async_agent"));
setAgentIds(waitOp, getAgentIds(op.getOperation()));
@@ -530,6 +519,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
builder.create<arith::ConstantIntOp>(loc, nameBarrierId - 1, 32);
// Process mutex users
int numUsers = 0;
SmallVector<Operation *> deprecatedOps;
for (Operation *user : createMutexOp.getResult().getUsers()) {
numUsers++;
assert(numUsers <= 2);
@@ -543,14 +533,20 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
Value barLeave = builder.create<arith::SelectOp>(
loc, isRole0, namedBarrierId1, namedBarrierId0);
builder.create<ttng::NamedBarrierArriveOp>(loc, barLeave, numThreads);
} else
} else {
llvm_unreachable("Unexpected user of mutex");
}
deprecatedOps.push_back(user);
}
for (Operation *user : deprecatedOps) {
user->erase();
}
nameBarrierId -= 2;
nameBarrierIdEnd -= 2;
createMutexOp.erase();
});
parentOp->walk(
[](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); });
}
void processLockOp(OpBuilder &builder, ttng::LockOp op) {
@@ -587,6 +583,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
OpBuilder builder(createMutexOp);
// Process mutex users
SmallVector<Operation *> deprecatedOps;
for (Operation *user : createMutexOp.getResult().getUsers()) {
auto loc = user->getLoc();
builder.setInsertionPoint(user);
@@ -596,6 +593,9 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
processUnlockOp(builder, op);
else
llvm_unreachable("Unexpected user of mutex");
deprecatedOps.push_back(user);
}
for (Operation *user : deprecatedOps) {
user->erase();
}

View File

@@ -156,14 +156,20 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
persistentForOp.getInitArgsMutable()
.slice(persistentForOp.getInitArgs().size() - 1, 1)
.assign(newIdx);
auto yield =
llvm::cast<scf::YieldOp>(persistentForOp.getBody()->getTerminator());
auto idxPlusOneOp =
yield->getOperand(yield->getNumOperands() - 1).getDefiningOp();
assert(isa<arith::AddIOp>(idxPlusOneOp));
assert(idxPlusOneOp->getOperand(0) ==
persistentForOp.getBody()->getArgument(
persistentForOp.getBody()->getNumArguments() - 1));
pipelineIdx = persistentForOp.getBody()->getArgument(
persistentForOp.getBody()->getNumArguments() - 1);
Operation *idxPlusOneOp = nullptr;
for (OpOperand &v : pipelineIdx.getUses()) {
if (isa<arith::AddIOp>(v.getOwner())) {
idxPlusOneOp = v.getOwner();
break;
}
}
assert(idxPlusOneOp && "idxPlusOneOp should be arith::AddIOp");
Operation *use = *idxPlusOneOp->getUsers().begin();
assert(isa<scf::YieldOp>(use) || isa<arith::SelectOp>(use) ||
isa<arith::CmpIOp>(use));
idxPlusOneOp->setOperand(1, numRolesValue);
// Add operations at the start of persistentForOp
@@ -213,45 +219,6 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
unlockLocs[i] = op;
}
// Update unlockLocs
// ====================== IR after async launch dots ======================
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
// %3) {
// * triton_nvidia_gpu.producer_wait arg2
// * %5 = triton_nvidia_gpu.dot_async %4, %5
// * triton_nvidia_gpu.dot_wait {pendings = 1}
// * %6 = arith.cmpi sgt, arg0, %c0
// * scf.if %6 {
// * %7 = arith.subi arg2, c1
// * triton_nvidia_gpu.consumer_release %7
// * }
// * %8 = arith.addi arg2, c1
// * scf.yield %5, %8
// * }
// * triton_nvidia_gpu.dot_wait {pendings = 0}
// * %9 = arith.subi %0#1, c1
// * triton_nvidia_gpu.consumer_release %9
// * =======================================================================
// after async launch dots, there will be outstanding consumerReleaseOp after
// ForOp. we should expend the unlockLocs from ForOp to the outstanding
// consumerReleaseOp.
for (int i = 0; i < numRoles; ++i) {
Operation *unlockOp = unlockLocs[i];
auto filter = [&](Operation *op) {
return op->getBlock() == unlockOp->getBlock();
};
if (isa<scf::ForOp>(unlockOp)) {
SetVector<Operation *> slices;
mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter});
auto iter = llvm::find_if(slices, [](Operation *op) {
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
});
if (iter != slices.end()) {
unlockLocs[i] = *iter;
}
}
}
// Only cases where all lock/unlock locations are in same level make sense.
for (int i = 1; i < numRoles; ++i) {
if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() ||
@@ -281,6 +248,54 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
else
lockLocs[i] = unlockLocs[prevTypeIds[i]];
}
// Update lockLocs
// ====================== IR after async launch dots ======================
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
// %3) {
// * triton_nvidia_gpu.producer_wait arg2
// * %5 = triton_nvidia_gpu.dot_async %4, %5
// * triton_nvidia_gpu.dot_wait {pendings = 1}
// * %6 = arith.cmpi sgt, arg0, %c0
// * scf.if %6 {
// * %7 = arith.subi arg2, c1
// * triton_nvidia_gpu.consumer_release %7
// * }
// * %8 = arith.addi arg2, c1
// * scf.yield %5, %8
// * }
// * triton_nvidia_gpu.dot_wait {pendings = 0}
// * ...
// * triton_nvidia_gpu.consumer_release ..
// * =======================================================================
// after async launch dots, there will be outstanding consumerReleaseOp after
// ForOp. we should set the epilogue lockLocs after the outstanding
// consumerReleaseOp.
for (int i = 0; i < numRoles; ++i) {
Operation *lockOp = lockLocs[i];
if (isa<scf::ForOp>(lockOp)) {
Operation *loc = nullptr;
unsigned numOutstandingConsumerRelease = 0;
for (auto v : lockOp->getResults()) {
SetVector<Operation *> slices;
mlir::getForwardSlice(v, &slices);
auto iter = llvm::find_if(slices, [](Operation *op) {
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
});
if (iter != slices.end()) {
numOutstandingConsumerRelease++;
loc = *iter;
}
}
assert(numOutstandingConsumerRelease <= 1 &&
"should have only one outstanding "
"consumerReleaseOp after "
"async launch dots");
if (loc)
lockLocs[i] = loc;
}
}
// lock
for (int i = 0; i < numRoles; ++i) {
builder.setInsertionPointAfter(lockLocs[i]);

View File

@@ -129,11 +129,12 @@ DenseMap<Channel *, Value> createBuffer(const SmallVector<Channel *> &channels,
}
//===----------------------------------------------------------------------===//
// appendPipelineIdxToLoopArgs
// createNewLoops
//===----------------------------------------------------------------------===//
scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
// for(...) -> for(..., pipelineIdx)
scf::ForOp createNewPersistentLoop(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
auto loc = forOp.getLoc();
Block *body = forOp.getBody();
@@ -200,6 +201,117 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
return newForOp;
}
// for(...) -> for(..., phase, pipelineIdx)
scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages,
scf::ForOp &parentForOp) {
auto loc = forOp.getLoc();
Block *body = forOp.getBody();
// The agentId set of pipelineIdx is the union of agentId sets of all ops in
// the for loop
OpBuilderWithAgentIds builder(forOp.getContext());
builder.setAgentIdsFromArray(collectAgentIds(forOp));
builder.setInsertionPoint(forOp);
Value numStagesVal =
builder.createWithAgentIds<arith::ConstantIntOp>(loc, numStages, 32);
// 0. Append pipelineIdx to block arguments
Value phase =
body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc);
Value pipelineIdx =
body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc);
// 1. prepare index and phase for next iteration
// nextIdx = curIdx + 1
// nextPhase = ((nextIdx < numStages && curPhase) || (nextIdx >= numStages &&
// curPhase^1))
// nextIdx = nextIdx >= numStages ? 0 : nextIdx
auto yieldOp = llvm::cast<scf::YieldOp>(body->getTerminator());
builder.setInsertionPoint(yieldOp);
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
Value _1_1b = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 1);
// generate index for next iter
Value nextPipelineIdx =
builder.createWithAgentIds<arith::AddIOp>(loc, pipelineIdx, one);
Value pipelineGECond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::uge, nextPipelineIdx, numStagesVal);
Value pipelineLTCond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::ult, nextPipelineIdx, numStagesVal);
Value cyclePipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, nextPipelineIdx, numStagesVal);
nextPipelineIdx = builder.createWithAgentIds<mlir::arith::SelectOp>(
loc, pipelineGECond, cyclePipelineIdx, nextPipelineIdx);
// generate phase for next iter
Value flipPhase =
builder.createWithAgentIds<mlir::arith::XOrIOp>(loc, phase, _1_1b);
Value cond0 = builder.createWithAgentIds<mlir::arith::AndIOp>(
loc, pipelineGECond, flipPhase);
Value cond1 = builder.createWithAgentIds<mlir::arith::AndIOp>(
loc, pipelineLTCond, phase);
Value nextPhase =
builder.createWithAgentIds<mlir::arith::OrIOp>(loc, cond0, cond1);
// 2. Append pipelineIdx to yield operands
yieldOp->insertOperands(yieldOp.getNumOperands(),
{nextPhase, nextPipelineIdx});
// 3. create newLoopArgs
SmallVector<Value> newLoopArgs;
for (auto operand : forOp.getInitArgs())
newLoopArgs.push_back(operand);
builder.setInsertionPoint(forOp);
Value initPipelineIdx, initEmptyIdx, initPhase;
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
if (parentForOp) {
// Make sure prior pipelineIdx is inserted in the end of parentForOp
initPipelineIdx = parentForOp.getBody()->getArguments().back();
Value numSteps = builder.createWithAgentIds<arith::SubIOp>(
loc, forOp.getUpperBound(), forOp.getLowerBound());
numSteps = builder.createWithAgentIds<arith::AddIOp>(loc, numSteps,
forOp.getStep());
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value two = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 2, 32);
numSteps = builder.createWithAgentIds<arith::SubIOp>(loc, numSteps, one);
numSteps = builder.createWithAgentIds<arith::DivUIOp>(loc, numSteps,
forOp.getStep());
// initPipelineIdx = (parentForOp.pipelineIdx * numSteps) % numStages
// initPhase = ((parentForOp.pipelineIdx * numSteps) / numStages) & 1
initPipelineIdx = builder.createWithAgentIds<arith::MulIOp>(
loc, initPipelineIdx, numSteps);
Value pipelineIdx = builder.createWithAgentIds<arith::DivUIOp>(
loc, initPipelineIdx, numStagesVal);
initPipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, initPipelineIdx,
builder.createWithAgentIds<arith::MulIOp>(loc, pipelineIdx,
numStagesVal));
pipelineIdx =
builder.createWithAgentIds<arith::AndIOp>(loc, pipelineIdx, one);
initPhase = builder.createWithAgentIds<arith::TruncIOp>(
loc, builder.getI1Type(), pipelineIdx);
} else {
// phase init to false and pipelineIdx init to 0
initPipelineIdx = zero;
initPhase = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 1);
}
newLoopArgs.append({initPhase, initPipelineIdx});
// 4. Create newForOp and take the region of forOp
auto newForOp = builder.createWithAgentIds<scf::ForOp>(
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
newLoopArgs);
newForOp.getRegion().takeBody(forOp.getRegion());
// 5. Replace forOp with newForOp
for (unsigned i = 0; i < forOp.getNumResults(); ++i)
forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i));
forOp.erase();
return newForOp;
}
//===----------------------------------------------------------------------===//
// appendPipelineIdxArgs
//===----------------------------------------------------------------------===//
@@ -217,7 +329,22 @@ void appendPipelineIdxArgs(SmallVector<Operation *> &backbone, int numStages) {
for (auto &op : orderedForOps) {
scf::ForOp parentForOp = op->getParentOfType<scf::ForOp>();
auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp);
scf::ForOp newForOp;
bool hasDotOp = false;
for (Operation &subOp : *op.getBody()) {
if (isa<triton::DotOp>(&subOp)) {
hasDotOp = true;
break;
}
}
if (hasDotOp) {
// for(...) -> for(..., phase, pipelineIdx)
newForOp = createNewMathLoop(op, numStages, parentForOp);
} else {
// for(...) -> for(..., pipelineIdx)
newForOp = createNewPersistentLoop(op, numStages, parentForOp);
}
auto backboneForItr =
std::find(backbone.begin(), backbone.end(), op.getOperation());
if (backboneForItr != backbone.end()) {
@@ -688,8 +815,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
}
builder.setAgentIdsFromArray(agentsPC);
Value pipelineIdx;
Value numStagesVal = builder.createWithAgentIds<arith::ConstantIntOp>(
headProducer->getLoc(), numStages, 32);
if (auto forOp = headProducer->getParentOfType<scf::ForOp>()) {
pipelineIdx = forOp.getBody()->getArguments().back();
} else {
@@ -700,10 +825,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
// insert ProducerAcquireOp
builder.setInsertionPoint(headProducer);
if (headProducer->getParentOfType<scf::ForOp>()) {
pipelineIdx = builder.createWithAgentIds<arith::RemSIOp>(
headProducer->getLoc(), pipelineIdx, numStagesVal);
}
builder.setAgentIdsFromArray(agentP);
builder.createWithAgentIds<ttng::ProducerAcquireOp>(headProducer->getLoc(),
token, pipelineIdx);
@@ -738,7 +859,8 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
dot.replaceAllUsesWith(dotAsync.getResult());
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
loc, dotAsync.getResult(), 1);
// 1. insert ConsumerReleaseOp for DotAsyncOps
Value cond = builder.createWithAgentIds<arith::CmpIOp>(
@@ -747,31 +869,43 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
auto ifOp =
builder.createWithAgentIds<scf::IfOp>(loc, ArrayRef<Type>{}, cond,
/*hasElse*/ false);
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
builder.setInsertionPointToStart(ifOp.thenBlock());
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(
headConsumer->getLoc(), 1, 32);
auto oriIdx = forOp.getBody()->getArguments().back();
Value consumerReleaseIdx =
builder.createWithAgentIds<arith::SubIOp>(loc, oriIdx, one);
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
loc, consumerReleaseIdx, numStagesVal);
Value consumerReleaseIdx = forOp.getBody()->getArguments().back();
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
Value lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
loc, numStages - 1, 32);
Value consumerReleaseIdxMinusOne =
builder.createWithAgentIds<arith::SubIOp>(loc, consumerReleaseIdx,
one);
cond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
loc, cond, lastStage, consumerReleaseIdxMinusOne);
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
consumerReleaseIdx);
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
0);
unsigned resultIndex = dotAsync->getUses().begin()->getOperandNumber();
Value result = forOp->getResult(resultIndex);
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
Value one_ = builder.createWithAgentIds<arith::ConstantIntOp>(
headConsumer->getLoc(), 1, 32);
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
loc, numStages - 1, 32);
consumerReleaseIdx = forOp.getResults().back();
consumerReleaseIdx = builder.createWithAgentIds<arith::SubIOp>(
loc, consumerReleaseIdx, one_);
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
loc, consumerReleaseIdx, numStagesVal);
consumerReleaseIdxMinusOne = builder.createWithAgentIds<arith::SubIOp>(
loc, consumerReleaseIdx, one);
cond = builder.createWithAgentIds<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
loc, cond, lastStage, consumerReleaseIdxMinusOne);
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
consumerReleaseIdx);
dotOp->erase();