[BACKEND] Prevent emitting multiple dot_wait after pipelinied loop (#2598)

Patch based on @donproc findings and suggested optimization.

Emitting multiple wait op may confuse ptxas and cause it to fallback to
a conservative mode.
This commit is contained in:
Thomas Raoux
2023-11-03 14:29:50 -07:00
committed by GitHub
parent 34b89a1173
commit cb3d79a185
6 changed files with 136 additions and 28 deletions

View File

@@ -131,8 +131,55 @@ struct DotWaitOpConversion
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto pendings = op.getPendings();
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
op, adaptor.getInput(), pendings);
Location loc = op.getLoc();
if (adaptor.getInputs().size() <= 1) {
Value intput =
adaptor.getInputs().size() == 1 ? adaptor.getInputs()[0] : Value();
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(op, intput,
pendings);
return success();
}
std::vector<Type> types;
// Pack the inputs into a single struct.
for (Value input : adaptor.getInputs()) {
auto structType = input.getType().dyn_cast<LLVM::LLVMStructType>();
if (!structType)
return failure();
for (Type type : structType.getBody())
types.push_back(type);
}
auto packedType =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
unsigned outputStructIndex = 0;
for (Value input : adaptor.getInputs()) {
auto structType = input.getType().dyn_cast<LLVM::LLVMStructType>();
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
Value value = rewriter.create<LLVM::ExtractValueOp>(
loc, structType.getBody()[i], input, i);
packed = rewriter.create<LLVM::InsertValueOp>(
loc, packedType, packed, value, outputStructIndex++);
}
}
Value packedOutput =
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, packed, pendings);
// Unpack the output into the original struct types.
SmallVector<Value> outputs;
outputStructIndex = 0;
for (Value input : adaptor.getInputs()) {
auto structType = input.getType().cast<LLVM::LLVMStructType>();
Value unpacked = rewriter.create<LLVM::UndefOp>(loc, structType);
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
Value value = rewriter.create<LLVM::ExtractValueOp>(
loc, packedType.getBody()[outputStructIndex], packedOutput,
outputStructIndex);
outputStructIndex++;
unpacked = rewriter.create<LLVM::InsertValueOp>(loc, structType,
unpacked, value, i);
}
outputs.push_back(unpacked);
}
rewriter.replaceOp(op, outputs);
return success();
}
};

View File

@@ -792,32 +792,21 @@ void mlir::triton::asyncLaunchDots(scf::ForOp forOp) {
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
builder.setInsertionPointAfter(forOp);
SmallVector<Type> resultTypes(resultNeedSync.size());
SmallVector<Value> yieldThenValues(resultNeedSync.size());
SmallVector<Value> yieldElseValues(resultNeedSync.size());
for (int i = 0; i < resultNeedSync.size(); ++i) {
resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType();
yieldThenValues[i] = forOp->getResult(resultNeedSync[i]);
yieldElseValues[i] = forOp->getResult(resultNeedSync[i]);
}
Value loopNotEmpty = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
forOp.getUpperBound());
auto ifOp = builder.create<scf::IfOp>(loc, resultTypes, loopNotEmpty,
/*hasElse*/ true);
builder.setInsertionPointToStart(ifOp.thenBlock());
SmallVector<Value> waitOperands;
for (int i = 0; i < resultNeedSync.size(); ++i) {
Value result = forOp->getResult(resultNeedSync[i]);
if (result.use_empty())
continue;
auto dotWait =
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(ifOp.getResult(i), dotWait);
yieldThenValues[i] = dotWait.getResult();
waitOperands.push_back(result);
}
if (!waitOperands.empty()) {
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
waitOperands, 0);
for (int i = 0; i < resultNeedSync.size(); ++i) {
Value result = forOp->getResult(resultNeedSync[i]);
result.replaceAllUsesExcept(dotWait.getResult(i), dotWait);
}
}
auto yieldOpThen = builder.create<scf::YieldOp>(loc, yieldThenValues);
builder.setInsertionPointToEnd(ifOp.elseBlock());
auto yieldOpElse = builder.create<scf::YieldOp>(loc, yieldElseValues);
// 3. potentially remove redundant dot_wait after dot_async if having mutiple
// DotOp in the loop

View File

@@ -79,6 +79,17 @@ void CreateMutexOp::build(::mlir::OpBuilder &builder,
build(builder, state, MutexType::get(builder.getContext()));
}
///--- DotWaitOp ---
LogicalResult DotWaitOp::inferReturnTypes(
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
for (Value operand : operands)
inferredReturnTypes.push_back(operand.getType());
return mlir::success();
}
} // namespace nvidia_gpu
} // namespace triton
} // namespace mlir

View File

@@ -892,7 +892,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
Value result = forOp->getResult(resultIndex);
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
forOp.getLoc(), result, 0);
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
result.replaceAllUsesExcept(dotWait.getResult(0), dotWait);
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);