mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Add back dot.wait when generating async_dot (#2478)
Based on discussion this is needed to make sure there is no race condition when reading shared memory.
This commit is contained in:
@@ -1631,6 +1631,16 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
return;
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
// 0. insert dot_wait after the last dot in the loop as we implicitly pipeline
|
||||
// wgmma ops by one stage.
|
||||
// This is needed to prevent shared memory inputs to be overriden before the
|
||||
// operation is completed.
|
||||
// TODO: merge this with the rest of the pipelining transformation and look at
|
||||
// a better representation for async dots.
|
||||
tt::DotOp lastDot = dots.back();
|
||||
builder.setInsertionPointAfter(lastDot);
|
||||
auto dotWait = builder.create<tt::nvidia_gpu::DotWaitOp>(
|
||||
lastDot.getLoc(), lastDot.getResult(), dots.size());
|
||||
|
||||
// 1. replace Dot with DotAsync
|
||||
for (size_t idx = 0; idx < dots.size(); ++idx) {
|
||||
@@ -1640,7 +1650,7 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) {
|
||||
dotOp.getLoc(), dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dotOp.replaceAllUsesWith(dotAsync.getResult());
|
||||
updateConsumerReleaseInfo(dotOp, dotAsync, /*stage=*/1);
|
||||
updateConsumerReleaseInfo(dotOp, dotWait, /*stage=*/1);
|
||||
dotOp->erase();
|
||||
}
|
||||
|
||||
@@ -1749,7 +1759,7 @@ void PipelinePass::emitConsumerRelease(Value mbarTensor,
|
||||
std::accumulate(CTASplitNum.begin(), CTASplitNum.end(),
|
||||
1, std::multiplies{});
|
||||
auto numConsumerThreads =
|
||||
isa<ttng::DotAsyncOp>(lastUserWithLargestStage) ? 128 : 32;
|
||||
isa<ttng::DotWaitOp>(lastUserWithLargestStage) ? 128 : 32;
|
||||
Value _0 = b.create<arith::ConstantIntOp>(loc, 0, 32);
|
||||
Value numArrives = b.create<arith::ConstantIntOp>(
|
||||
loc, numConsumerThreads / numRemoteCTAs, 32);
|
||||
|
||||
@@ -859,6 +859,8 @@ void buildAsyncComm(const DenseMap<Operation *, SmallVector<Channel *>> &map,
|
||||
loc, dotOp.getA(), dotOp.getB(), dotOp.getC(),
|
||||
dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc());
|
||||
dot.replaceAllUsesWith(dotAsync.getResult());
|
||||
builder.createWithAgentIds<triton::nvidia_gpu::DotWaitOp>(
|
||||
loc, dotAsync.getResult(), 1);
|
||||
|
||||
// 1. insert ConsumerReleaseOp for DotAsyncOps
|
||||
Value cond = builder.createWithAgentIds<arith::CmpIOp>(
|
||||
|
||||
@@ -22,9 +22,10 @@
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: triton_gpu.extract_slice
|
||||
// CHECK: triton_nvidia_gpu.dot_async
|
||||
// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 1
|
||||
// CHECK: triton_nvidia_gpu.consumer_release
|
||||
// CHECK: scf.yield
|
||||
// CHECK: triton_nvidia_gpu.dot_wait
|
||||
// CHECK: triton_nvidia_gpu.dot_wait {{.*}} pendings = 0
|
||||
// CHECK: async_agent = dense<1> : vector<1xi32>
|
||||
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
|
||||
|
||||
Reference in New Issue
Block a user