[BACKEND] Optimize performance for f16 epilogue with TMA store (#2135)

1. Optimize the conversion and packing for 2xf32 -> 2xf16.
2. Split TMA store block into multiple slices of size 64x64.
3. Distribute the TMA store to all the warps.
4. Fix some naming issue.
This commit is contained in:
ivanyinwz
2023-08-22 03:44:11 +08:00
committed by GitHub
parent ea8416164f
commit ec801ce18e
6 changed files with 353 additions and 25 deletions

View File

@@ -109,6 +109,12 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec,
return paddedRepShape;
}
SmallVector<unsigned>
getScratchConfigForStoreAsync(triton::nvidia_gpu::StoreAsyncOp op) {
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
return convertType<unsigned, int64_t>(getShapePerCTA(srcTy));
}
// TODO: extend beyond scalars
SmallVector<unsigned> getScratchConfigForAtomicRMW(triton::AtomicRMWOp op) {
SmallVector<unsigned> smemShape;
@@ -244,6 +250,18 @@ private:
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
scratchAlignment);
} else if (auto storeAsyncOp =
dyn_cast<triton::nvidia_gpu::StoreAsyncOp>(op)) {
auto srcTy = storeAsyncOp.getSrc().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
if (!srcEncoding.isa<MmaEncodingAttr>()) {
return;
}
auto smemShape = getScratchConfigForStoreAsync(storeAsyncOp);
unsigned elems = std::accumulate(smemShape.begin(), smemShape.end(), 1,
std::multiplies{});
auto bytes = elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes, 1024);
} else if (auto atomicRMWOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto value = op->getOperand(0);
// only scalar requires scratch memory