mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix nPerWarp == 8 in MMA16816SmemLoader (#2109)
This commit is contained in:
@@ -19,7 +19,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
// Data loader for mma.16816 instruction.
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int warpsPerTile, ArrayRef<uint32_t> order,
|
||||
MMA16816SmemLoader(int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
|
||||
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder,
|
||||
int kWidth, ArrayRef<Value> smemStrides,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
@@ -93,6 +93,8 @@ private:
|
||||
int inWarpMatOffset;
|
||||
// Offset in number of matrices to increment on non-k dim across warps
|
||||
int warpMatOffset;
|
||||
|
||||
int nPerWarp;
|
||||
};
|
||||
|
||||
SmallVector<Value>
|
||||
@@ -131,10 +133,18 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
// address (s0,s1) annotates.
|
||||
|
||||
Value matOff[2];
|
||||
matOff[kOrder ^ 1] = add(
|
||||
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
|
||||
mul(nkMatArr,
|
||||
i32_val(inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
|
||||
// When B's shape(k, n) is (16, 8) and ldmatrix.x4 is used, the shared memory
|
||||
// access will be out of bound. In the future we should change this case to
|
||||
// ldmatrix.x2
|
||||
if (kOrder == 0 && nPerWarp == 8) {
|
||||
matOff[kOrder ^ 1] = mul(warpId, i32_val(warpMatOffset));
|
||||
} else {
|
||||
matOff[kOrder ^ 1] = add(
|
||||
mul(warpId, i32_val(warpMatOffset)), // warp offset (kOrder=1)
|
||||
mul(nkMatArr,
|
||||
i32_val(
|
||||
inWarpMatOffset))); // matrix offset inside a warp (kOrder=1)
|
||||
}
|
||||
matOff[kOrder] = kMatArr;
|
||||
|
||||
// Physical offset (before swizzling)
|
||||
@@ -390,13 +400,13 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> ptrs, Type matTy,
|
||||
}
|
||||
|
||||
MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
int warpsPerTile, ArrayRef<uint32_t> order, ArrayRef<uint32_t> warpsPerCTA,
|
||||
uint32_t kOrder, int kWidth, ArrayRef<Value> smemStrides,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase, int elemBytes,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
int nPerWarp, int warpsPerTile, ArrayRef<uint32_t> order,
|
||||
ArrayRef<uint32_t> warpsPerCTA, uint32_t kOrder, int kWidth,
|
||||
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
|
||||
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
|
||||
: order(order.begin(), order.end()),
|
||||
: nPerWarp(nPerWarp), order(order.begin(), order.end()),
|
||||
warpsPerCTA(warpsPerCTA.begin(), warpsPerCTA.end()), kOrder(kOrder),
|
||||
kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()),
|
||||
instrShape(instrShape.begin(), instrShape.end()),
|
||||
@@ -490,6 +500,7 @@ std::function<void(int, int)> getLoadMatrixFn(
|
||||
bool isA, TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Location loc) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto shapePerCTA = getShapePerCTA(tensorTy);
|
||||
Type eltTy = tensorTy.getElementType();
|
||||
// We assumes that the input operand of Dot should be from shared layout.
|
||||
// TODO(Superjomn) Consider other layouts if needed later.
|
||||
@@ -511,13 +522,16 @@ std::function<void(int, int)> getLoadMatrixFn(
|
||||
if (kWidth != (4 / elemBytes))
|
||||
assert(vecPhase == 1 || vecPhase == 4 * kWidth);
|
||||
|
||||
int nPerWarp =
|
||||
std::max<int>(shapePerCTA[1] / mmaLayout.getWarpsPerCTA()[1], 8);
|
||||
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &rewriter, &vals](int a, int b) {
|
||||
MMA16816SmemLoader loader(
|
||||
warpsPerTile, sharedLayout.getOrder(), mmaLayout.getWarpsPerCTA(),
|
||||
kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/,
|
||||
instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter,
|
||||
typeConverter, loc);
|
||||
nPerWarp, warpsPerTile, sharedLayout.getOrder(),
|
||||
mmaLayout.getWarpsPerCTA(), kOrder, kWidth, smemObj.strides,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
|
||||
maxPhase, elemBytes, rewriter, typeConverter, loc);
|
||||
// Offset of a slice within the original tensor in shared memory
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
SmallVector<Value> offs =
|
||||
|
||||
Reference in New Issue
Block a user