Files
ROCm/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp
Lixun Zhang 02a2f24dd5 [Backend] Refactor mfma selection (#441)
* Select mfma dimensions and instruction from static table

* Extend mfmaLayout to include version and instrShape

* Simplify generateMFMAOp by searching the mfma instruction in the table

* Fix getNonKDim() and non_k_dim

* Break instrShape into MDim and NDim
2024-01-16 21:05:35 -06:00

534 lines
21 KiB
C++

/*
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#ifdef USE_ROCM
#include "../ConvertLayoutOpToLLVM.h"
#include "../Utility.h"
using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::getOrder;
using ::mlir::triton::gpu::getShapePerCTA;
using ::mlir::triton::gpu::SharedEncodingAttr;
namespace {
Type getShemPtrTy(Type elemTy) {
if (elemTy.isBF16()) {
auto ctx = elemTy.getContext();
return ptr_ty(type::i16Ty(ctx), 3);
}
return ptr_ty(elemTy, 3);
}
// Get waveId inside block of waves.
Value getWaveIdInBlock(ConversionPatternRewriter &rewriter, Location loc,
Value waveId, const ArrayRef<unsigned int> &wpt,
int elemPerInstrNonK, int tensorSizeNonK, int nonKIdx) {
if (nonKIdx == 1)
waveId = udiv(waveId, i32_val(wpt[0]));
return urem(urem(waveId, i32_val(wpt[nonKIdx])),
i32_val(tensorSizeNonK / elemPerInstrNonK));
}
} // namespace
namespace SharedToDotOperandMFMA {
/**
* @brief swizzling tensor element indexes according pattern encoded in
* SharedEncodingAttr
*
* @param rewriter
* @param loc
* @param row row of target tensor element related to the start of smemObj
* @param col col of target tensor element related to the start of smemObj
* @param smemObj shared memory object, contains info about tensor in LDS
* @param attr layout attribute, contains swizzling info
* @return swizzled row, col indexes in tensor notation
*/
std::pair<mlir::Value, mlir::Value>
swizzleIndexes(ConversionPatternRewriter &rewriter, Location loc, Value row,
Value col, SharedMemoryObject smemObj, SharedEncodingAttr attr) {
(void)smemObj; // unused in current pattern
bool transposed = (attr.getOrder()[0] != 1);
if (transposed) {
// tensor is column-wise, so swapping col and row in computations
std::swap(row, col);
}
auto vec = i32_val(attr.getVec());
auto perPhase = i32_val(attr.getPerPhase());
auto maxPhase = i32_val(attr.getMaxPhase());
// Original algorithm taken from getSwizzledSharedPtrs function
// (TritonGPUToLLVMBase.h): Basic algorithm for row-major tensor is following:
//
// phase = (row // perPhase) % maxPhase
// colOffSwizzled = ((col // vec) ^ phase) * vec
// colOffOrdered = col % vec
// colOff = colOffSwizzled + colOffOrdered
auto phase = urem(udiv(row, perPhase), maxPhase);
auto colOffSwizzled = mul(xor_(udiv(col, vec), phase), vec);
auto colOffOrdered = urem(col, vec);
auto colOff = add(colOffSwizzled, colOffOrdered);
if (transposed)
return {colOff, row};
else
return {row, colOff};
}
/**
* @brief This function maps particular load of mfma dot operand to element
* indexes(row, col)
*
* Whole tensor is broken into "blocks" of waves along "non-K" axis.
* One block could be processed by multiple waves.
* One wave works on a piece of tensor size elemsPerInstr[0] x K.
* Each of these pieces is broken into "tiles" of size elemsPerInstr[0] x
* elemsPerInstr[1].
*
* Total offset of element is a sum of following values:
* 1. Offset of wave-block in tensor
* 2. Offset of wave inside one wave-block
* 3. Offset of tile in one wave
* 4. Offset of one lane data in a tile
* 5. Offset of particular element of tensor processed by one lane
*
* This function computes these offsets for axies independently
* Note that this function returns the offsets of elements in the first
* wave-block. The offsets of elements in later wave-blocks can be computed
* by adding a constant stride to the xor-ed offsets of elements in the
* first wave-block.
*
* @param rewriter
* @param loc
* @param elemsPerInstr operand tile shape consumed by one MFMA instruction
* @param waveId id component of 2d wave grid along nono-K axis
* @param laneId lane id in warp [0..63]
* @param warpsPerGroup number of warps in one block
* @param numOfElems number of elements accessed by thread per repetition
* @param reps number of instructions repretition to fully cover dot operand
* @param smemStrides strides in LDS tensor
* @param loadVecSize number of elements loaded by one operation
* @param iNonKDim non-K dimension of dot operand
* @return vector (i-th element corresponds to i-th load instruction) of
* 2-element vectors(tensor row and col).
*/
llvm::SmallVector<llvm::SmallVector<Value>>
computeTensorElemMapping(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, ArrayRef<Value> smemOffsets,
int loadVecSize, unsigned iNonKDim) {
auto numM = reps[0];
auto numK = reps[1];
const int loadsPerThread = numOfElems / loadVecSize;
llvm::SmallVector<llvm::SmallVector<Value>> mapping(numK * loadsPerThread);
Value _0 = i32_val(0);
Value _32 = i32_val(32);
Value nonKDim = i32_val(iNonKDim);
Value waveVOffset = mul(waveId, i32_val(elemsPerInstr[0]));
for (int tile = 0; tile < numK; ++tile) {
Value tileVOffset = _0;
Value tileHOffset = i32_val(tile * elemsPerInstr[1]);
Value laneVOffset = urem(laneId, nonKDim);
Value laneHOffset;
if (iNonKDim == 32)
laneHOffset = select(icmp_uge(laneId, _32), i32_val(numOfElems), _0);
else
laneHOffset = mul(udiv(laneId, nonKDim), i32_val(numOfElems));
for (int loadId = 0; loadId < loadsPerThread; ++loadId) {
Value elemVOffset = _0;
Value elemHOffset = i32_val(loadId * loadVecSize);
Value sliceVOffset =
add(add(add(tileVOffset, laneVOffset), elemVOffset), waveVOffset);
Value sliceHOffset = add(add(tileHOffset, laneHOffset), elemHOffset);
Value row = add(sliceVOffset, smemOffsets[0]);
Value col = add(sliceHOffset, smemOffsets[1]);
mapping[loadsPerThread * tile + loadId] = {row, col};
}
}
return mapping;
}
bool isSwizzled(SharedEncodingAttr layout) { return layout.getMaxPhase() != 1; }
Value computeOffset(ConversionPatternRewriter &rewriter, Location loc,
Value row, Value col, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout) {
auto [swizzledRow, swizzledCol] =
swizzleIndexes(rewriter, loc, row, col, smemObj, srcLayout);
auto &strides = smemObj.strides;
Value rowOffset = mul(swizzledRow, strides[0]);
Value colOffset = mul(swizzledCol, strides[1]);
return add(rowOffset, colOffset);
}
llvm::SmallVector<Value>
computeOffsetsAType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim) {
SmallVector<Value> strides{smemObj.strides[0], smemObj.strides[1]};
SmallVector<Value> offsets{smemObj.offsets[0], smemObj.offsets[1]};
int vectorSize = 1;
if (srcLayout.getOrder()[0] == 1) {
if (isSwizzled(srcLayout))
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
else
vectorSize = numOfElems;
}
auto mapping = computeTensorElemMapping(rewriter, loc, elemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
reps, offsets, vectorSize, nonKDim);
llvm::SmallVector<Value> aOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
Value row = mapping[i][0];
Value col = mapping[i][1];
aOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout);
}
return aOffsets;
}
llvm::SmallVector<Value>
computeOffsetsBType(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, SharedMemoryObject smemObj,
SharedEncodingAttr srcLayout, unsigned nonKDim) {
// transpose reps and offsets, because operand B has layout equal to
// transposed operand A layout
SmallVector<int64_t> tElemsPerInstr{elemsPerInstr[1], elemsPerInstr[0]};
SmallVector<int64_t> tReps{reps[1], reps[0]};
SmallVector<Value> toffsets{smemObj.offsets[1], smemObj.offsets[0]};
int vectorSize = 1;
if (srcLayout.getOrder()[0] == 0) {
if (isSwizzled(srcLayout))
vectorSize = std::min(static_cast<int>(srcLayout.getVec()), numOfElems);
else
vectorSize = numOfElems;
}
auto mapping = computeTensorElemMapping(rewriter, loc, tElemsPerInstr, waveId,
laneId, warpsPerGroup, numOfElems,
tReps, toffsets, vectorSize, nonKDim);
llvm::SmallVector<Value> bOffsets(mapping.size());
for (int i = 0; i < mapping.size(); ++i) {
// swap row and col, because operand B layout is a transposed operand A
// layout
Value row = mapping[i][1];
Value col = mapping[i][0];
bOffsets[i] = computeOffset(rewriter, loc, row, col, smemObj, srcLayout);
}
return bOffsets;
}
Value computeBasePtr(ConversionPatternRewriter &rewriter, Location loc,
const SharedMemoryObject &smemObj) {
Value base = smemObj.base;
Type type = base.getType();
for (int i = 0; i < smemObj.strides.size(); ++i) {
Value offset = sub(i32_val(0), mul(smemObj.offsets[i], smemObj.strides[i]));
base = gep(type, base, offset);
}
return base;
}
/**
* @brief try find if value is an integer constant
*
* Trace def-use chain and return integer in case we can proof it is constant.
* Current implementation can trace chains of insertValue->extractValue
* operations.
*
* @param val Value for that we want to get constant
* @return std::optional on found integer value or empty std::optional
*/
std::optional<int> findConstValue(Value val) {
while (val && !val.getDefiningOp<LLVM::ConstantOp>()) {
LLVM::ExtractValueOp extractValOp =
val.getDefiningOp<LLVM::ExtractValueOp>();
if (!extractValOp)
return std::optional<int>();
auto extractPosArr = extractValOp.getPosition();
if (extractPosArr.size() > 1)
return std::optional<int>();
int extractPos = extractPosArr[0];
int insertPos = -1;
LLVM::InsertValueOp insertValOp;
Value container = extractValOp.getOperand();
do {
insertValOp = container.getDefiningOp<LLVM::InsertValueOp>();
if (!insertValOp)
return std::optional<int>();
auto insertPosArr = insertValOp.getPosition();
if (insertPosArr.size() > 1)
return std::optional<int>();
insertPos = insertPosArr[0];
container = insertValOp.getContainer();
} while (insertPos != extractPos);
val = insertValOp.getValue();
}
if (!val)
return std::optional<int>();
auto cOp = val.getDefiningOp<LLVM::ConstantOp>();
assert(cOp);
auto valAttr = cOp.getValueAttr();
auto intAttr = dyn_cast<mlir::IntegerAttr>(valAttr);
assert(intAttr);
return intAttr.getInt();
}
bool fastPathAvailable(const SharedMemoryObject &smemObj,
const SharedEncodingAttr &srcEncoding,
const MfmaEncodingAttr &dstEncoding) {
if (srcEncoding.getMaxPhase() > 1)
return false;
auto stride0 = findConstValue(smemObj.strides[0]);
auto stride1 = findConstValue(smemObj.strides[1]);
auto offset0 = findConstValue(smemObj.offsets[0]);
auto offset1 = findConstValue(smemObj.offsets[1]);
bool allValuesDefined = stride0.has_value() && stride1.has_value() &&
offset0.has_value() && offset1.has_value();
if (!allValuesDefined)
return false;
if (offset0.value() != 0 || offset1.value() != 0)
return false;
return true;
}
// Computes offsets for operand B or transposed operand A
// @param rewriter
// @param loc
// @param elemsPerInstr operand tile shape consumed by one MFMA instruction
// @param waveId wave id for the "non K" axis
// @param laneId lane id in warp [0..63]
// @param warpsPerGroup number of warps per horizontal axis
// @param numOfElems number of elements accessed by threads per repetition
// @param reps number of instructions repretition to fully cover dot operand
// @param cSwizzleOffset
llvm::SmallVector<Value>
fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc,
const ArrayRef<int64_t> &elemsPerInstr, Value waveId,
Value laneId, int warpsPerGroup, int numOfElems,
ArrayRef<int64_t> reps, Value cSwizzleOffset) {
auto numK = reps[0];
auto numN = reps[1];
SmallVector<Value> offsets(numK * numN * numOfElems);
int lineSize = warpsPerGroup * elemsPerInstr[1] * numN;
Value _nonKDim = i32_val(elemsPerInstr[1]);
Value waveOffset = mul(waveId, i32_val(elemsPerInstr[1]));
Value colOffset = urem(laneId, _nonKDim);
for (int block = 0; block < numN; ++block) {
Value blockOffset = i32_val(block * elemsPerInstr[1] * warpsPerGroup);
for (int tile = 0; tile < numK; ++tile) {
Value tileOffset = i32_val(tile * elemsPerInstr[0] * lineSize);
for (int elem = 0; elem < numOfElems; ++elem) {
Value halfOffset =
mul(udiv(laneId, _nonKDim), i32_val(numOfElems * lineSize));
Value rowOffset = add(i32_val(elem * lineSize), halfOffset);
Value elemOffset = add(rowOffset, colOffset);
Value offset =
add(add(add(waveOffset, blockOffset), tileOffset), elemOffset);
offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset;
}
}
}
return offsets;
}
bool isColMajor(::llvm::ArrayRef<unsigned> order) {
assert(order.size() == 2 && (order[0] & ~1ul) == 0 &&
order[0] + order[1] == 1);
return order[0] == 0;
}
bool isKMajor(::llvm::ArrayRef<unsigned> order, int opIdx) {
if (order[0] + opIdx == 1)
return true;
else
return false;
}
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Location loc, Value tensor, DotOperandEncodingAttr encoding,
const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx");
int kDimIdx = opIdx == 0 ? 1 : 0;
int nonKDimIdx = opIdx == 0 ? 0 : 1;
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
int nonKDim = mfmaLayout.getMDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
ArrayRef<int64_t> shape = aTensorTy.getShape();
auto sharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder();
auto elemTy = aTensorTy.getElementType();
auto elemsPerInstr = encoding.getMFMAElemsPerInstr();
auto mfmaInstrNonK = elemsPerInstr[nonKDimIdx];
auto mfmaInstrK = elemsPerInstr[kDimIdx];
auto numReps = encoding.getMFMARep(shape);
auto numRepNonK = numReps[nonKDimIdx];
auto numRepK = numReps[kDimIdx];
unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout);
assert(iWaveSize == 64);
Value waveSize = i32_val(iWaveSize);
Value linearWaveId = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);
Value spatialWaveId =
getWaveIdInBlock(rewriter, loc, linearWaveId, warpsPerCTA, mfmaInstrNonK,
shape[nonKDimIdx], nonKDimIdx);
int numOfElems = mfmaInstrNonK * mfmaInstrK / iWaveSize;
assert(numOfElems >= 1);
unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK;
int warpsPerGroupNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps);
elemTy = typeConverter->convertType(elemTy);
SmallVector<Value> loadedValues;
SmallVector<Value> offsets;
Value smemBase;
bool isFastPath = fastPathAvailable(smemObj, sharedLayout, mfmaLayout);
if (!isKMajor(order, opIdx) && isFastPath) {
// fast path handles tensors that are not k-major, in which case swizzling
// is disabled and offsets computation can be simplified
// TODO (zhanglx): later when we enable vector access to LDS for non k-major
// tensors, we'll refactor the scope of fast and normal path
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
if (opIdx == 0) {
if (isColMajor(order)) {
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrNonK};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, reps, cSwizzleOffset);
} else {
llvm_unreachable(
"row major operand A should be handled in the normal path");
}
} else {
if (isColMajor(order)) {
llvm_unreachable(
"col major operand B should be handled in the normal path");
} else {
offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr,
spatialWaveId, lane, warpsPerGroupNonK,
numOfElems, numReps, cSwizzleOffset);
}
}
smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
} else { // normal path
// Normal path handles tensors that are k-major, in which case swizzling
// is enabled and it requires a 2-step method to compute the offsets.
if (opIdx == 0) {
offsets = computeOffsetsAType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
} else {
assert(opIdx == 1);
offsets = computeOffsetsBType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
}
smemBase = computeBasePtr(rewriter, loc, smemObj);
}
Type resElemTy = typeConverter->convertType(elemTy);
Type smemPtrTy = getShemPtrTy(elemTy);
int loadsPerThread = offsets.size() / numRepK / (isFastPath ? numRepNonK : 1);
int elemsPerLoad = numOfElems / loadsPerThread;
assert(numOfElems % loadsPerThread == 0);
for (int nonK = 0; nonK < numRepNonK; ++nonK) {
Value blockVOffset = i32_val(nonK * mfmaInstrNonK * warpsPerGroupNonK);
Value offAdjust = mul(blockVOffset, i32_val(shape[order[0]]));
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
Value loadOffset;
if (isFastPath)
loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
else
// In the normal path, we only computed the offsets of elements
// in the first wave-block. Therefore, we update the offsets
// of elements in later wave-blocks by adding a constant stride
loadOffset = add(offAdjust, offsets[k * loadsPerThread + loadId]);
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value loadedValue = load(loadAddress);
if (loadsPerThread > 1) {
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(elemTy, loadedValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = loadedValue;
}
}
loadedValues.push_back(valVec);
}
}
MLIRContext *ctx = mfmaLayout.getContext();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(loadedValues.size(), loadedValues[0].getType()));
auto result =
typeConverter->packLLElements(loc, loadedValues, rewriter, structTy);
return result;
}
} // namespace SharedToDotOperandMFMA
#endif // ifdef USE_ROCM