[BACKEND] Optimize performance for f16 epilogue with TMA store (#2135)

1. Optimize the conversion and packing for 2xf32 -> 2xf16.
2. Split TMA store block into multiple slices of size 64x64.
3. Distribute the TMA store to all the warps.
4. Fix some naming issue.
This commit is contained in:
ivanyinwz
2023-08-22 03:44:11 +08:00
committed by GitHub
parent ea8416164f
commit ec801ce18e
6 changed files with 353 additions and 25 deletions

View File

@@ -3,6 +3,10 @@
#include "ConvertLayoutOpToLLVM.h"
#include "LoadStoreOpToLLVM.h"
#include "Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include <numeric>
@@ -12,6 +16,7 @@ using namespace mlir::triton;
using ::mlir::LLVM::delinearize;
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
using ::mlir::LLVM::linearize;
using ::mlir::triton::gpu::getCTALayout;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::getTotalElemsPerThread;
using ::mlir::triton::gpu::SharedEncodingAttr;
@@ -404,6 +409,18 @@ struct StoreAsyncOpConversion
LogicalResult
matchAndRewrite(triton::nvidia_gpu::StoreAsyncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto srcEncoding = srcTy.getEncoding();
if (srcEncoding.isa<MmaEncodingAttr>()) {
return lowerStoreAsyncWithSlice(op, adaptor, rewriter);
} else {
return lowerStoreAsync(op, adaptor, rewriter);
}
}
LogicalResult lowerStoreAsync(triton::nvidia_gpu::StoreAsyncOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
@@ -413,6 +430,9 @@ struct StoreAsyncOpConversion
auto elemTy = srcTy.getElementType();
auto rank = srcTy.getRank();
// The sotre async op only supports tensor with ranke <= 5.
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
assert(rank > 0 && rank <= 5);
auto moduleOp = op->getParentOfType<ModuleOp>();
@@ -475,14 +495,14 @@ struct StoreAsyncOpConversion
.cast<RankedTensorType>()
.getShape();
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
// magic 128 bytes
const uint32_t bytesPerCacheline = 128;
uint32_t bytesPerElem = elemTy.getIntOrFloatBitWidth() / 8;
uint32_t numBox{1};
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
auto tNumElems = shapePerCTA[dim];
if (i == 0 && tNumElems * bytesPerElem > 128) {
tNumElems = 128 / bytesPerElem;
if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) {
tNumElems = bytesPerCacheline / bytesPerElem;
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
}
boxDims.emplace_back(tNumElems);
@@ -574,6 +594,268 @@ struct StoreAsyncOpConversion
return success();
}
LogicalResult
lowerStoreAsyncWithSlice(triton::nvidia_gpu::StoreAsyncOp op,
OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
MLIRContext *ctx = rewriter.getContext();
auto dst = op.getDst();
auto src = op.getSrc();
auto srcTy = src.getType().cast<RankedTensorType>();
auto makeTensorPtr = tensorPtrMap->lookup(op.getOperation());
auto dstTensorTy = makeTensorPtr.getResult()
.getType()
.cast<triton::PointerType>()
.getPointeeType()
.cast<RankedTensorType>();
auto tensorShape = dstTensorTy.getShape();
auto dstOrder = makeTensorPtr.getOrder();
auto dstElemTy = dstTensorTy.getElementType();
auto rank = srcTy.getRank();
// The sotre async op only supports tensor with ranke <= 5.
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
assert(rank > 0 && rank <= 5);
auto moduleOp = op->getParentOfType<ModuleOp>();
assert(moduleOp && "Parent ModuleOp not found for StoreAsyncOp");
auto llFuncOp = op->getParentOfType<LLVM::LLVMFuncOp>();
assert(llFuncOp && "LLVMFuncOp not found for StoreAsyncOp");
int numTMADescs = getNumTMADescs(llFuncOp);
assert(numTMADescs > 0);
auto ctaLayout = getCTALayout(dstTensorTy.getEncoding());
// The order of smem should be consistent with gmem.
SmallVector<unsigned> sharedOrder;
for (auto o : makeTensorPtr.getOrder()) {
sharedOrder.emplace_back(o);
}
auto sharedLayout = SharedEncodingAttr::get(ctx, tensorShape, sharedOrder,
ctaLayout, dstElemTy);
mlir::triton::gpu::TMAInfo tmaInfo;
tmaInfo.tensorDataType = getCUtensorMapDataType(dstElemTy);
tmaInfo.tensorRank = rank;
assert(tmaMetadata);
unsigned TMADescIdx = tmaMetadata->size();
unsigned numFuncArgs = llFuncOp.getBody().front().getNumArguments();
unsigned globalAddressArgIdx = getArgIdx(makeTensorPtr.getBase());
tmaInfo.globalAddressArgIdx = globalAddressArgIdx;
tmaInfo.TMADescArgIdx = numFuncArgs - numTMADescs + TMADescIdx;
auto getDimOfOrder = [](ArrayRef<int32_t> order, int32_t i) {
auto it = std::find(order.begin(), order.end(), i);
assert(it != order.end());
return std::distance(order.begin(), it);
};
std::vector<int32_t> globalDimsArgIdx;
std::vector<int32_t> globalStridesArgIdx;
// constant values are mapped to (-1 - value)
for (int i = 0; i < rank; ++i) {
int32_t argIdx = -1;
auto dim = getDimOfOrder(dstOrder, i);
argIdx = getArgIdx(makeTensorPtr.getShape()[dim]);
globalDimsArgIdx.emplace_back(argIdx);
// handle constant stride
argIdx = getArgIdx(makeTensorPtr.getStrides()[dim]);
globalStridesArgIdx.emplace_back(argIdx);
}
tmaInfo.globalDimsArgIdx = globalDimsArgIdx;
tmaInfo.globalStridesArgIdx = globalStridesArgIdx;
std::vector<uint32_t> boxDims;
auto CTAsPerCGA = sharedLayout.getCTALayout().getCTAsPerCGA();
auto CTAOrder = sharedLayout.getCTALayout().getCTAOrder();
auto CTASplitNum = sharedLayout.getCTALayout().getCTASplitNum();
auto shapePerCTA = getShapePerCTA(CTASplitNum, tensorShape);
auto srcLayout = srcTy.getEncoding();
auto mmaLayout = srcLayout.dyn_cast<MmaEncodingAttr>();
unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy);
auto instrShape = mmaLayout.getInstrShape();
auto warpsPerCTA = mmaLayout.getWarpsPerCTA();
uint32_t repM =
ceil<unsigned>(shapePerCTA[0], instrShape[0] * warpsPerCTA[0]);
uint32_t numElemsPerRep = numElems / repM;
const uint32_t bytesPerCacheline = 128;
uint32_t bytesPerElem = dstElemTy.getIntOrFloatBitWidth() / 8;
uint32_t numBox{1};
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
auto tNumElems = shapePerCTA[dim];
if (i == 0 && tNumElems * bytesPerElem > bytesPerCacheline) {
tNumElems = bytesPerCacheline / bytesPerElem;
numBox = (shapePerCTA[dim] + tNumElems - 1) / tNumElems;
}
if (i == 1) {
tNumElems = tNumElems / repM / warpsPerCTA[0];
}
boxDims.emplace_back(tNumElems);
}
std::vector<uint32_t> elementStrides(rank, 1);
tmaInfo.boxDims = boxDims;
tmaInfo.elementStrides = elementStrides;
CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE;
assert(((dstElemTy.getIntOrFloatBitWidth() == 16 &&
sharedLayout.getVec() == 8) or
(dstElemTy.getIntOrFloatBitWidth() == 32 &&
sharedLayout.getVec() == 4)) &&
"Unexpected shared layout for StoreAsyncOp");
if (sharedLayout.getPerPhase() == 4 && sharedLayout.getMaxPhase() == 2)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_32B;
else if (sharedLayout.getPerPhase() == 2 && sharedLayout.getMaxPhase() == 4)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_64B;
else if (sharedLayout.getPerPhase() == 1 && sharedLayout.getMaxPhase() == 8)
swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B;
else
llvm::report_fatal_error("Unsupported shared layout for StoreAsyncOp");
tmaInfo.swizzle = swizzle;
tmaInfo.interleave = CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE;
tmaInfo.l2Promotion =
CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
tmaInfo.oobFill =
CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
tmaMetadata->emplace_back(tmaInfo);
Value llDst = adaptor.getDst();
Value llSrc = adaptor.getSrc();
auto srcShape = srcTy.getShape();
auto dstElemPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3);
Value smemBase = getSharedMemoryBase(loc, rewriter, op.getOperation());
smemBase = bitcast(smemBase, dstElemPtrTy);
SmallVector<Value> offsetVals;
for (auto i = 0; i < srcShape.size(); ++i) {
offsetVals.emplace_back(i32_val(0));
}
Value tmaDesc =
llFuncOp.getBody().front().getArgument(tmaInfo.TMADescArgIdx);
auto ptrI8SharedTy = LLVM::LLVMPointerType::get(
typeConverter->convertType(rewriter.getI8Type()), 3);
auto threadId = getThreadId(rewriter, loc);
Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0));
auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter,
dst.getType());
uint32_t boxStride = std::accumulate(boxDims.begin(), boxDims.end(), 1,
std::multiplies<uint32_t>());
boxStride = boxStride * repM * warpsPerCTA[0];
Value clusterCTAId = getClusterCTAId(rewriter, loc);
SmallVector<Value> multiDimClusterCTAId =
delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder);
// rowStride in bytes
uint32_t rowStrideInBytes = shapePerCTA[dstOrder[0]] * bytesPerElem;
uint32_t swizzlingByteWidth =
std::min<uint32_t>(rowStrideInBytes, bytesPerCacheline);
unsigned numElemsPerSwizzlingRow = swizzlingByteWidth / bytesPerElem;
unsigned leadingDimOffset =
numElemsPerSwizzlingRow * shapePerCTA[dstOrder[1]];
uint32_t rowsPerRep = getShapePerCTATile(mmaLayout)[0];
Value warpId = udiv(threadId, i32_val(32));
Value warpId0 = urem(urem(warpId, i32_val(warpsPerCTA[0])),
i32_val(srcShape[0] / instrShape[0]));
auto srcOrder = triton::gpu::getOrder(srcLayout);
unsigned inVec =
srcOrder == sharedLayout.getOrder()
? triton::gpu::getContigPerThread(srcLayout)[srcOrder[0]]
: 1;
unsigned outVec = sharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
assert(minVec == 2);
auto wordTy = vec_ty(dstElemTy, minVec);
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
rewriter, srcTy);
for (uint32_t b = 0; b < numBox; ++b) {
for (int rep = 0; rep < repM; ++rep) {
Value rowOfWarp = add(mul(warpId0, i32_val(instrShape[0])),
i32_val(rep * rowsPerRep));
uint32_t elemIdxOffset = rep * numElemsPerRep;
for (unsigned idx = 0; idx < numElemsPerRep / numBox; idx += 8) {
uint32_t elemIdx = elemIdxOffset + b * numElemsPerRep / numBox + idx;
Value offset = rewriter.create<triton::nvgpu::OffsetOfStmatrixV4Op>(
loc, i32_ty, threadId, rowOfWarp,
i32_val(b * numElemsPerRep / numBox + idx), leadingDimOffset,
numElemsPerSwizzlingRow, true);
Value addr = gep(dstElemPtrTy, smemBase, offset);
Value words[4];
for (unsigned i = 0; i < 8; ++i) {
if (i % minVec == 0)
words[i / 2] = undef(wordTy);
words[i / 2] = insert_element(
wordTy, words[i / 2], inVals[elemIdx + i], i32_val(i % minVec));
}
rewriter.create<triton::nvgpu::StoreMatrixOp>(
loc, bitcast(addr, ptrI8SharedTy),
ValueRange{bitcast(words[0], i32_ty), bitcast(words[1], i32_ty),
bitcast(words[2], i32_ty), bitcast(words[3], i32_ty)});
}
rewriter.create<triton::nvgpu::FenceAsyncSharedOp>(loc, 0);
SmallVector<Value> coord;
// raw coord
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
coord.push_back(llCoord[dim]);
}
// coord with box and cta offset
for (int i = 0; i < rank; ++i) {
auto dim = getDimOfOrder(dstOrder, i);
if (i == 0) {
coord[i] = add(coord[i], i32_val(b * boxDims[i]));
auto CTAOffset =
mul(multiDimClusterCTAId[dim], i32_val(numBox * boxDims[i]));
coord[i] = add(coord[i], CTAOffset);
} else {
Value blockOffset = i32_val(rep * instrShape[0] * warpsPerCTA[0]);
Value warpOffset = mul(warpId0, i32_val(instrShape[0]));
coord[i] = add(add(coord[i], add(blockOffset, warpOffset)),
mul(multiDimClusterCTAId[dim],
i32_val(boxDims[i] * repM * warpsPerCTA[0])));
}
}
Value srcOffset =
add(i32_val(b * boxStride + rep * instrShape[0] * warpsPerCTA[0] *
instrShape[1] * warpsPerCTA[1] /
numBox),
mul(warpId0, i32_val(instrShape[0] * numElemsPerSwizzlingRow)));
auto srcPtrTy = ptr_ty(getTypeConverter()->convertType(dstElemTy), 3);
Value srcPtrBase = gep(srcPtrTy, smemBase, srcOffset);
auto addr = bitcast(srcPtrBase, ptrI8SharedTy);
rewriter.create<triton::nvgpu::TMAStoreTiledOp>(loc, tmaDesc, addr,
pred, coord);
}
}
rewriter.eraseOp(op);
return success();
}
private:
CUtensorMapDataType getCUtensorMapDataType(Type ty) const {
if (ty.isF16()) {
@@ -1136,6 +1418,9 @@ struct InsertSliceAsyncV2OpConversion
auto rank = resultTy.getRank() - 1;
// TODO: support any valid rank in (3, 4, 5)
// The sotre async op only supports tensor with ranke <= 5.
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tensor-dimension-size-and-format
assert(rank > 0 && rank <= 5);
SmallVector<unsigned> shape;
auto axis = op->getAttrOfType<IntegerAttr>("axis").getInt();