[BACKEND] Fix nPerWarp == 8 in MMA16816SmemLoader (#2109)

This commit is contained in:
Qingyi Liu
2023-08-16 08:32:27 +08:00
committed by GitHub
parent 8fa11a75d3
commit 780266c3a2

View File

@@ -19,7 +19,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
// Data loader for mma.16816 instruction.
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int warpsPerTile, ArrayRef<uint32_t> order,
MMA16816SmemLoader(int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder,
int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
@@ -93,6 +93,8 @@ private:
int inWarpMatOffset;
// Offset in number of matrices to increment on non-k dim across warps
int warpMatOffset;
int nPerWarp;
};
SmallVector<Value>
@@ -131,10 +133,18 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
// address (s0,s1) annotates.
Value matOff[2];
matOff[kOrder ^ 1] = add(
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
// When B's shape(k, n) is (16, 8) and ldmatrix.x4 is used, the shared memory
// access will be out of bound. In the future we should change this case to
// ldmatrix.x2
if (kOrder == 0 && nPerWarp == 8) {
matOff[kOrder ^ 1] = mul(warpId, i32_val(warpMatOffset));
} else {
matOff[kOrder ^ 1] = add(
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
mul(nkMatArr,
i32_val(
inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
}
matOff[kOrder] = kMatArr;
// Physical offset (before swizzling)
@@ -390,13 +400,13 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
}
MMA16816SmemLoader::MMA16816SmemLoader(
int warpsPerTile, ArrayRef<uint32_t> order, ArrayRef<uint32_t> warpsPerCTA,
uint32_t kOrder, int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase, int elemBytes,
ConversionPatternRewriter &rewriter,
int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder, int kWidth,
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
: order(order.begin(), order.end()),
: nPerWarp(nPerWarp), order(order.begin(), order.end()),
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
@@ -490,6 +500,7 @@ std::function<void(int, int)> getLoadMatrixFn(
bool isA, TritonGPUToLLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter, Location loc) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
auto shapePerCTA = getShapePerCTA(tensorTy);
Type eltTy = tensorTy.getElementType();
// We assumes that the input operand of Dot should be from shared layout.
// TODO(Superjomn) Consider other layouts if needed later.
@@ -511,13 +522,16 @@ std::function<void(int, int)> getLoadMatrixFn(
if (kWidth != (4 / elemBytes))
assert(vecPhase == 1 || vecPhase == 4 * kWidth);
int nPerWarp =
std::max<int>(shapePerCTA[1] / mmaLayout.getWarpsPerCTA()[1], 8);
// (a, b) is the coordinate.
auto load = [=, &rewriter, &vals](int a, int b) {
MMA16816SmemLoader loader(
warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(),
kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/,
instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter,
typeConverter, loc);
nPerWarp, warpsPerTile, sharedLayout.getOrder(),
mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides,
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
maxPhase, elemBytes, rewriter, typeConverter, loc);
// Offset of a slice within the original tensor in shared memory
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offs =