mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Optimize performance for f16 epilogue with TMA store (#2135)
1. Optimize the conversion and packing for 2xf32 -> 2xf16. 2. Split TMA store block into multiple slices of size 64x64. 3. Distribute the TMA store to all the warps. 4. Fix some naming issue.
This commit is contained in:
@@ -3,6 +3,10 @@
|
||||
|
||||
#include "ConvertLayoutOpToLLVM.h"
|
||||
#include "LoadStoreOpToLLVM.h"
|
||||
#include "Utility.h"
|
||||
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
|
||||
|
||||
#include <numeric>
|
||||
|
||||
@@ -12,6 +16,7 @@ using namespace mlir::triton;
|
||||
using ::mlir::LLVM::delinearize;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::linearize;
|
||||
using ::mlir::triton::gpu::getCTALayout;
|
||||
using ::mlir::triton::gpu::getShapePerCTA;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
@@ -404,6 +409,18 @@ struct StoreAsyncOpConversion
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto srcEncoding = srcTy.getEncoding();
|
||||
if (srcEncoding.isa<MmaEncodingAttr>()) {
|
||||
return lowerStoreAsyncWithSlice(op, adaptor, rewriter);
|
||||
} else {
|
||||
return lowerStoreAsync(op, adaptor, rewriter);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult lowerStoreAsync(triton::nvidia_gpu::StoreAsyncOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
@@ -413,6 +430,9 @@ struct StoreAsyncOpConversion
|
||||
auto elemTy = srcTy.getElementType();
|
||||
|
||||
auto rank = srcTy.getRank();
|
||||
// The sotre async op only supports tensor with ranke <= 5.
|
||||
// Reference:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
|
||||
assert(rank > 0 && rank <= 5);
|
||||
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
@@ -475,14 +495,14 @@ struct StoreAsyncOpConversion
|
||||
.cast<RankedTensorType>()
|
||||
.getShape();
|
||||
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
|
||||
// magic 128 bytes
|
||||
const uint32_t bytesPerCacheline = 128;
|
||||
uint32_t bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8;
|
||||
uint32_t numBox{1};
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
auto tNumElems = shapePerCTA[dim];
|
||||
if (i == 0 && tNumElems * bytesPerElem > 128) {
|
||||
tNumElems = 128 / bytesPerElem;
|
||||
if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) {
|
||||
tNumElems = bytesPerCacheline / bytesPerElem;
|
||||
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
|
||||
}
|
||||
boxDims.emplace_back(tNumElems);
|
||||
@@ -574,6 +594,268 @@ struct StoreAsyncOpConversion
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
lowerStoreAsyncWithSlice(triton::nvidia_gpu::StoreAsyncOp op,
|
||||
OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto loc = op.getLoc();
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto dst = op.getDst();
|
||||
auto src = op.getSrc();
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
|
||||
auto dstTensorTy = makeTensorPtr.getResult()
|
||||
.getType()
|
||||
.cast<triton::PointerType>()
|
||||
.getPointeeType()
|
||||
.cast<RankedTensorType>();
|
||||
auto tensorShape = dstTensorTy.getShape();
|
||||
auto dstOrder = makeTensorPtr.getOrder();
|
||||
auto dstElemTy = dstTensorTy.getElementType();
|
||||
|
||||
auto rank = srcTy.getRank();
|
||||
// The sotre async op only supports tensor with ranke <= 5.
|
||||
// Reference:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
|
||||
assert(rank > 0 && rank <= 5);
|
||||
|
||||
auto moduleOp = op->getParentOfType<ModuleOp>();
|
||||
assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp");
|
||||
|
||||
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp");
|
||||
|
||||
int numTMADescs = getNumTMADescs(llFuncOp);
|
||||
assert(numTMADescs > 0);
|
||||
|
||||
auto ctaLayout = getCTALayout(dstTensorTy.getEncoding());
|
||||
// The order of smem should be consistent with gmem.
|
||||
SmallVector<unsigned> sharedOrder;
|
||||
for (auto o : makeTensorPtr.getOrder()) {
|
||||
sharedOrder.emplace_back(o);
|
||||
}
|
||||
auto sharedLayout = SharedEncodingAttr::get(ctx, tensorShape, sharedOrder,
|
||||
ctaLayout, dstElemTy);
|
||||
|
||||
mlir::triton::gpu::TMAInfo tmaInfo;
|
||||
|
||||
tmaInfo.tensorDataType = getCUtensorMapDataType(dstElemTy);
|
||||
tmaInfo.tensorRank = rank;
|
||||
assert(tmaMetadata);
|
||||
|
||||
unsigned TMADescIdx = tmaMetadata->size();
|
||||
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
|
||||
|
||||
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
|
||||
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
|
||||
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
|
||||
|
||||
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
|
||||
auto it = std::find(order.begin(), order.end(), i);
|
||||
assert(it != order.end());
|
||||
return std::distance(order.begin(), it);
|
||||
};
|
||||
|
||||
std::vector<int32_t> globalDimsArgIdx;
|
||||
std::vector<int32_t> globalStridesArgIdx;
|
||||
// constant values are mapped to (-1 - value)
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
int32_t argIdx = -1;
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
|
||||
globalDimsArgIdx.emplace_back(argIdx);
|
||||
// handle constant stride
|
||||
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
|
||||
globalStridesArgIdx.emplace_back(argIdx);
|
||||
}
|
||||
|
||||
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
|
||||
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
|
||||
std::vector<uint32_t> boxDims;
|
||||
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
|
||||
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
|
||||
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
|
||||
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
|
||||
|
||||
auto srcLayout = srcTy.getEncoding();
|
||||
auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
|
||||
|
||||
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
|
||||
|
||||
auto instrShape = mmaLayout.getInstrShape();
|
||||
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
uint32_t repM =
|
||||
ceil<unsigned>(shapePerCTA[0], instrShape[0] * warpsPerCTA[0]);
|
||||
uint32_t numElemsPerRep = numElems / repM;
|
||||
|
||||
const uint32_t bytesPerCacheline = 128;
|
||||
uint32_t bytesPerElem = dstElemTy.getIntOrFloatBitWidth() / 8;
|
||||
uint32_t numBox{1};
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
auto tNumElems = shapePerCTA[dim];
|
||||
if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) {
|
||||
tNumElems = bytesPerCacheline / bytesPerElem;
|
||||
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
|
||||
}
|
||||
if (i == 1) {
|
||||
tNumElems = tNumElems / repM / warpsPerCTA[0];
|
||||
}
|
||||
boxDims.emplace_back(tNumElems);
|
||||
}
|
||||
std::vector<uint32_t> elementStrides(rank, 1);
|
||||
tmaInfo.boxDims = boxDims;
|
||||
tmaInfo.elementStrides = elementStrides;
|
||||
|
||||
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
assert(((dstElemTy.getIntOrFloatBitWidth() == 16 &&
|
||||
sharedLayout.getVec() == 8) or
|
||||
(dstElemTy.getIntOrFloatBitWidth() == 32 &&
|
||||
sharedLayout.getVec() == 4)) &&
|
||||
"Unexpected shared layout for StoreAsyncOp");
|
||||
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
|
||||
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
|
||||
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
|
||||
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
|
||||
else
|
||||
llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp");
|
||||
tmaInfo.swizzle = swizzle;
|
||||
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
|
||||
tmaInfo.l2Promotion =
|
||||
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
|
||||
tmaInfo.oobFill =
|
||||
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
||||
|
||||
tmaMetadata->emplace_back(tmaInfo);
|
||||
|
||||
Value llDst = adaptor.getDst();
|
||||
Value llSrc = adaptor.getSrc();
|
||||
auto srcShape = srcTy.getShape();
|
||||
auto dstElemPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3);
|
||||
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
smemBase = bitcast(smemBase, dstElemPtrTy);
|
||||
|
||||
SmallVector<Value> offsetVals;
|
||||
for (auto i = 0; i < srcShape.size(); ++i) {
|
||||
offsetVals.emplace_back(i32_val(0));
|
||||
}
|
||||
|
||||
Value tmaDesc =
|
||||
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
|
||||
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
|
||||
typeConverter->convertType(rewriter.getI8Type()), 3);
|
||||
|
||||
auto threadId = getThreadId(rewriter, loc);
|
||||
Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0));
|
||||
|
||||
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
|
||||
dst.getType());
|
||||
uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1,
|
||||
std::multiplies<uint32_t>());
|
||||
boxStride = boxStride * repM * warpsPerCTA[0];
|
||||
|
||||
Value clusterCTAId = getClusterCTAId(rewriter, loc);
|
||||
SmallVector<Value> multiDimClusterCTAId =
|
||||
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
|
||||
|
||||
// rowStride in bytes
|
||||
uint32_t rowStrideInBytes = shapePerCTA[dstOrder[0]] * bytesPerElem;
|
||||
uint32_t swizzlingByteWidth =
|
||||
std::min<uint32_t>(rowStrideInBytes, bytesPerCacheline);
|
||||
|
||||
unsigned numElemsPerSwizzlingRow = swizzlingByteWidth / bytesPerElem;
|
||||
unsigned leadingDimOffset =
|
||||
numElemsPerSwizzlingRow * shapePerCTA[dstOrder[1]];
|
||||
|
||||
uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0];
|
||||
|
||||
Value warpId = udiv(threadId, i32_val(32));
|
||||
Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])),
|
||||
i32_val(srcShape[0] / instrShape[0]));
|
||||
auto srcOrder = triton::gpu::getOrder(srcLayout);
|
||||
unsigned inVec =
|
||||
srcOrder == sharedLayout.getOrder()
|
||||
? triton::gpu::getContigPerThread(srcLayout)[srcOrder[0]]
|
||||
: 1;
|
||||
unsigned outVec = sharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
assert(minVec == 2);
|
||||
|
||||
auto wordTy = vec_ty(dstElemTy, minVec);
|
||||
|
||||
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
for (uint32_t b = 0; b < numBox; ++b) {
|
||||
for (int rep = 0; rep < repM; ++rep) {
|
||||
Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])),
|
||||
i32_val(rep * rowsPerRep));
|
||||
uint32_t elemIdxOffset = rep * numElemsPerRep;
|
||||
|
||||
for (unsigned idx = 0; idx < numElemsPerRep / numBox; idx += 8) {
|
||||
uint32_t elemIdx = elemIdxOffset + b * numElemsPerRep / numBox + idx;
|
||||
|
||||
Value offset = rewriter.create<triton::nvgpu::OffsetOfStmatrixV4Op>(
|
||||
loc, i32_ty, threadId, rowOfWarp,
|
||||
i32_val(b * numElemsPerRep / numBox + idx), leadingDimOffset,
|
||||
numElemsPerSwizzlingRow, true);
|
||||
|
||||
Value addr = gep(dstElemPtrTy, smemBase, offset);
|
||||
Value words[4];
|
||||
for (unsigned i = 0; i < 8; ++i) {
|
||||
if (i % minVec == 0)
|
||||
words[i / 2] = undef(wordTy);
|
||||
words[i / 2] = insert_element(
|
||||
wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec));
|
||||
}
|
||||
|
||||
rewriter.create<triton::nvgpu::StoreMatrixOp>(
|
||||
loc, bitcast(addr, ptrI8SharedTy),
|
||||
ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty),
|
||||
bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)});
|
||||
}
|
||||
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
|
||||
|
||||
SmallVector<Value> coord;
|
||||
// raw coord
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
coord.push_back(llCoord[dim]);
|
||||
}
|
||||
// coord with box and cta offset
|
||||
for (int i = 0; i < rank; ++i) {
|
||||
auto dim = getDimOfOrder(dstOrder, i);
|
||||
if (i == 0) {
|
||||
coord[i] = add(coord[i], i32_val(b * boxDims[i]));
|
||||
auto CTAOffset =
|
||||
mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i]));
|
||||
coord[i] = add(coord[i], CTAOffset);
|
||||
} else {
|
||||
Value blockOffset = i32_val(rep * instrShape[0] * warpsPerCTA[0]);
|
||||
Value warpOffset = mul(warpId0, i32_val(instrShape[0]));
|
||||
coord[i] = add(add(coord[i], add(blockOffset, warpOffset)),
|
||||
mul(multiDimClusterCTAId[dim],
|
||||
i32_val(boxDims[i] * repM * warpsPerCTA[0])));
|
||||
}
|
||||
}
|
||||
Value srcOffset =
|
||||
add(i32_val(b * boxStride + rep * instrShape[0] * warpsPerCTA[0] *
|
||||
instrShape[1] * warpsPerCTA[1] /
|
||||
numBox),
|
||||
mul(warpId0, i32_val(instrShape[0] * numElemsPerSwizzlingRow)));
|
||||
auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3);
|
||||
Value srcPtrBase = gep(srcPtrTy, smemBase, srcOffset);
|
||||
auto addr = bitcast(srcPtrBase, ptrI8SharedTy);
|
||||
rewriter.create<triton::nvgpu::TMAStoreTiledOp>(loc, tmaDesc, addr,
|
||||
pred, coord);
|
||||
}
|
||||
}
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
|
||||
if (ty.isF16()) {
|
||||
@@ -1136,6 +1418,9 @@ struct InsertSliceAsyncV2OpConversion
|
||||
auto rank = resultTy.getRank() - 1;
|
||||
|
||||
// TODO: support any valid rank in (3, 4, 5)
|
||||
// The sotre async op only supports tensor with ranke <= 5.
|
||||
// Reference:
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
|
||||
assert(rank > 0 && rank <= 5);
|
||||
SmallVector<unsigned> shape;
|
||||
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();
|
||||
|
||||
Reference in New Issue
Block a user