mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Prevent emitting multiple dot_wait after pipelinied loop (#2598)
Patch based on @donproc findings and suggested optimization. Emitting multiple wait op may confuse ptxas and cause it to fallback to a conservative mode.
This commit is contained in:
@@ -270,16 +270,16 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure,
|
||||
}
|
||||
|
||||
def TTNG_DotWaitOp : TTNG_Op<"dot_wait", [DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
AllTypesMatch<["input", "output"]>]> {
|
||||
AllTypesMatch<["inputs", "outputs"]>]> {
|
||||
let summary = "dot wait";
|
||||
let arguments = (ins TT_FpIntTensor:$input, I32Attr:$pendings);
|
||||
let results = (outs TT_FpIntTensor:$output);
|
||||
let arguments = (ins Variadic<TT_FpIntTensor>:$inputs, I32Attr:$pendings);
|
||||
let results = (outs Variadic<TT_FpIntTensor>:$outputs);
|
||||
let description = [{
|
||||
This operation defining the waiting action for a async dot, MMAv3 .e.g.
|
||||
The subsequent operations should not execute until this operation completes waiting.
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$input attr-dict `:` type($input)";
|
||||
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
|
||||
}
|
||||
|
||||
def TTNG_StoreAsyncOp : TTNG_Op<"store_async",
|
||||
|
||||
@@ -131,8 +131,55 @@ struct DotWaitOpConversion
|
||||
matchAndRewrite(triton::nvidia_gpu::DotWaitOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto pendings = op.getPendings();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(
|
||||
op, adaptor.getInput(), pendings);
|
||||
Location loc = op.getLoc();
|
||||
if (adaptor.getInputs().size() <= 1) {
|
||||
Value intput =
|
||||
adaptor.getInputs().size() == 1 ? adaptor.getInputs()[0] : Value();
|
||||
rewriter.replaceOpWithNewOp<triton::nvgpu::WGMMAWaitGroupOp>(op, intput,
|
||||
pendings);
|
||||
return success();
|
||||
}
|
||||
std::vector<Type> types;
|
||||
// Pack the inputs into a single struct.
|
||||
for (Value input : adaptor.getInputs()) {
|
||||
auto structType = input.getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
if (!structType)
|
||||
return failure();
|
||||
for (Type type : structType.getBody())
|
||||
types.push_back(type);
|
||||
}
|
||||
auto packedType =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, packedType);
|
||||
unsigned outputStructIndex = 0;
|
||||
for (Value input : adaptor.getInputs()) {
|
||||
auto structType = input.getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
|
||||
Value value = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, structType.getBody()[i], input, i);
|
||||
packed = rewriter.create<LLVM::InsertValueOp>(
|
||||
loc, packedType, packed, value, outputStructIndex++);
|
||||
}
|
||||
}
|
||||
Value packedOutput =
|
||||
rewriter.create<triton::nvgpu::WGMMAWaitGroupOp>(loc, packed, pendings);
|
||||
// Unpack the output into the original struct types.
|
||||
SmallVector<Value> outputs;
|
||||
outputStructIndex = 0;
|
||||
for (Value input : adaptor.getInputs()) {
|
||||
auto structType = input.getType().cast<LLVM::LLVMStructType>();
|
||||
Value unpacked = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (unsigned i = 0; i < structType.getBody().size(); ++i) {
|
||||
Value value = rewriter.create<LLVM::ExtractValueOp>(
|
||||
loc, packedType.getBody()[outputStructIndex], packedOutput,
|
||||
outputStructIndex);
|
||||
outputStructIndex++;
|
||||
unpacked = rewriter.create<LLVM::InsertValueOp>(loc, structType,
|
||||
unpacked, value, i);
|
||||
}
|
||||
outputs.push_back(unpacked);
|
||||
}
|
||||
rewriter.replaceOp(op, outputs);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -792,32 +792,21 @@ void mlir::triton::asyncLaunchDots(scf::ForOp forOp) {
|
||||
|
||||
// 2. If there's any outstanding DotAsyncOps, we need to wait for them.
|
||||
builder.setInsertionPointAfter(forOp);
|
||||
SmallVector<Type> resultTypes(resultNeedSync.size());
|
||||
SmallVector<Value> yieldThenValues(resultNeedSync.size());
|
||||
SmallVector<Value> yieldElseValues(resultNeedSync.size());
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
resultTypes[i] = forOp->getResult(resultNeedSync[i]).getType();
|
||||
yieldThenValues[i] = forOp->getResult(resultNeedSync[i]);
|
||||
yieldElseValues[i] = forOp->getResult(resultNeedSync[i]);
|
||||
}
|
||||
Value loopNotEmpty = builder.create<arith::CmpIOp>(
|
||||
loc, arith::CmpIPredicate::slt, forOp.getLowerBound(),
|
||||
forOp.getUpperBound());
|
||||
auto ifOp = builder.create<scf::IfOp>(loc, resultTypes, loopNotEmpty,
|
||||
/*hasElse*/ true);
|
||||
builder.setInsertionPointToStart(ifOp.thenBlock());
|
||||
SmallVector<Value> waitOperands;
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
if (result.use_empty())
|
||||
continue;
|
||||
auto dotWait =
|
||||
builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(ifOp.getResult(i), dotWait);
|
||||
yieldThenValues[i] = dotWait.getResult();
|
||||
waitOperands.push_back(result);
|
||||
}
|
||||
if (!waitOperands.empty()) {
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(forOp.getLoc(),
|
||||
waitOperands, 0);
|
||||
for (int i = 0; i < resultNeedSync.size(); ++i) {
|
||||
Value result = forOp->getResult(resultNeedSync[i]);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(i), dotWait);
|
||||
}
|
||||
}
|
||||
auto yieldOpThen = builder.create<scf::YieldOp>(loc, yieldThenValues);
|
||||
builder.setInsertionPointToEnd(ifOp.elseBlock());
|
||||
auto yieldOpElse = builder.create<scf::YieldOp>(loc, yieldElseValues);
|
||||
|
||||
// 3. potentially remove redundant dot_wait after dot_async if having mutiple
|
||||
// DotOp in the loop
|
||||
|
||||
@@ -79,6 +79,17 @@ void CreateMutexOp::build(::mlir::OpBuilder &builder,
|
||||
build(builder, state, MutexType::get(builder.getContext()));
|
||||
}
|
||||
|
||||
///--- DotWaitOp ---
|
||||
LogicalResult DotWaitOp::inferReturnTypes(
|
||||
::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location,
|
||||
::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
|
||||
::mlir::OpaqueProperties properties, ::mlir::RegionRange regions,
|
||||
::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
|
||||
for (Value operand : operands)
|
||||
inferredReturnTypes.push_back(operand.getType());
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace nvidia_gpu
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
|
||||
@@ -892,7 +892,7 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
Value result = forOp->getResult(resultIndex);
|
||||
auto dotWait = builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
forOp.getLoc(), result, 0);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(), dotWait);
|
||||
result.replaceAllUsesExcept(dotWait.getResult(0), dotWait);
|
||||
|
||||
// 3. insert ConsumerReleaseOp for outstanding DotAsyncOps
|
||||
zero = builder.createWithAgentIds<arith::ConstantIntOp>(loc, 0, 32);
|
||||
|
||||
@@ -379,3 +379,64 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c
|
||||
tt.return %17#0 : tensor<128x64xf32, #mma>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}>
|
||||
#mma1 = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}>
|
||||
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}>
|
||||
module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
|
||||
// CHECK-LABEL: two_accumulator_escape
|
||||
tt.func @two_accumulator_escape(%arg0: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16, 1> {tt.divisibility = 16 : i32}) -> (tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>) {
|
||||
%cst = arith.constant dense<0> : tensor<64x16xi32, #blocked>
|
||||
%c0_i32 = arith.constant 0 : i32
|
||||
%cst_0 = arith.constant dense<0> : tensor<1x16xi32, #blocked>
|
||||
%cst_1 = arith.constant dense<0> : tensor<128x1xi32, #blocked1>
|
||||
%c0_i64 = arith.constant 0 : i64
|
||||
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x16xf32, #mma1>
|
||||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma>
|
||||
%cst_4 = arith.constant dense<1.000000e+00> : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>>
|
||||
%c1_i32 = arith.constant 1 : i32
|
||||
%c8_i32 = arith.constant 8 : i32
|
||||
%0 = tt.addptr %arg0, %c0_i64 : !tt.ptr<f16, 1>, i64
|
||||
%1 = tt.addptr %arg1, %c0_i64 : !tt.ptr<f16, 1>, i64
|
||||
%2 = tt.splat %1 : (!tt.ptr<f16, 1>) -> tensor<128x1x!tt.ptr<f16, 1>, #blocked1>
|
||||
%3 = tt.addptr %2, %cst_1 : tensor<128x1x!tt.ptr<f16, 1>, #blocked1>, tensor<128x1xi32, #blocked1>
|
||||
%4 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>
|
||||
%5 = tt.expand_dims %4 {axis = 0 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>) -> tensor<1x64xi32, #blocked1>
|
||||
%6 = tt.broadcast %3 : (tensor<128x1x!tt.ptr<f16, 1>, #blocked1>) -> tensor<128x64x!tt.ptr<f16, 1>, #blocked1>
|
||||
%7 = tt.broadcast %5 : (tensor<1x64xi32, #blocked1>) -> tensor<128x64xi32, #blocked1>
|
||||
%8 = tt.addptr %6, %7 : tensor<128x64x!tt.ptr<f16, 1>, #blocked1>, tensor<128x64xi32, #blocked1>
|
||||
%9 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x64xf16, #blocked1>
|
||||
%10 = tt.splat %0 : (!tt.ptr<f16, 1>) -> tensor<1x16x!tt.ptr<f16, 1>, #blocked>
|
||||
%11 = tt.addptr %10, %cst_0 : tensor<1x16x!tt.ptr<f16, 1>, #blocked>, tensor<1x16xi32, #blocked>
|
||||
%12 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>
|
||||
%13 = tt.expand_dims %12 {axis = 1 : i32} : (tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<64x1xi32, #blocked>
|
||||
%14 = tt.broadcast %11 : (tensor<1x16x!tt.ptr<f16, 1>, #blocked>) -> tensor<64x16x!tt.ptr<f16, 1>, #blocked>
|
||||
%15 = tt.broadcast %13 : (tensor<64x1xi32, #blocked>) -> tensor<64x16xi32, #blocked>
|
||||
%16 = tt.addptr %14, %15 : tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<64x16xi32, #blocked>
|
||||
// CHECK: %[[R:.+]]:{{.+}} = scf.for
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: triton_nvidia_gpu.dot_async
|
||||
// CHECK: triton_nvidia_gpu.dot_async
|
||||
// CHECK: triton_nvidia_gpu.dot_wait %35 {pendings = 2 : i32}
|
||||
// CHECK: scf.yield
|
||||
// CHECK: %{{.*}}:2 = triton_nvidia_gpu.dot_wait %[[R]]#{{.+}}, %[[R]]#{{.+}} {pendings = 0 : i32} : tensor<128x16xf32, #{{.*}}>, tensor<128x64xf32, #{{.*}}>
|
||||
%17:3 = scf.for %arg3 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg4 = %cst_3, %arg5 = %16, %arg6 = %cst_2) -> (tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<128x16xf32, #mma1>) : i32 {
|
||||
%18 = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked>
|
||||
%19 = triton_gpu.convert_layout %9 : (tensor<128x64xf16, #blocked1>) -> tensor<128x64xf16, #shared>
|
||||
%20 = triton_gpu.convert_layout %18 : (tensor<64x16xf16, #blocked>) -> tensor<64x16xf16, #shared1>
|
||||
%21 = tt.dot %19, %20, %arg6 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared> * tensor<64x16xf16, #shared1> -> tensor<128x16xf32, #mma1>
|
||||
%l = tt.load %arg5 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x16xf16, #blocked>
|
||||
%c = triton_gpu.convert_layout %l : (tensor<64x16xf16, #blocked>) -> tensor<64x16xf16, #shared1>
|
||||
%23 = tt.trans %c : (tensor<64x16xf16, #shared1>) -> tensor<16x64xf16, #shared>
|
||||
%25 = tt.dot %cst_4, %23, %arg4 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma1}>> * tensor<16x64xf16, #shared> -> tensor<128x64xf32, #mma>
|
||||
%26 = tt.addptr %arg5, %cst : tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<64x16xi32, #blocked>
|
||||
scf.yield %25, %26, %21 : tensor<128x64xf32, #mma>, tensor<64x16x!tt.ptr<f16, 1>, #blocked>, tensor<128x16xf32, #mma1>
|
||||
}
|
||||
tt.return %17#0, %17#2 : tensor<128x64xf32, #mma>, tensor<128x16xf32, #mma1>
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user