mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Backend] Refactor sharedToDotOperandMFMA lowering (#439)
* Remove unnecessary xor computations for k-major swizzled tensors * Support mfma16 and mfma4 in the fast path * Choose warpsPerCTA according to nonKDim * Set maxPhase=4 for mfma4 * Fix tests For now, we do not disable swizzling for k-major tensors * Remove fastPathComputeOffsetsTy1 * Enable k-major + disabled swizzling in the normal path
This commit is contained in:
@@ -109,13 +109,17 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
|
||||
* elemsPerInstr[1].
|
||||
*
|
||||
* Total offset of element is a sum of following values:
|
||||
* 1. Offset of wave block in tensor
|
||||
* 2. Offset of wave inside one wave block
|
||||
* 1. Offset of wave-block in tensor
|
||||
* 2. Offset of wave inside one wave-block
|
||||
* 3. Offset of tile in one wave
|
||||
* 4. Offset of one lane data in a tile
|
||||
* 5. Offset of particular element of tensor processed by one lane
|
||||
*
|
||||
* This function computes these offsets for axies independently
|
||||
* Note that this function returns the offsets of elements in the first
|
||||
* wave-block. The offsets of elements in later wave-blocks can be computed
|
||||
* by adding a constant stride to the xor-ed offsets of elements in the
|
||||
* first wave-block.
|
||||
*
|
||||
* @param rewriter
|
||||
* @param loc
|
||||
@@ -140,46 +144,36 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
|
||||
auto numM = reps[0];
|
||||
auto numK = reps[1];
|
||||
const int loadsPerThread = numOfElems / loadVecSize;
|
||||
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numM * numK *
|
||||
loadsPerThread);
|
||||
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numK * loadsPerThread);
|
||||
|
||||
Value _0 = i32_val(0);
|
||||
Value _32 = i32_val(32);
|
||||
Value nonKDim = i32_val(iNonKDim);
|
||||
Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0]));
|
||||
|
||||
for (int block = 0; block < numM; ++block) {
|
||||
Value blockVOffset = i32_val(block * elemsPerInstr[0] * warpsPerGroup);
|
||||
Value blockHOffset = _0;
|
||||
Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0]));
|
||||
Value waveHOffset = _0;
|
||||
for (int tile = 0; tile < numK; ++tile) {
|
||||
Value tileVOffset = _0;
|
||||
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);
|
||||
for (int tile = 0; tile < numK; ++tile) {
|
||||
Value tileVOffset = _0;
|
||||
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);
|
||||
|
||||
Value laneVOffset = urem(laneId, nonKDim);
|
||||
Value laneHOffset;
|
||||
if (iNonKDim == 32)
|
||||
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
|
||||
else
|
||||
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
|
||||
Value laneVOffset = urem(laneId, nonKDim);
|
||||
Value laneHOffset;
|
||||
if (iNonKDim == 32)
|
||||
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
|
||||
else
|
||||
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
|
||||
|
||||
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
Value elemVOffset = _0;
|
||||
Value elemHOffset = i32_val(loadId * loadVecSize);
|
||||
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
Value elemVOffset = _0;
|
||||
Value elemHOffset = i32_val(loadId * loadVecSize);
|
||||
|
||||
Value sliceVOffset = add(
|
||||
add(add(add(blockVOffset, waveVOffset), tileVOffset), laneVOffset),
|
||||
elemVOffset);
|
||||
Value sliceHOffset = add(
|
||||
add(add(add(blockHOffset, waveHOffset), tileHOffset), laneHOffset),
|
||||
elemHOffset);
|
||||
Value sliceVOffset =
|
||||
add(add(add(tileVOffset, laneVOffset), elemVOffset), waveVOffset);
|
||||
Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset);
|
||||
|
||||
Value row = add(sliceVOffset, smemOffsets[0]);
|
||||
Value col = add(sliceHOffset, smemOffsets[1]);
|
||||
Value row = add(sliceVOffset, smemOffsets[0]);
|
||||
Value col = add(sliceHOffset, smemOffsets[1]);
|
||||
|
||||
mapping[numK * loadsPerThread * block + loadsPerThread * tile +
|
||||
loadId] = {row, col};
|
||||
}
|
||||
mapping[loadsPerThread * tile + loadId] = {row, col};
|
||||
}
|
||||
}
|
||||
return mapping;
|
||||
@@ -321,8 +315,6 @@ std::optional<int> findConstValue(Value val) {
|
||||
bool fastPathAvailable(const SharedMemoryObject &smemObj,
|
||||
const SharedEncodingAttr &srcEncoding,
|
||||
const MfmaEncodingAttr &dstEncoding) {
|
||||
if (dstEncoding.getNonKDim() != 32)
|
||||
return false;
|
||||
if (srcEncoding.getMaxPhase() > 1)
|
||||
return false;
|
||||
auto stride0 = findConstValue(smemObj.strides[0]);
|
||||
@@ -338,52 +330,6 @@ bool fastPathAvailable(const SharedMemoryObject &smemObj,
|
||||
return true;
|
||||
}
|
||||
|
||||
// Computes offsets for operand A or transposed operand B
|
||||
// @param rewriter
|
||||
// @param loc
|
||||
// @param elemsPerInstr operand tile shape consumed by one MFMA instruction
|
||||
// @param waveM wave id for the "non K" axis
|
||||
// @param laneId lane id in warp [0..63]
|
||||
// @param warpsPerGroup number of warps in one block
|
||||
// @param numOfElems number of elements accessed by thread per repetition
|
||||
// @param reps number of instructions repretition to fully cover dot operand
|
||||
// @param cSwizzleOffset
|
||||
llvm::SmallVector<Value>
|
||||
fastPathComputeOffsetsTy1(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
|
||||
const int loadVecSize = numOfElems;
|
||||
const int loadsPerThread = 1; // 1 is just in case if we decide to use different loadVecSize
|
||||
auto numM = reps[0];
|
||||
auto numK = reps[1];
|
||||
SmallVector<Value> offsets(numM * numK * loadsPerThread);
|
||||
int lineSize = elemsPerInstr[1] * numK;
|
||||
int blockSize = elemsPerInstr[0] * warpsPerGroup * lineSize;
|
||||
Value _0 = i32_val(0);
|
||||
Value _32 = i32_val(32);
|
||||
Value waveHalf = udiv(laneId, _32);
|
||||
|
||||
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[0] * lineSize));
|
||||
Value colOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
|
||||
|
||||
for (int block = 0; block < numM; ++block) {
|
||||
Value blockOffset = i32_val(block * blockSize);
|
||||
for (int tile = 0; tile < numK; ++tile) {
|
||||
Value tileOffset = i32_val(tile * elemsPerInstr[1]);
|
||||
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
Value rowOffset =
|
||||
add(mul(urem(laneId, _32), i32_val(lineSize)), i32_val(loadId * loadVecSize));
|
||||
Value elemOffset = add(rowOffset, colOffset);
|
||||
Value offset =
|
||||
add(add(add(waveOffset, blockOffset), tileOffset), elemOffset);
|
||||
offsets[numK * loadsPerThread * block + loadsPerThread * tile + loadId] = offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
// Computes offsets for operand B or transposed operand A
|
||||
// @param rewriter
|
||||
// @param loc
|
||||
@@ -395,19 +341,18 @@ fastPathComputeOffsetsTy1(ConversionPatternRewriter &rewriter, Location loc,
|
||||
// @param reps number of instructions repretition to fully cover dot operand
|
||||
// @param cSwizzleOffset
|
||||
llvm::SmallVector<Value>
|
||||
fastPathComputeOffsetsTy2(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
|
||||
fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
|
||||
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
|
||||
Value laneId, int warpsPerGroup, int numOfElems,
|
||||
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
|
||||
auto numK = reps[0];
|
||||
auto numN = reps[1];
|
||||
SmallVector<Value> offsets(numK * numN * numOfElems);
|
||||
|
||||
int lineSize = warpsPerGroup * elemsPerInstr[1] * numN;
|
||||
Value _0 = i32_val(0);
|
||||
Value _32 = i32_val(32);
|
||||
Value _nonKDim = i32_val(elemsPerInstr[1]);
|
||||
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[1]));
|
||||
Value colOffset = urem(laneId, _32);
|
||||
Value colOffset = urem(laneId, _nonKDim);
|
||||
|
||||
for (int block = 0; block < numN; ++block) {
|
||||
Value blockOffset = i32_val(block * elemsPerInstr[1] * warpsPerGroup);
|
||||
@@ -415,7 +360,7 @@ fastPathComputeOffsetsTy2(ConversionPatternRewriter &rewriter, Location loc,
|
||||
Value tileOffset = i32_val(tile * elemsPerInstr[0] * lineSize);
|
||||
for (int elem = 0; elem < numOfElems; ++elem) {
|
||||
Value halfOffset =
|
||||
select(icmp_uge(laneId, _32), i32_val(numOfElems * lineSize), _0);
|
||||
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
|
||||
Value rowOffset = add(i32_val(elem * lineSize), halfOffset);
|
||||
Value elemOffset = add(rowOffset, colOffset);
|
||||
Value offset =
|
||||
@@ -427,12 +372,19 @@ fastPathComputeOffsetsTy2(ConversionPatternRewriter &rewriter, Location loc,
|
||||
return offsets;
|
||||
}
|
||||
|
||||
bool isTransposed(::llvm::ArrayRef<unsigned> order) {
|
||||
bool isColMajor(::llvm::ArrayRef<unsigned> order) {
|
||||
assert(order.size() == 2 && (order[0] & ~1ul) == 0 &&
|
||||
order[0] + order[1] == 1);
|
||||
return order[0] == 0;
|
||||
}
|
||||
|
||||
bool isKMajor(::llvm::ArrayRef<unsigned> order, int opIdx) {
|
||||
if (order[0] + opIdx == 1)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
}
|
||||
|
||||
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
||||
Location loc, Value tensor, DotOperandEncodingAttr encoding,
|
||||
const SharedMemoryObject &smemObj,
|
||||
@@ -480,35 +432,38 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
||||
SmallVector<Value> loadedValues;
|
||||
SmallVector<Value> offsets;
|
||||
Value smemBase;
|
||||
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
|
||||
bool isFastPath = fastPathAvailable(smemObj, sharedLayout, mfmaLayout);
|
||||
if (!isKMajor(order, opIdx) && isFastPath) {
|
||||
// fast path handles tensors that are not k-major, in which case swizzling
|
||||
// is disabled and offsets computation can be simplified
|
||||
// TODO (zhanglx): later when we enable vector access to LDS for non k-major
|
||||
// tensors, we'll refactor the scope of fast and normal path
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
if (opIdx == 0) {
|
||||
if (isTransposed(order)) { // HERE
|
||||
if (isColMajor(order)) {
|
||||
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrNonK};
|
||||
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
||||
offsets = fastPathComputeOffsetsTy2(
|
||||
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
||||
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
|
||||
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
|
||||
spatialWaveId, lane, warpsPerGroupNonK,
|
||||
numOfElems, reps, cSwizzleOffset);
|
||||
} else {
|
||||
offsets = fastPathComputeOffsetsTy1(
|
||||
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
||||
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
|
||||
llvm_unreachable(
|
||||
"row major operand A should be handled in the normal path");
|
||||
}
|
||||
} else {
|
||||
if (isTransposed(order)) {
|
||||
SmallVector<int64_t> elemsPerInstr{mfmaInstrNonK, mfmaInstrK};
|
||||
SmallVector<int64_t> reps{numReps[1], numReps[0]};
|
||||
offsets = fastPathComputeOffsetsTy1(
|
||||
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
||||
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
|
||||
if (isColMajor(order)) {
|
||||
llvm_unreachable(
|
||||
"col major operand B should be handled in the normal path");
|
||||
} else {
|
||||
offsets = fastPathComputeOffsetsTy2(
|
||||
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
|
||||
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
|
||||
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
|
||||
spatialWaveId, lane, warpsPerGroupNonK,
|
||||
numOfElems, numReps, cSwizzleOffset);
|
||||
}
|
||||
}
|
||||
smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
|
||||
} else { // normal path
|
||||
// Normal path handles tensors that are k-major, in which case swizzling
|
||||
// is enabled and it requires a 2-step method to compute the offsets.
|
||||
if (opIdx == 0) {
|
||||
offsets = computeOffsetsAType(rewriter, loc, elemsPerInstr, spatialWaveId,
|
||||
lane, warpsPerGroupNonK, numOfElems,
|
||||
@@ -525,18 +480,27 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
||||
Type resElemTy = typeConverter->convertType(elemTy);
|
||||
Type smemPtrTy = getShemPtrTy(elemTy);
|
||||
|
||||
int loadsPerThread = offsets.size() / (numRepNonK * numRepK);
|
||||
int loadsPerThread = offsets.size() / numRepK / (isFastPath ? numRepNonK : 1);
|
||||
int elemsPerLoad = numOfElems / loadsPerThread;
|
||||
assert(numOfElems % loadsPerThread == 0);
|
||||
|
||||
for (int nonK = 0; nonK < numRepNonK; ++nonK) {
|
||||
Value blockVOffset = i32_val(nonK * mfmaInstrNonK * warpsPerGroupNonK);
|
||||
Value offAdjust = mul(blockVOffset, i32_val(shape[order[0]]));
|
||||
for (int k = 0; k < numRepK; ++k) {
|
||||
auto vecTy = vec_ty(resElemTy, numOfElems);
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
|
||||
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
|
||||
Value loadOffset = offsets[nonK * loadsPerThread * numRepK +
|
||||
k * loadsPerThread + loadId];
|
||||
Value loadOffset;
|
||||
if (isFastPath)
|
||||
loadOffset = offsets[nonK * loadsPerThread * numRepK +
|
||||
k * loadsPerThread + loadId];
|
||||
else
|
||||
// In the normal path, we only computed the offsets of elements
|
||||
// in the first wave-block. Therefore, we update the offsets
|
||||
// of elements in later wave-blocks by adding a constant stride
|
||||
loadOffset = add(offAdjust, offsets[k * loadsPerThread + loadId]);
|
||||
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
|
||||
getShemPtrTy(loadVecTy));
|
||||
Value loadedValue = load(loadAddress);
|
||||
|
||||
@@ -94,8 +94,9 @@ warpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
|
||||
return warpsPerTile(dotOp, shape, numWarps, {32, 32});
|
||||
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
|
||||
SmallVector<int64_t, 2> shapePerWarp) {
|
||||
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2>
|
||||
@@ -263,7 +264,8 @@ public:
|
||||
|
||||
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp);
|
||||
|
||||
auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps);
|
||||
auto warpsPerTile =
|
||||
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
|
||||
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
|
||||
|
||||
Reference in New Issue
Block a user