mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge commit '721897fcc4f942aa97d2e9ba3787a5e213758177' into ifu-231108
Conflicts: bin/triton-translate.cpp lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp python/triton/compiler/compiler.py python/triton/runtime/jit.py python/tutorials/06-fused-attention.py test/Conversion/tritongpu_to_llvm.mlir
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinAttributes.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/FunctionImplementation.h"
|
||||
#include "mlir/IR/FunctionInterfaces.h"
|
||||
#include "mlir/IR/OperationSupport.h"
|
||||
#include "mlir/Interfaces/FunctionImplementation.h"
|
||||
#include "mlir/Interfaces/FunctionInterfaces.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Types.h"
|
||||
|
||||
|
||||
@@ -8,6 +8,6 @@ add_mlir_dialect_library(TritonGPUIR
|
||||
TritonGPUAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRGPUOps
|
||||
MLIRGPUDialect
|
||||
TritonIR
|
||||
)
|
||||
|
||||
@@ -1603,7 +1603,7 @@ ParseResult parseInsertSliceOp(OpAsmParser &parser, OperationState &result) {
|
||||
result.operands))
|
||||
return failure();
|
||||
|
||||
// Deduce operand_segment_sizes from the number of the operands.
|
||||
// Deduce operandSegmentSizes from the number of the operands.
|
||||
auto operandSegmentSizesAttrName =
|
||||
OpT::getOperandSegmentSizesAttrName(result.name);
|
||||
result.addAttribute(
|
||||
@@ -1616,7 +1616,7 @@ template <class OpT>
|
||||
void printInsertSliceOp(OpAsmPrinter &printer, OpT insertSliceOp) {
|
||||
printer << " ";
|
||||
printer << insertSliceOp.getOperation()->getOperands();
|
||||
// "operand_segment_sizes" can be deduced, so we don't print it.
|
||||
// "operandSegmentSizes" can be deduced, so we don't print it.
|
||||
printer.printOptionalAttrDict(
|
||||
insertSliceOp->getAttrs(),
|
||||
{insertSliceOp.getOperandSegmentSizesAttrName()});
|
||||
|
||||
@@ -139,7 +139,10 @@ class BlockedToMMA : public mlir::RewritePattern {
|
||||
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
|
||||
int origBitWidth = finalBitWidth;
|
||||
SetVector<Operation *> slice;
|
||||
mlir::getBackwardSlice(x, &slice, bwdFilter);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = bwdFilter;
|
||||
getBackwardSlice(x, &slice, opt);
|
||||
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
|
||||
if (firstOp)
|
||||
if (Value arg = firstOp->getOperand(0))
|
||||
@@ -235,8 +238,11 @@ public:
|
||||
if (versionMajor == 1) {
|
||||
SetVector<Operation *> aBwdSlices, bBwdSlices;
|
||||
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
|
||||
getBackwardSlice(a, &aBwdSlices, {isCvt});
|
||||
getBackwardSlice(b, &bBwdSlices, {isCvt});
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
opt.filter = isCvt;
|
||||
getBackwardSlice(a, &aBwdSlices, opt);
|
||||
getBackwardSlice(b, &bBwdSlices, opt);
|
||||
// get the source of the first conversion found in slices
|
||||
auto getCvtArgOrder = [](Operation *op) {
|
||||
return cast<ConvertLayoutOp>(op)
|
||||
|
||||
@@ -98,7 +98,9 @@ public:
|
||||
// and all operations between the load and the conversion
|
||||
// should be layout preserving
|
||||
SetVector<Operation *> slice;
|
||||
getBackwardSlice(op, &slice);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
getBackwardSlice(op, &slice, opt);
|
||||
int loadIdx = -1;
|
||||
bool checkOp = false;
|
||||
for (int i = 0; i < slice.size(); i++) {
|
||||
|
||||
@@ -160,6 +160,8 @@ class LoopPipeliner {
|
||||
void checkOpShareBarriers(SetVector<Operation *> &ops);
|
||||
int numLoadsRequireAsyncWait = 0;
|
||||
int numLoadsRequireMBarrier = 0;
|
||||
// Number of buffers to allocate for each input.
|
||||
int numSharedMemorySlices = 0;
|
||||
|
||||
/// Iterator values
|
||||
Value nextIV;
|
||||
@@ -280,9 +282,12 @@ class LoopPipeliner {
|
||||
|
||||
public:
|
||||
LoopPipeliner(scf::ForOp forOp, int numStages, int numWarps, int numCTAs,
|
||||
bool mode, ConsumerReleaseMap &consumerReleaseMap)
|
||||
bool mode, int numSharedMemorySlices,
|
||||
ConsumerReleaseMap &consumerReleaseMap)
|
||||
: forOp(forOp), numStages(numStages), numWarps(numWarps),
|
||||
numCTAs(numCTAs), mode(mode), consumerReleaseMap(consumerReleaseMap) {
|
||||
numCTAs(numCTAs), mode(mode),
|
||||
numSharedMemorySlices(numSharedMemorySlices),
|
||||
consumerReleaseMap(consumerReleaseMap) {
|
||||
// cache yieldOp
|
||||
yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
|
||||
}
|
||||
@@ -644,7 +649,7 @@ void LoopPipeliner::createBufferTypes() {
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
bufferShape.insert(bufferShape.begin(), numSharedMemorySlices);
|
||||
auto CTALayout = ttg::getCTALayout(ty.getEncoding());
|
||||
Attribute sharedEnc;
|
||||
if (auto dotOpEnc = cvt.getType()
|
||||
@@ -946,6 +951,11 @@ void LoopPipeliner::emitPrologue() {
|
||||
pipelineIterIdx = builder.create<arith::AddIOp>(
|
||||
iv.getLoc(), pipelineIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(iv.getLoc(), 1, 32));
|
||||
Value numSlices = builder.create<arith::ConstantIntOp>(
|
||||
iv.getLoc(), numSharedMemorySlices, 32);
|
||||
Value _0 = builder.create<arith::ConstantIntOp>(iv.getLoc(), 0, 32);
|
||||
pipelineIterIdx = getBoundedIterationValue(builder, pipelineIterIdx,
|
||||
numSlices, pipelineIterIdx, _0);
|
||||
// Some values have not been used by any ops in the loop body
|
||||
for (BlockArgument arg : forOp.getRegionIterArgs())
|
||||
setValueMappingYield(arg, valueMapping[arg][stage], stage + 1);
|
||||
@@ -1220,11 +1230,13 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
Value _1 = builder.create<arith::ConstantIntOp>(idxLoc, 1, 32);
|
||||
Value numStagesVal =
|
||||
builder.create<arith::ConstantIntOp>(idxLoc, numStages, 32);
|
||||
Value numSlices =
|
||||
builder.create<arith::ConstantIntOp>(idxLoc, numSharedMemorySlices, 32);
|
||||
|
||||
// nextWaitIdx
|
||||
Value waitIdxPlusOne = builder.create<arith::AddIOp>(idxLoc, curWaitIdx, _1);
|
||||
Value nextWaitIdx = getBoundedIterationValue(
|
||||
builder, waitIdxPlusOne, numStagesVal, waitIdxPlusOne, _0);
|
||||
Value nextWaitIdx = getBoundedIterationValue(builder, waitIdxPlusOne,
|
||||
numSlices, waitIdxPlusOne, _0);
|
||||
|
||||
// Indices of InsertSliceAsyncOp and ExtractSliceOp
|
||||
Value insertSliceIndex = pipelineIterIdx;
|
||||
@@ -1417,9 +1429,8 @@ void LoopPipeliner::prefetchNextIteration(scf::ForOp newForOp,
|
||||
// Bump pipelineIterIdx
|
||||
Value pipelineIterIdxPlusOne =
|
||||
builder.create<arith::AddIOp>(idxLoc, pipelineIterIdx, _1);
|
||||
pipelineIterIdx =
|
||||
getBoundedIterationValue(builder, pipelineIterIdxPlusOne, numStagesVal,
|
||||
pipelineIterIdxPlusOne, _0);
|
||||
pipelineIterIdx = getBoundedIterationValue(
|
||||
builder, pipelineIterIdxPlusOne, numSlices, pipelineIterIdxPlusOne, _0);
|
||||
|
||||
// Bump curWaitIdx
|
||||
curWaitIdx = nextWaitIdx;
|
||||
@@ -1516,10 +1527,23 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
|
||||
llvm::SmallVector<scf::ForOp> newForOps;
|
||||
|
||||
// Currently we schedule stage 0 after stage `numStages - 1` during
|
||||
// pipelining therefore we only need `numStages - 1` slice of memory.
|
||||
// On Hopper we have a separate post-processing that pipelines wgmma so we
|
||||
// need an extra buffer for each input.
|
||||
// Note that an alternative would be to keep allocating `numStages` buffers
|
||||
// and remove the barrier between the loads from shared memory and the
|
||||
// copies from global to shared. This would require improving existing
|
||||
// membar analysis.
|
||||
int numSharedMemorySlices =
|
||||
computeCapability < 90 ? numStages - 1 : numStages;
|
||||
|
||||
// Do the pipelining
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
LoopPipeliner pipeliner(forOp, this->numStages, this->numWarps,
|
||||
this->numCTAs, mode, consumerReleaseMap);
|
||||
this->numCTAs, mode, numSharedMemorySlices,
|
||||
consumerReleaseMap);
|
||||
if (pipeliner.initialize().failed())
|
||||
return;
|
||||
|
||||
@@ -1593,7 +1617,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
|
||||
/// XXX(Keren): Clean up the following duplicate code with checkDotOp
|
||||
/// dots to be pipelined
|
||||
SetVector<Value> dots;
|
||||
SmallVector<tt::DotOp> dots;
|
||||
SmallVector<unsigned> resultNeedSync;
|
||||
for (Operation &op : *loop) {
|
||||
if (auto dotOp = dyn_cast<tt::DotOp>(&op)) {
|
||||
auto resTy = dotOp.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
@@ -1615,8 +1640,11 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
if (!CArg || !CArg.hasOneUse())
|
||||
valid = false;
|
||||
|
||||
if (valid)
|
||||
dots.insert(dotOp);
|
||||
if (valid) {
|
||||
dots.push_back(dotOp);
|
||||
resultNeedSync.push_back(
|
||||
dotOp->getUses().begin()->getOperandNumber());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1627,39 +1655,39 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
return;
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
|
||||
// 0. insert dot_wait after the last dot in the loop
|
||||
Value dot = dots.back();
|
||||
auto loc = dot.getLoc();
|
||||
builder.setInsertionPointAfter(dot.getDefiningOp());
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(loc, dots.size());
|
||||
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
|
||||
// wgmma ops by one stage.
|
||||
// This is needed to prevent shared memory inputs to be overriden before the
|
||||
// operation is completed.
|
||||
// TODO: merge this with the rest of the pipelining transformation and look at
|
||||
// a better representation for async dots.
|
||||
tt::DotOp lastDot = dots.back();
|
||||
builder.setInsertionPointAfter(lastDot);
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
|
||||
lastDot.getLoc(), lastDot.getResult(), dots.size());
|
||||
|
||||
// 1. replace Dot with DotAsync
|
||||
for (size_t idx = 0; idx < dots.size(); ++idx) {
|
||||
Value dot = dots[idx];
|
||||
auto dotOp = cast<tt::DotOp>(dot.getDefiningOp());
|
||||
builder.setInsertionPoint(dot.getDefiningOp());
|
||||
tt::DotOp dotOp = dots[idx];
|
||||
builder.setInsertionPoint(dotOp);
|
||||
auto dotAsync = builder.create<tt::nvidia_gpu::DotAsyncOp>(
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(),
|
||||
dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1);
|
||||
dot.getDefiningOp()->erase();
|
||||
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dotOp.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1);
|
||||
dotOp->erase();
|
||||
}
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
Value loopNotEmpty = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
|
||||
forOp.getUpperBound());
|
||||
// TODO[goostavz]: it's a workaround to put the DotWaitOp in an IfOp for
|
||||
// a bug in ptxas which mistakenly analysis the control flow and turn the GMMA
|
||||
// into synchronuous implementation for safety.
|
||||
// Remove this If once the bug is fixed.
|
||||
auto ifOp = builder.create<scf::IfOp>(loc, ArrayRef<Type>{}, loopNotEmpty,
|
||||
/*hasElse*/ false);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), 0);
|
||||
for (unsigned resultIndex : resultNeedSync) {
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
auto dotWait =
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
}
|
||||
}
|
||||
|
||||
Value PipelinePass::getRemoteCTAId(OpBuilder &b, Location loc,
|
||||
|
||||
@@ -31,6 +31,7 @@ using triton::gpu::SliceEncodingAttr;
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
<<<<<<< HEAD
|
||||
// convert(blocked, dot_operand) ->
|
||||
// convert(blocked, mma) + convert(mma, dot_operand)
|
||||
// if this value is itself the result of a dot operation
|
||||
@@ -102,6 +103,9 @@ public:
|
||||
};
|
||||
|
||||
//
|
||||
=======
|
||||
// dot(a, b, load(ptr)) -> add(load(ptr), dot(a, b, 0))
|
||||
>>>>>>> 721897fcc4f942aa97d2e9ba3787a5e213758177
|
||||
class ConvertDotConvert : public mlir::RewritePattern {
|
||||
public:
|
||||
ConvertDotConvert(mlir::MLIRContext *context)
|
||||
@@ -233,12 +237,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
|
||||
getForwardSlice(currentValue, &forwardSlice);
|
||||
for (Operation *op : forwardSlice) {
|
||||
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
if (convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.isa<triton::gpu::MmaEncodingAttr>())
|
||||
return true;
|
||||
Attribute dstEncoding = convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding();
|
||||
if (auto mmaLayout =
|
||||
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
|
||||
return (mmaLayout.getVersionMajor() > 1) ? true
|
||||
: mmaLayout == encoding;
|
||||
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return encoding.cast<triton::gpu::MmaEncodingAttr>()
|
||||
.getVersionMajor() > 1;
|
||||
}
|
||||
auto yield = dyn_cast<scf::YieldOp>(op);
|
||||
if (!yield)
|
||||
@@ -560,6 +569,15 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) {
|
||||
return rewrittenValue;
|
||||
OpBuilder rewriter(value.getContext());
|
||||
rewriter.setInsertionPointAfterValue(rewrittenValue);
|
||||
// Workaround: The pipeliner will insert async.wait after a pipelined loop
|
||||
// to ensure that there is no pending copies and it is safe to re-use shared
|
||||
// memory. We shouldn't insert ops that may use shared memory in between the
|
||||
// loop and the async.wait. This is a hack until we fix the IR
|
||||
// representation of async wait.
|
||||
if (Operation *op = rewrittenValue.getDefiningOp()) {
|
||||
if (isa<triton::gpu::AsyncWaitOp>(op->getNextNode()))
|
||||
rewriter.setInsertionPointAfter(op->getNextNode());
|
||||
}
|
||||
auto tmpType = RankedTensorType::get(tensorType.getShape(),
|
||||
tensorType.getElementType(), encoding);
|
||||
Value converted = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -1122,7 +1140,6 @@ public:
|
||||
hoistConvert(m);
|
||||
|
||||
mlir::RewritePatternSet decomposePatterns(context);
|
||||
decomposePatterns.add<DecomposeDotOperand>(context);
|
||||
decomposePatterns.add<ConvertDotConvert>(context);
|
||||
if (mlir::applyPatternsAndFoldGreedily(m, std::move(decomposePatterns))
|
||||
.failed()) {
|
||||
|
||||
@@ -91,7 +91,7 @@ private:
|
||||
// suport ForOp only
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(argOwner)) {
|
||||
// prologue
|
||||
auto iterOperands = forOp.getIterOperands();
|
||||
auto iterOperands = forOp.getInitArgs();
|
||||
if (argNum == 0)
|
||||
return false;
|
||||
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))
|
||||
|
||||
@@ -628,12 +628,13 @@ bool CTAPlanner::isElementwiseOp(Operation *op) const {
|
||||
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
|
||||
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
|
||||
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
|
||||
arith::MaxFOp, arith::MaxSIOp, arith::MaxUIOp, arith::MinFOp,
|
||||
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
|
||||
arith::NegFOp, arith::OrIOp, arith::RemFOp, arith::RemSIOp,
|
||||
arith::RemUIOp, arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp,
|
||||
arith::SIToFPOp, arith::SubFOp, arith::SubIOp, arith::TruncFOp,
|
||||
arith::TruncIOp, arith::UIToFPOp, arith::XOrIOp>(op))
|
||||
arith::MaximumFOp, arith::MaxSIOp, arith::MaxUIOp,
|
||||
arith::MinimumFOp, arith::MinSIOp, arith::MinUIOp,
|
||||
arith::MulFOp, arith::MulIOp, arith::NegFOp, arith::OrIOp,
|
||||
arith::RemFOp, arith::RemSIOp, arith::RemUIOp, arith::ShLIOp,
|
||||
arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp, arith::SubFOp,
|
||||
arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
|
||||
arith::UIToFPOp, arith::XOrIOp>(op))
|
||||
return true;
|
||||
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
|
||||
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
|
||||
|
||||
@@ -220,7 +220,9 @@ public:
|
||||
SetVector<Operation *> backwardSlice;
|
||||
mod.walk([&](triton::MakeTensorPtrOp op) -> void {
|
||||
assert(isa<triton::FuncOp>(op->getParentOp()));
|
||||
getBackwardSlice(op.getOperation(), &backwardSlice);
|
||||
mlir::BackwardSliceOptions opt;
|
||||
opt.omitBlockArguments = true;
|
||||
getBackwardSlice(op.getOperation(), &backwardSlice, opt);
|
||||
op->removeAttr("async_agent");
|
||||
});
|
||||
for (auto op : backwardSlice) {
|
||||
|
||||
@@ -79,6 +79,7 @@ void materializeGetAgentIdOp(Operation *parentOp) {
|
||||
builder.setInsertionPoint(agentIdOp);
|
||||
Value globalRoleId = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
int globalNumWarps = 0;
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (auto cmpOp : agentIdOp->getUsers()) {
|
||||
assert(isa<arith::CmpIOp>(cmpOp));
|
||||
for (auto u : cmpOp->getUsers()) {
|
||||
@@ -111,11 +112,14 @@ void materializeGetAgentIdOp(Operation *parentOp) {
|
||||
Value cond =
|
||||
builder.create<arith::AndIOp>(loc, lowerBound, upperBound);
|
||||
cmpOp->getResult(0).replaceAllUsesWith(cond);
|
||||
cmpOp->erase();
|
||||
deprecatedOps.push_back(cmpOp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Operation *cmpOp : deprecatedOps) {
|
||||
cmpOp->erase();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -145,39 +149,24 @@ LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) {
|
||||
}
|
||||
|
||||
Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op,
|
||||
bool skipFirstWait) {
|
||||
bool emptyBarrier) {
|
||||
// TODO: currently we only support one loop, no nested loop, while or
|
||||
// condition.
|
||||
auto loc = op->getLoc();
|
||||
auto forOp = op->getParentOfType<scf::ForOp>();
|
||||
if (!forOp) {
|
||||
return builder.create<arith::ConstantIntOp>(loc, skipFirstWait, 1);
|
||||
return builder.create<arith::ConstantIntOp>(loc, emptyBarrier, 1);
|
||||
}
|
||||
|
||||
auto defOp = op->getOperand(0).getDefiningOp();
|
||||
assert(isa<ttng::CreateTokenOp>(defOp) &&
|
||||
"mbarrier's definingOp is not createTokenOp");
|
||||
ttng::CreateTokenOp createTokenOp = dyn_cast<ttng::CreateTokenOp>(defOp);
|
||||
Value numStage =
|
||||
builder.create<arith::ConstantIntOp>(loc, createTokenOp.getNum(), 32);
|
||||
Value curStep = forOp.getBody()->getArguments().back();
|
||||
if (curStep.getType() == builder.getIndexType()) {
|
||||
curStep =
|
||||
builder.create<arith::IndexCastOp>(loc, numStage.getType(), curStep);
|
||||
// for (..., phase, pipelineIdx)
|
||||
unsigned numArgs = forOp.getBody()->getNumArguments();
|
||||
assert(numArgs > 2 && "Unexpected number of arguments");
|
||||
Value curPhase = forOp.getBody()->getArgument(numArgs - 2);
|
||||
if (emptyBarrier) {
|
||||
Value _1_1b = builder.create<arith::ConstantIntOp>(loc, 1, 1);
|
||||
curPhase = builder.create<mlir::arith::XOrIOp>(loc, curPhase, _1_1b);
|
||||
}
|
||||
Value curPhase = builder.create<arith::DivUIOp>(loc, curStep, numStage);
|
||||
if (skipFirstWait) {
|
||||
// If skipFirstWait, it waits for phaseBit 1
|
||||
Value _1 = builder.create<arith::ConstantIntOp>(loc, 1, 32);
|
||||
curPhase = builder.create<arith::AddIOp>(loc, curPhase, _1);
|
||||
}
|
||||
Value _2 = builder.create<arith::ConstantIntOp>(loc, 2, 32);
|
||||
// TODO: May use alternative methods of phaseBit calculation to avoid high
|
||||
// overhead of RemOp
|
||||
Value phaseBit = builder.create<arith::RemUIOp>(loc, curPhase, _2);
|
||||
Value _0 = builder.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
return builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, phaseBit,
|
||||
_0);
|
||||
return curPhase;
|
||||
}
|
||||
|
||||
int getTxBytes(ttng::InsertSliceAsyncV2Op load) {
|
||||
@@ -260,7 +249,7 @@ void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op,
|
||||
auto loc = op.getLoc();
|
||||
// The first producer_aquire should be met immediately, so initailly producer
|
||||
// skips the fisrt wait
|
||||
Value phase = getMBarrierPhaseBit(builder, op, 1);
|
||||
Value phase = getMBarrierPhaseBit(builder, op, true);
|
||||
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferEmpty, phase);
|
||||
assert(op.getOperation()->hasAttr("async_agent"));
|
||||
setAgentIds(waitOp, getAgentIds(op.getOperation()));
|
||||
@@ -296,7 +285,7 @@ void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op,
|
||||
void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op,
|
||||
Value bufferFull) {
|
||||
auto loc = op.getLoc();
|
||||
Value phase = getMBarrierPhaseBit(builder, op, 0);
|
||||
Value phase = getMBarrierPhaseBit(builder, op, false);
|
||||
auto waitOp = builder.create<ttng::MBarrierWaitOp>(loc, bufferFull, phase);
|
||||
assert(op.getOperation()->hasAttr("async_agent"));
|
||||
setAgentIds(waitOp, getAgentIds(op.getOperation()));
|
||||
@@ -530,6 +519,7 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
|
||||
builder.create<arith::ConstantIntOp>(loc, nameBarrierId - 1, 32);
|
||||
// Process mutex users
|
||||
int numUsers = 0;
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (Operation *user : createMutexOp.getResult().getUsers()) {
|
||||
numUsers++;
|
||||
assert(numUsers <= 2);
|
||||
@@ -543,14 +533,20 @@ void mutexSyncPingPang(Operation *parentOp, int numAgents, int &nameBarrierId,
|
||||
Value barLeave = builder.create<arith::SelectOp>(
|
||||
loc, isRole0, namedBarrierId1, namedBarrierId0);
|
||||
builder.create<ttng::NamedBarrierArriveOp>(loc, barLeave, numThreads);
|
||||
} else
|
||||
} else {
|
||||
llvm_unreachable("Unexpected user of mutex");
|
||||
}
|
||||
deprecatedOps.push_back(user);
|
||||
}
|
||||
for (Operation *user : deprecatedOps) {
|
||||
user->erase();
|
||||
}
|
||||
nameBarrierId -= 2;
|
||||
nameBarrierIdEnd -= 2;
|
||||
createMutexOp.erase();
|
||||
});
|
||||
|
||||
parentOp->walk(
|
||||
[](ttng::CreateMutexOp createMutexOp) { createMutexOp.erase(); });
|
||||
}
|
||||
|
||||
void processLockOp(OpBuilder &builder, ttng::LockOp op) {
|
||||
@@ -587,6 +583,7 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
|
||||
OpBuilder builder(createMutexOp);
|
||||
|
||||
// Process mutex users
|
||||
SmallVector<Operation *> deprecatedOps;
|
||||
for (Operation *user : createMutexOp.getResult().getUsers()) {
|
||||
auto loc = user->getLoc();
|
||||
builder.setInsertionPoint(user);
|
||||
@@ -596,6 +593,9 @@ void materializeMutexOperationsOthers(ModuleOp parentOp) {
|
||||
processUnlockOp(builder, op);
|
||||
else
|
||||
llvm_unreachable("Unexpected user of mutex");
|
||||
deprecatedOps.push_back(user);
|
||||
}
|
||||
for (Operation *user : deprecatedOps) {
|
||||
user->erase();
|
||||
}
|
||||
|
||||
|
||||
@@ -156,14 +156,20 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
persistentForOp.getInitArgsMutable()
|
||||
.slice(persistentForOp.getInitArgs().size() - 1, 1)
|
||||
.assign(newIdx);
|
||||
auto yield =
|
||||
llvm::cast<scf::YieldOp>(persistentForOp.getBody()->getTerminator());
|
||||
auto idxPlusOneOp =
|
||||
yield->getOperand(yield->getNumOperands() - 1).getDefiningOp();
|
||||
assert(isa<arith::AddIOp>(idxPlusOneOp));
|
||||
assert(idxPlusOneOp->getOperand(0) ==
|
||||
persistentForOp.getBody()->getArgument(
|
||||
persistentForOp.getBody()->getNumArguments() - 1));
|
||||
|
||||
pipelineIdx = persistentForOp.getBody()->getArgument(
|
||||
persistentForOp.getBody()->getNumArguments() - 1);
|
||||
Operation *idxPlusOneOp = nullptr;
|
||||
for (OpOperand &v : pipelineIdx.getUses()) {
|
||||
if (isa<arith::AddIOp>(v.getOwner())) {
|
||||
idxPlusOneOp = v.getOwner();
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert(idxPlusOneOp && "idxPlusOneOp should be arith::AddIOp");
|
||||
Operation *use = *idxPlusOneOp->getUsers().begin();
|
||||
assert(isa<scf::YieldOp>(use) || isa<arith::SelectOp>(use) ||
|
||||
isa<arith::CmpIOp>(use));
|
||||
idxPlusOneOp->setOperand(1, numRolesValue);
|
||||
|
||||
// Add operations at the start of persistentForOp
|
||||
@@ -213,45 +219,6 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
unlockLocs[i] = op;
|
||||
}
|
||||
|
||||
// Update unlockLocs
|
||||
// ====================== IR after async launch dots ======================
|
||||
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
|
||||
// %3) {
|
||||
// * triton_nvidia_gpu.producer_wait arg2
|
||||
// * %5 = triton_nvidia_gpu.dot_async %4, %5
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 1}
|
||||
// * %6 = arith.cmpi sgt, arg0, %c0
|
||||
// * scf.if %6 {
|
||||
// * %7 = arith.subi arg2, c1
|
||||
// * triton_nvidia_gpu.consumer_release %7
|
||||
// * }
|
||||
// * %8 = arith.addi arg2, c1
|
||||
// * scf.yield %5, %8
|
||||
// * }
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 0}
|
||||
// * %9 = arith.subi %0#1, c1
|
||||
// * triton_nvidia_gpu.consumer_release %9
|
||||
// * =======================================================================
|
||||
// after async launch dots, there will be outstanding consumerReleaseOp after
|
||||
// ForOp. we should expend the unlockLocs from ForOp to the outstanding
|
||||
// consumerReleaseOp.
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
Operation *unlockOp = unlockLocs[i];
|
||||
auto filter = [&](Operation *op) {
|
||||
return op->getBlock() == unlockOp->getBlock();
|
||||
};
|
||||
if (isa<scf::ForOp>(unlockOp)) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter});
|
||||
auto iter = llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
|
||||
});
|
||||
if (iter != slices.end()) {
|
||||
unlockLocs[i] = *iter;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Only cases where all lock/unlock locations are in same level make sense.
|
||||
for (int i = 1; i < numRoles; ++i) {
|
||||
if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() ||
|
||||
@@ -281,6 +248,54 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp,
|
||||
else
|
||||
lockLocs[i] = unlockLocs[prevTypeIds[i]];
|
||||
}
|
||||
|
||||
// Update lockLocs
|
||||
// ====================== IR after async launch dots ======================
|
||||
// * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 =
|
||||
// %3) {
|
||||
// * triton_nvidia_gpu.producer_wait arg2
|
||||
// * %5 = triton_nvidia_gpu.dot_async %4, %5
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 1}
|
||||
// * %6 = arith.cmpi sgt, arg0, %c0
|
||||
// * scf.if %6 {
|
||||
// * %7 = arith.subi arg2, c1
|
||||
// * triton_nvidia_gpu.consumer_release %7
|
||||
// * }
|
||||
// * %8 = arith.addi arg2, c1
|
||||
// * scf.yield %5, %8
|
||||
// * }
|
||||
// * triton_nvidia_gpu.dot_wait {pendings = 0}
|
||||
// * ...
|
||||
// * triton_nvidia_gpu.consumer_release ..
|
||||
// * =======================================================================
|
||||
// after async launch dots, there will be outstanding consumerReleaseOp after
|
||||
// ForOp. we should set the epilogue lockLocs after the outstanding
|
||||
// consumerReleaseOp.
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
Operation *lockOp = lockLocs[i];
|
||||
if (isa<scf::ForOp>(lockOp)) {
|
||||
Operation *loc = nullptr;
|
||||
unsigned numOutstandingConsumerRelease = 0;
|
||||
for (auto v : lockOp->getResults()) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(v, &slices);
|
||||
auto iter = llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::nvidia_gpu::ConsumerReleaseOp>(op);
|
||||
});
|
||||
if (iter != slices.end()) {
|
||||
numOutstandingConsumerRelease++;
|
||||
loc = *iter;
|
||||
}
|
||||
}
|
||||
assert(numOutstandingConsumerRelease <= 1 &&
|
||||
"should have only one outstanding "
|
||||
"consumerReleaseOp after "
|
||||
"async launch dots");
|
||||
if (loc)
|
||||
lockLocs[i] = loc;
|
||||
}
|
||||
}
|
||||
|
||||
// lock
|
||||
for (int i = 0; i < numRoles; ++i) {
|
||||
builder.setInsertionPointAfter(lockLocs[i]);
|
||||
|
||||
@@ -129,11 +129,12 @@ DenseMap<Channel *, Value> createBuffer(const SmallVector<Channel *> &channels,
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// appendPipelineIdxToLoopArgs
|
||||
// createNewLoops
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
// for(...) -> for(..., pipelineIdx)
|
||||
scf::ForOp createNewPersistentLoop(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
auto loc = forOp.getLoc();
|
||||
Block *body = forOp.getBody();
|
||||
|
||||
@@ -200,6 +201,117 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages,
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
// for(...) -> for(..., phase, pipelineIdx)
|
||||
scf::ForOp createNewMathLoop(scf::ForOp forOp, int numStages,
|
||||
scf::ForOp &parentForOp) {
|
||||
auto loc = forOp.getLoc();
|
||||
Block *body = forOp.getBody();
|
||||
|
||||
// The agentId set of pipelineIdx is the union of agentId sets of all ops in
|
||||
// the for loop
|
||||
OpBuilderWithAgentIds builder(forOp.getContext());
|
||||
builder.setAgentIdsFromArray(collectAgentIds(forOp));
|
||||
|
||||
builder.setInsertionPoint(forOp);
|
||||
Value numStagesVal =
|
||||
builder.createWithAgentIds<arith::ConstantIntOp>(loc, numStages, 32);
|
||||
|
||||
// 0. Append pipelineIdx to block arguments
|
||||
Value phase =
|
||||
body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc);
|
||||
Value pipelineIdx =
|
||||
body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc);
|
||||
|
||||
// 1. prepare index and phase for next iteration
|
||||
// nextIdx = curIdx + 1
|
||||
// nextPhase = ((nextIdx < numStages && curPhase) || (nextIdx >= numStages &&
|
||||
// curPhase^1))
|
||||
// nextIdx = nextIdx >= numStages ? 0 : nextIdx
|
||||
auto yieldOp = llvm::cast<scf::YieldOp>(body->getTerminator());
|
||||
builder.setInsertionPoint(yieldOp);
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value _1_1b = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 1);
|
||||
// generate index for next iter
|
||||
Value nextPipelineIdx =
|
||||
builder.createWithAgentIds<arith::AddIOp>(loc, pipelineIdx, one);
|
||||
Value pipelineGECond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::uge, nextPipelineIdx, numStagesVal);
|
||||
Value pipelineLTCond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::ult, nextPipelineIdx, numStagesVal);
|
||||
Value cyclePipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, nextPipelineIdx, numStagesVal);
|
||||
nextPipelineIdx = builder.createWithAgentIds<mlir::arith::SelectOp>(
|
||||
loc, pipelineGECond, cyclePipelineIdx, nextPipelineIdx);
|
||||
// generate phase for next iter
|
||||
Value flipPhase =
|
||||
builder.createWithAgentIds<mlir::arith::XOrIOp>(loc, phase, _1_1b);
|
||||
Value cond0 = builder.createWithAgentIds<mlir::arith::AndIOp>(
|
||||
loc, pipelineGECond, flipPhase);
|
||||
Value cond1 = builder.createWithAgentIds<mlir::arith::AndIOp>(
|
||||
loc, pipelineLTCond, phase);
|
||||
Value nextPhase =
|
||||
builder.createWithAgentIds<mlir::arith::OrIOp>(loc, cond0, cond1);
|
||||
|
||||
// 2. Append pipelineIdx to yield operands
|
||||
yieldOp->insertOperands(yieldOp.getNumOperands(),
|
||||
{nextPhase, nextPipelineIdx});
|
||||
|
||||
// 3. create newLoopArgs
|
||||
SmallVector<Value> newLoopArgs;
|
||||
for (auto operand : forOp.getInitArgs())
|
||||
newLoopArgs.push_back(operand);
|
||||
|
||||
builder.setInsertionPoint(forOp);
|
||||
Value initPipelineIdx, initEmptyIdx, initPhase;
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
if (parentForOp) {
|
||||
// Make sure prior pipelineIdx is inserted in the end of parentForOp
|
||||
initPipelineIdx = parentForOp.getBody()->getArguments().back();
|
||||
Value numSteps = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, forOp.getUpperBound(), forOp.getLowerBound());
|
||||
numSteps = builder.createWithAgentIds<arith::AddIOp>(loc, numSteps,
|
||||
forOp.getStep());
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value two = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 2, 32);
|
||||
numSteps = builder.createWithAgentIds<arith::SubIOp>(loc, numSteps, one);
|
||||
numSteps = builder.createWithAgentIds<arith::DivUIOp>(loc, numSteps,
|
||||
forOp.getStep());
|
||||
// initPipelineIdx = (parentForOp.pipelineIdx * numSteps) % numStages
|
||||
// initPhase = ((parentForOp.pipelineIdx * numSteps) / numStages) & 1
|
||||
initPipelineIdx = builder.createWithAgentIds<arith::MulIOp>(
|
||||
loc, initPipelineIdx, numSteps);
|
||||
Value pipelineIdx = builder.createWithAgentIds<arith::DivUIOp>(
|
||||
loc, initPipelineIdx, numStagesVal);
|
||||
initPipelineIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, initPipelineIdx,
|
||||
builder.createWithAgentIds<arith::MulIOp>(loc, pipelineIdx,
|
||||
numStagesVal));
|
||||
pipelineIdx =
|
||||
builder.createWithAgentIds<arith::AndIOp>(loc, pipelineIdx, one);
|
||||
initPhase = builder.createWithAgentIds<arith::TruncIOp>(
|
||||
loc, builder.getI1Type(), pipelineIdx);
|
||||
} else {
|
||||
// phase init to false and pipelineIdx init to 0
|
||||
initPipelineIdx = zero;
|
||||
initPhase = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 1);
|
||||
}
|
||||
newLoopArgs.append({initPhase, initPipelineIdx});
|
||||
|
||||
// 4. Create newForOp and take the region of forOp
|
||||
auto newForOp = builder.createWithAgentIds<scf::ForOp>(
|
||||
loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(),
|
||||
newLoopArgs);
|
||||
newForOp.getRegion().takeBody(forOp.getRegion());
|
||||
|
||||
// 5. Replace forOp with newForOp
|
||||
for (unsigned i = 0; i < forOp.getNumResults(); ++i)
|
||||
forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i));
|
||||
forOp.erase();
|
||||
|
||||
return newForOp;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// appendPipelineIdxArgs
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -217,7 +329,22 @@ void appendPipelineIdxArgs(SmallVector<Operation *> &backbone, int numStages) {
|
||||
|
||||
for (auto &op : orderedForOps) {
|
||||
scf::ForOp parentForOp = op->getParentOfType<scf::ForOp>();
|
||||
auto newForOp = appendPipelineIdxToLoopArgs(op, numStages, parentForOp);
|
||||
scf::ForOp newForOp;
|
||||
bool hasDotOp = false;
|
||||
for (Operation &subOp : *op.getBody()) {
|
||||
if (isa<triton::DotOp>(&subOp)) {
|
||||
hasDotOp = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (hasDotOp) {
|
||||
// for(...) -> for(..., phase, pipelineIdx)
|
||||
newForOp = createNewMathLoop(op, numStages, parentForOp);
|
||||
} else {
|
||||
// for(...) -> for(..., pipelineIdx)
|
||||
newForOp = createNewPersistentLoop(op, numStages, parentForOp);
|
||||
}
|
||||
auto backboneForItr =
|
||||
std::find(backbone.begin(), backbone.end(), op.getOperation());
|
||||
if (backboneForItr != backbone.end()) {
|
||||
@@ -688,8 +815,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
}
|
||||
builder.setAgentIdsFromArray(agentsPC);
|
||||
Value pipelineIdx;
|
||||
Value numStagesVal = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headProducer->getLoc(), numStages, 32);
|
||||
if (auto forOp = headProducer->getParentOfType<scf::ForOp>()) {
|
||||
pipelineIdx = forOp.getBody()->getArguments().back();
|
||||
} else {
|
||||
@@ -700,10 +825,6 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
|
||||
// insert ProducerAcquireOp
|
||||
builder.setInsertionPoint(headProducer);
|
||||
if (headProducer->getParentOfType<scf::ForOp>()) {
|
||||
pipelineIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
headProducer->getLoc(), pipelineIdx, numStagesVal);
|
||||
}
|
||||
builder.setAgentIdsFromArray(agentP);
|
||||
builder.createWithAgentIds<ttng::ProducerAcquireOp>(headProducer->getLoc(),
|
||||
token, pipelineIdx);
|
||||
@@ -738,7 +859,8 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(loc, 1);
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
loc, dotAsync.getResult(), 1);
|
||||
|
||||
// 1. insert ConsumerReleaseOp for DotAsyncOps
|
||||
Value cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
@@ -747,31 +869,43 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
auto ifOp =
|
||||
builder.createWithAgentIds<scf::IfOp>(loc, ArrayRef<Type>{}, cond,
|
||||
/*hasElse*/ false);
|
||||
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headConsumer->getLoc(), 1, 32);
|
||||
auto oriIdx = forOp.getBody()->getArguments().back();
|
||||
Value consumerReleaseIdx =
|
||||
builder.createWithAgentIds<arith::SubIOp>(loc, oriIdx, one);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
loc, consumerReleaseIdx, numStagesVal);
|
||||
Value consumerReleaseIdx = forOp.getBody()->getArguments().back();
|
||||
Value zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
Value lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
loc, numStages - 1, 32);
|
||||
Value consumerReleaseIdxMinusOne =
|
||||
builder.createWithAgentIds<arith::SubIOp>(loc, consumerReleaseIdx,
|
||||
one);
|
||||
cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
|
||||
loc, cond, lastStage, consumerReleaseIdxMinusOne);
|
||||
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
|
||||
consumerReleaseIdx);
|
||||
setAgentIds(ifOp.thenYield().getOperation(), agentIds);
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
|
||||
0);
|
||||
unsigned resultIndex = dotAsync->getUses().begin()->getOperandNumber();
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
|
||||
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
|
||||
Value one_ = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
headConsumer->getLoc(), 1, 32);
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
one = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 1, 32);
|
||||
lastStage = builder.createWithAgentIds<arith::ConstantIntOp>(
|
||||
loc, numStages - 1, 32);
|
||||
consumerReleaseIdx = forOp.getResults().back();
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, consumerReleaseIdx, one_);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::RemSIOp>(
|
||||
loc, consumerReleaseIdx, numStagesVal);
|
||||
consumerReleaseIdxMinusOne = builder.createWithAgentIds<arith::SubIOp>(
|
||||
loc, consumerReleaseIdx, one);
|
||||
cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::eq, consumerReleaseIdx, zero);
|
||||
consumerReleaseIdx = builder.createWithAgentIds<arith::SelectOp>(
|
||||
loc, cond, lastStage, consumerReleaseIdxMinusOne);
|
||||
builder.createWithAgentIds<ttng::ConsumerReleaseOp>(loc, token,
|
||||
consumerReleaseIdx);
|
||||
dotOp->erase();
|
||||
|
||||
Reference in New Issue
Block a user