[Backend] Refactor sharedToDotOperandMFMA lowering (#439)

* Remove unnecessary xor computations for k-major swizzled tensors

* Support mfma16 and mfma4 in the fast path

* Choose warpsPerCTA according to nonKDim

* Set maxPhase=4 for mfma4

* Fix tests

For now, we do not disable swizzling for k-major tensors

* Remove fastPathComputeOffsetsTy1

* Enable k-major + disabled swizzling in the normal path
This commit is contained in:
Lixun Zhang
2024-01-12 12:50:18 -06:00
committed by GitHub
parent a7bb38ea79
commit 2e217c5a5c
5 changed files with 86 additions and 116 deletions

View File

@@ -109,13 +109,17 @@ swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
* elemsPerInstr[1].
*
* Total offset of element is a sum of following values:
* 1. Offset of wave block in tensor
* 2. Offset of wave inside one wave block
* 1. Offset of wave-block in tensor
* 2. Offset of wave inside one wave-block
* 3. Offset of tile in one wave
* 4. Offset of one lane data in a tile
* 5. Offset of particular element of tensor processed by one lane
*
* This function computes these offsets for axies independently
* Note that this function returns the offsets of elements in the first
* wave-block. The offsets of elements in later wave-blocks can be computed
* by adding a constant stride to the xor-ed offsets of elements in the
* first wave-block.
*
* @param rewriter
* @param loc
@@ -140,46 +144,36 @@ computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
auto numM = reps[0];
auto numK = reps[1];
const int loadsPerThread = numOfElems / loadVecSize;
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numM * numK *
loadsPerThread);
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numK * loadsPerThread);
Value _0 = i32_val(0);
Value _32 = i32_val(32);
Value nonKDim = i32_val(iNonKDim);
Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0]));
for (int block = 0; block < numM; ++block) {
Value blockVOffset = i32_val(block * elemsPerInstr[0] * warpsPerGroup);
Value blockHOffset = _0;
Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0]));
Value waveHOffset = _0;
for (int tile = 0; tile < numK; ++tile) {
Value tileVOffset = _0;
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);
for (int tile = 0; tile < numK; ++tile) {
Value tileVOffset = _0;
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);
Value laneVOffset = urem(laneId, nonKDim);
Value laneHOffset;
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
Value laneVOffset = urem(laneId, nonKDim);
Value laneHOffset;
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Value elemHOffset = i32_val(loadId * loadVecSize);
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Value elemHOffset = i32_val(loadId * loadVecSize);
Value sliceVOffset = add(
add(add(add(blockVOffset, waveVOffset), tileVOffset), laneVOffset),
elemVOffset);
Value sliceHOffset = add(
add(add(add(blockHOffset, waveHOffset), tileHOffset), laneHOffset),
elemHOffset);
Value sliceVOffset =
add(add(add(tileVOffset, laneVOffset), elemVOffset), waveVOffset);
Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset);
Value row = add(sliceVOffset, smemOffsets[0]);
Value col = add(sliceHOffset, smemOffsets[1]);
Value row = add(sliceVOffset, smemOffsets[0]);
Value col = add(sliceHOffset, smemOffsets[1]);
mapping[numK * loadsPerThread * block + loadsPerThread * tile +
loadId] = {row, col};
}
mapping[loadsPerThread * tile + loadId] = {row, col};
}
}
return mapping;
@@ -321,8 +315,6 @@ std::optional<int> findConstValue(Value val) {
bool fastPathAvailable(const SharedMemoryObject &smemObj,
const SharedEncodingAttr &srcEncoding,
const MfmaEncodingAttr &dstEncoding) {
if (dstEncoding.getNonKDim() != 32)
return false;
if (srcEncoding.getMaxPhase() > 1)
return false;
auto stride0 = findConstValue(smemObj.strides[0]);
@@ -338,52 +330,6 @@ bool fastPathAvailable(const SharedMemoryObject &smemObj,
return true;
}
// Computes offsets for operand A or transposed operand B
// @param rewriter
// @param loc
// @param elemsPerInstr operand tile shape consumed by one MFMA instruction
// @param waveM wave id for the "non K" axis
// @param laneId lane id in warp [0..63]
// @param warpsPerGroup number of warps in one block
// @param numOfElems number of elements accessed by thread per repetition
// @param reps number of instructions repretition to fully cover dot operand
// @param cSwizzleOffset
llvm::SmallVector<Value>
fastPathComputeOffsetsTy1(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
const int loadVecSize = numOfElems;
const int loadsPerThread = 1; // 1 is just in case if we decide to use different loadVecSize
auto numM = reps[0];
auto numK = reps[1];
SmallVector<Value> offsets(numM * numK * loadsPerThread);
int lineSize = elemsPerInstr[1] * numK;
int blockSize = elemsPerInstr[0] * warpsPerGroup * lineSize;
Value _0 = i32_val(0);
Value _32 = i32_val(32);
Value waveHalf = udiv(laneId, _32);
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[0] * lineSize));
Value colOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
for (int block = 0; block < numM; ++block) {
Value blockOffset = i32_val(block * blockSize);
for (int tile = 0; tile < numK; ++tile) {
Value tileOffset = i32_val(tile * elemsPerInstr[1]);
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value rowOffset =
add(mul(urem(laneId, _32), i32_val(lineSize)), i32_val(loadId * loadVecSize));
Value elemOffset = add(rowOffset, colOffset);
Value offset =
add(add(add(waveOffset, blockOffset), tileOffset), elemOffset);
offsets[numK * loadsPerThread * block + loadsPerThread * tile + loadId] = offset;
}
}
}
return offsets;
}
// Computes offsets for operand B or transposed operand A
// @param rewriter
// @param loc
@@ -395,19 +341,18 @@ fastPathComputeOffsetsTy1(ConversionPatternRewriter &rewriter, Location loc,
// @param reps number of instructions repretition to fully cover dot operand
// @param cSwizzleOffset
llvm::SmallVector<Value>
fastPathComputeOffsetsTy2(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
auto numK = reps[0];
auto numN = reps[1];
SmallVector<Value> offsets(numK * numN * numOfElems);
int lineSize = warpsPerGroup * elemsPerInstr[1] * numN;
Value _0 = i32_val(0);
Value _32 = i32_val(32);
Value _nonKDim = i32_val(elemsPerInstr[1]);
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[1]));
Value colOffset = urem(laneId, _32);
Value colOffset = urem(laneId, _nonKDim);
for (int block = 0; block < numN; ++block) {
Value blockOffset = i32_val(block * elemsPerInstr[1] * warpsPerGroup);
@@ -415,7 +360,7 @@ fastPathComputeOffsetsTy2(ConversionPatternRewriter &rewriter, Location loc,
Value tileOffset = i32_val(tile * elemsPerInstr[0] * lineSize);
for (int elem = 0; elem < numOfElems; ++elem) {
Value halfOffset =
select(icmp_uge(laneId, _32), i32_val(numOfElems * lineSize), _0);
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
Value rowOffset = add(i32_val(elem * lineSize), halfOffset);
Value elemOffset = add(rowOffset, colOffset);
Value offset =
@@ -427,12 +372,19 @@ fastPathComputeOffsetsTy2(ConversionPatternRewriter &rewriter, Location loc,
return offsets;
}
bool isTransposed(::llvm::ArrayRef<unsigned> order) {
bool isColMajor(::llvm::ArrayRef<unsigned> order) {
assert(order.size() == 2 && (order[0] & ~1ul) == 0 &&
order[0] + order[1] == 1);
return order[0] == 0;
}
bool isKMajor(::llvm::ArrayRef<unsigned> order, int opIdx) {
if (order[0] + opIdx == 1)
return true;
else
return false;
}
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Location loc, Value tensor, DotOperandEncodingAttr encoding,
const SharedMemoryObject &smemObj,
@@ -480,35 +432,38 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
SmallVector<Value> loadedValues;
SmallVector<Value> offsets;
Value smemBase;
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
bool isFastPath = fastPathAvailable(smemObj, sharedLayout, mfmaLayout);
if (!isKMajor(order, opIdx) && isFastPath) {
// fast path handles tensors that are not k-major, in which case swizzling
// is disabled and offsets computation can be simplified
// TODO (zhanglx): later when we enable vector access to LDS for non k-major
// tensors, we'll refactor the scope of fast and normal path
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
if (opIdx == 0) {
if (isTransposed(order)) { // HERE
if (isColMajor(order)) {
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrNonK};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsetsTy2(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, reps, cSwizzleOffset);
} else {
offsets = fastPathComputeOffsetsTy1(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
llvm_unreachable(
"row major operand A should be handled in the normal path");
}
} else {
if (isTransposed(order)) {
SmallVector<int64_t> elemsPerInstr{mfmaInstrNonK, mfmaInstrK};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsetsTy1(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
if (isColMajor(order)) {
llvm_unreachable(
"col major operand B should be handled in the normal path");
} else {
offsets = fastPathComputeOffsetsTy2(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, numReps, cSwizzleOffset);
}
}
smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
} else { // normal path
// Normal path handles tensors that are k-major, in which case swizzling
// is enabled and it requires a 2-step method to compute the offsets.
if (opIdx == 0) {
offsets = computeOffsetsAType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
@@ -525,18 +480,27 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Type resElemTy = typeConverter->convertType(elemTy);
Type smemPtrTy = getShemPtrTy(elemTy);
int loadsPerThread = offsets.size() / (numRepNonK * numRepK);
int loadsPerThread = offsets.size() / numRepK / (isFastPath ? numRepNonK : 1);
int elemsPerLoad = numOfElems / loadsPerThread;
assert(numOfElems % loadsPerThread == 0);
for (int nonK = 0; nonK < numRepNonK; ++nonK) {
Value blockVOffset = i32_val(nonK * mfmaInstrNonK * warpsPerGroupNonK);
Value offAdjust = mul(blockVOffset, i32_val(shape[order[0]]));
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
Value loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
Value loadOffset;
if (isFastPath)
loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
else
// In the normal path, we only computed the offsets of elements
// in the first wave-block. Therefore, we update the offsets
// of elements in later wave-blocks by adding a constant stride
loadOffset = add(offAdjust, offsets[k * loadsPerThread + loadId]);
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value loadedValue = load(loadAddress);

View File

@@ -94,8 +94,9 @@ warpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
}
SmallVector<unsigned, 2>
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
return warpsPerTile(dotOp, shape, numWarps, {32, 32});
warpsPerTileMFMA(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
SmallVector<int64_t, 2> shapePerWarp) {
return warpsPerTile(dotOp, shape, numWarps, shapePerWarp);
}
SmallVector<unsigned, 2>
@@ -263,7 +264,8 @@ public:
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp);
auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps);
auto warpsPerTile =
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,