mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Prevent emitting multiple dot_wait after pipelinied loop (#2598)
Patch based on @donproc findings and suggested optimization. Emitting multiple wait op may confuse ptxas and cause it to fallback to a conservative mode.
This commit is contained in:
@@ -792,32 +792,21 @@ void mlir::triton::asyncLaunchDots(scf::ForOp forOp) {
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
SmallVector<Type> resultTypes(resultNeedSync.size());
|
||||
SmallVector<Value> yieldThenValues(resultNeedSync.size());
|
||||
SmallVector<Value> yieldElseValues(resultNeedSync.size());
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType();
|
||||
yieldThenValues[i] = forOp->getResult(resultNeedSync[i]);
|
||||
yieldElseValues[i] = forOp->getResult(resultNeedSync[i]);
|
||||
}
|
||||
Value loopNotEmpty = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
|
||||
forOp.getUpperBound());
|
||||
auto ifOp = builder.create<scf::IfOp>(loc, resultTypes, loopNotEmpty,
|
||||
/*hasElse*/ true);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
SmallVector<Value> waitOperands;
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
auto dotWait =
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(ifOp.getResult(i), dotWait);
|
||||
yieldThenValues[i] = dotWait.getResult();
|
||||
waitOperands.push_back(result);
|
||||
}
|
||||
if (!waitOperands.empty()) {
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
|
||||
waitOperands, 0);
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(i), dotWait);
|
||||
}
|
||||
}
|
||||
auto yieldOpThen = builder.create<scf::YieldOp>(loc, yieldThenValues);
|
||||
builder.setInsertionPointToEnd(ifOp.elseBlock());
|
||||
auto yieldOpElse = builder.create<scf::YieldOp>(loc, yieldElseValues);
|
||||
|
||||
// 3. potentially remove redundant dot_wait after dot_async if having mutiple
|
||||
// DotOp in the loop
|
||||
|
||||
@@ -79,6 +79,17 @@ void CreateMutexOp::build(::mlir::OpBuilder &builder,
|
||||
build(builder, state, MutexType::get(builder.getContext()));
|
||||
}
|
||||
|
||||
///--- DotWaitOp ---
|
||||
LogicalResult DotWaitOp::inferReturnTypes(
|
||||
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
for (Value operand : operands)
|
||||
inferredReturnTypes.push_back(operand.getType());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace nvidia_gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -892,7 +892,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(0), dotWait);
|
||||
|
||||
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
|
||||
Reference in New Issue
Block a user