[MFMA] Refactor dot pipeline to reduce code duplication (#400)

This PR:
- simplifies data types generated by `shared->mfma dot op` layout conversions. Do not pack data types in int32 or int64
- reduce code duplication between fast/normal path
- reduce code duplication between operand A and operand B

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
Alexander Efimov
2023-12-13 22:33:02 +01:00
committed by GitHub
parent 605a90c58e
commit f2afd65e8c
6 changed files with 132 additions and 308 deletions

View File

@@ -884,7 +884,7 @@ private:
return success();
}
// shared -> mma_operand
// shared -> mma_operand/mfma_operand
LogicalResult
lowerSharedToDotOperand(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
@@ -1060,7 +1060,7 @@ private:
}
#ifdef USE_ROCM
// shared -> dot_operand if the result layout is mma
// shared -> dot_operand if the result layout is mfma
Value lowerSharedToDotOperandMFMA(
triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter, const MfmaEncodingAttr &mfmaLayout,

View File

@@ -41,16 +41,14 @@ Type getShemPtrTy(Type elemTy) {
return ptr_ty(elemTy, 3);
}
// Get a waveId for M axis.
Value getWaveM(ConversionPatternRewriter &rewriter, Location loc, Value wave,
const ArrayRef<unsigned int> &wpt, int elemPerInstr, int M) {
return urem(urem(wave, i32_val(wpt[0])), i32_val(M / elemPerInstr));
}
// Get a waveId for N axis.
Value getWaveN(ConversionPatternRewriter &rewriter, Location loc, Value wave,
const ArrayRef<unsigned int> &wpt, int elemPerInstr, int N) {
Value waveMN = udiv(wave, i32_val(wpt[0]));
return urem(urem(waveMN, i32_val(wpt[1])), i32_val(N / elemPerInstr));
// Get waveId inside block of waves.
Value getWaveIdInBlock(ConversionPatternRewriter &rewriter, Location loc,
Value waveId, const ArrayRef<unsigned int> &wpt,
int elemPerInstrNonK, int tensorSizeNonK, int nonKIdx) {
if (nonKIdx == 1)
waveId = udiv(waveId, i32_val(wpt[0]));
return urem(urem(waveId, i32_val(wpt[nonKIdx])),
i32_val(tensorSizeNonK / elemPerInstrNonK));
}
} // namespace
@@ -435,319 +433,137 @@ bool isTransposed(::llvm::ArrayRef<unsigned> order) {
return order[0] == 0;
}
Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value thread,
DotOperandEncodingAttr encoding,
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
const SharedMemoryObject &smemObj) {
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Location loc, Value tensor, DotOperandEncodingAttr encoding,
const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
assert((opIdx == 0 || opIdx == 1) && "unexpected operand idx");
int kDimIdx = opIdx == 0 ? 1 : 0;
int nonKDimIdx = opIdx == 0 ? 0 : 1;
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
ArrayRef<int64_t> shape = aTensorTy.getShape();
auto sharedLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder();
auto aElemTy = aTensorTy.getElementType();
auto aElemsPerInstr = encoding.getMFMAElemsPerInstr();
auto mfmaInstrM = aElemsPerInstr[0];
auto mfmaInstrK = aElemsPerInstr[1];
auto elemTy = aTensorTy.getElementType();
auto elemsPerInstr = encoding.getMFMAElemsPerInstr();
auto mfmaInstrNonK = elemsPerInstr[nonKDimIdx];
auto mfmaInstrK = elemsPerInstr[kDimIdx];
auto numReps = encoding.getMFMARep(shape, aElemTy);
auto numRepM = numReps[0];
auto numRepK = numReps[1];
auto numReps = encoding.getMFMARep(shape);
auto numRepNonK = numReps[nonKDimIdx];
auto numRepK = numReps[kDimIdx];
unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout);
assert(iWaveSize == 64);
Value waveSize = i32_val(iWaveSize);
Value wave = udiv(thread, waveSize);
Value linearWaveId = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);
Value waveM =
getWaveM(rewriter, loc, wave, warpsPerCTA, mfmaInstrM, shape[0]);
int numOfElems = mfmaInstrM * mfmaInstrK / iWaveSize;
Value spatialWaveId =
getWaveIdInBlock(rewriter, loc, linearWaveId, warpsPerCTA, mfmaInstrNonK,
shape[nonKDimIdx], nonKDimIdx);
int numOfElems = mfmaInstrNonK * mfmaInstrK / iWaveSize;
assert(numOfElems >= 1);
unsigned int maxNumWarps = shape[0] / mfmaInstrM;
int warpsPerGroupM = std::min(warpsPerCTA[0], maxNumWarps);
aElemTy = typeConverter->convertType(aElemTy);
SmallVector<Value> ha;
unsigned int maxNumWarps = shape[nonKDimIdx] / mfmaInstrNonK;
int warpsPerGroupNonK = std::min(warpsPerCTA[nonKDimIdx], maxNumWarps);
elemTy = typeConverter->convertType(elemTy);
SmallVector<Value> loadedValues;
SmallVector<Value> offsets;
Value smemBase;
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offsets;
if (isTransposed(order)) {
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrM};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsetsTy2(rewriter, loc, elemsPerInstr, waveM,
lane, warpsPerGroupM, numOfElems,
reps, cSwizzleOffset);
if (opIdx == 0) {
if (isTransposed(order)) { // HERE
SmallVector<int64_t> elemsPerInstr{mfmaInstrK, mfmaInstrNonK};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsetsTy2(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
} else {
offsets = fastPathComputeOffsetsTy1(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
}
} else {
offsets = fastPathComputeOffsetsTy1(rewriter, loc, aElemsPerInstr, waveM,
lane, warpsPerGroupM, numOfElems,
numReps, cSwizzleOffset);
}
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
Type smemPtrTy = getShemPtrTy(aElemTy);
Type resElemTy = typeConverter->convertType(aElemTy);
int loadsPerThread = offsets.size() / (numRepM * numRepK);
const int elemsPerLoad = numOfElems / loadsPerThread;
assert(numOfElems % loadsPerThread == 0);
for (int m = 0; m < numRepM; ++m) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
Value loadOffset =
offsets[m * loadsPerThread * numRepK + k * loadsPerThread + loadId];
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value vectorValue = load(loadAddress);
if (numOfElems > 1) {
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(aElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
}
}
if (aElemTy == i8_ty && numOfElems == 4)
valVec = bitcast(valVec, i32_ty);
if (aElemTy == i8_ty && numOfElems == 8)
valVec = bitcast(valVec, i64_ty);
ha.push_back(valVec);
if (isTransposed(order)) {
SmallVector<int64_t> elemsPerInstr{mfmaInstrNonK, mfmaInstrK};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsetsTy1(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, reps, cSwizzleOffset);
} else {
offsets = fastPathComputeOffsetsTy2(
rewriter, loc, elemsPerInstr, spatialWaveId, lane,
warpsPerGroupNonK, numOfElems, numReps, cSwizzleOffset);
}
}
smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
} else { // normal path
SmallVector<Value> offsets = computeOffsetsAType(
rewriter, loc, aElemsPerInstr, waveM, lane, warpsPerGroupM, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
if (opIdx == 0) {
offsets = computeOffsetsAType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
} else {
assert(opIdx == 1);
offsets = computeOffsetsBType(rewriter, loc, elemsPerInstr, spatialWaveId,
lane, warpsPerGroupNonK, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
}
smemBase = computeBasePtr(rewriter, loc, smemObj);
}
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = typeConverter->convertType(aElemTy);
Type resElemTy = typeConverter->convertType(elemTy);
Type smemPtrTy = getShemPtrTy(elemTy);
Type smemPtrTy = getShemPtrTy(aElemTy);
int loadsPerThread = offsets.size() / (numRepNonK * numRepK);
int elemsPerLoad = numOfElems / loadsPerThread;
assert(numOfElems % loadsPerThread == 0);
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
int elemsPerLoad = numOfElems / loadsPerThread;
for (int m = 0; m < numRepM; ++m) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(aElemTy, elemsPerLoad);
Value loadOffset = offsets[m * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value vectorValue = load(loadAddress);
if (numOfElems > 1) {
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(aElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(aElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
for (int nonK = 0; nonK < numRepNonK; ++nonK) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(elemTy, elemsPerLoad);
Value loadOffset = offsets[nonK * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value loadedValue = load(loadAddress);
if (loadsPerThread > 1) {
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(elemTy, loadedValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = loadedValue;
}
if (aElemTy == i8_ty && numOfElems == 4)
valVec = bitcast(valVec, i32_ty);
if (aElemTy == i8_ty && numOfElems == 8)
valVec = bitcast(valVec, i64_ty);
ha.push_back(valVec);
}
loadedValues.push_back(valVec);
}
}
MLIRContext *ctx = mfmaLayout.getContext();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(ha.size(), ha[0].getType()));
auto result = typeConverter->packLLElements(loc, ha, rewriter, structTy);
ctx, SmallVector<Type>(loadedValues.size(), loadedValues[0].getType()));
auto result =
typeConverter->packLLElements(loc, loadedValues, rewriter, structTy);
return result;
}
Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value thread,
DotOperandEncodingAttr encoding,
TritonGPUToLLVMTypeConverter *typeConverter, Value tensor,
const SharedMemoryObject &smemObj) {
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
auto nonKDim = mfmaLayout.getNonKDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto bTensorTy = tensor.getType().cast<RankedTensorType>();
ArrayRef<int64_t> shape = bTensorTy.getShape();
auto sharedLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
auto order = sharedLayout.getOrder();
auto bElemTy = bTensorTy.getElementType();
auto bElemsPerInstr = encoding.getMFMAElemsPerInstr();
auto mfmaInstrK = bElemsPerInstr[0];
auto mfmaInstrN = bElemsPerInstr[1];
auto numReps = encoding.getMFMARep(shape, bElemTy);
auto numRepK = numReps[0];
auto numRepN = numReps[1];
unsigned iWaveSize = triton::gpu::getWarpSize(mfmaLayout);
assert(iWaveSize == 64);
Value waveSize = i32_val(iWaveSize);
Value wave = udiv(thread, waveSize);
Value lane = urem(thread, waveSize);
Value waveN =
getWaveN(rewriter, loc, wave, warpsPerCTA, mfmaInstrN, shape[1]);
int numOfElems = mfmaInstrK * mfmaInstrN / iWaveSize;
assert(numOfElems >= 1);
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
bElemTy = typeConverter->convertType(bElemTy);
SmallVector<Value> hb;
if (fastPathAvailable(smemObj, sharedLayout, mfmaLayout)) {
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
llvm::SmallVector<Value> offsets;
unsigned int maxNumWarps = shape[1] / mfmaInstrN;
int warpsPerGroupN = std::min(warpsPerCTA[1], maxNumWarps);
if (isTransposed(order)) {
SmallVector<int64_t> elemsPerInstr{mfmaInstrN, mfmaInstrK};
SmallVector<int64_t> reps{numReps[1], numReps[0]};
offsets = fastPathComputeOffsetsTy1(rewriter, loc, elemsPerInstr, waveN,
lane, warpsPerGroupN, numOfElems,
reps, cSwizzleOffset);
} else {
offsets = fastPathComputeOffsetsTy2(rewriter, loc, bElemsPerInstr, waveN,
lane, warpsPerGroupN, numOfElems,
numReps, cSwizzleOffset);
}
Value smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter);
Type resElemTy = typeConverter->convertType(bElemTy);
Type smemPtrTy = getShemPtrTy(bElemTy);
const int loadsPerThread = offsets.size() / (numRepN * numRepK);
const int elemsPerLoad = numOfElems / loadsPerThread;
assert(numOfElems % loadsPerThread == 0);
for (int n = 0; n < numRepN; ++n) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
Value loadOffset =
offsets[n * loadsPerThread * numRepK + k * loadsPerThread + loadId];
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value vectorValue = load(loadAddress);
if (numOfElems > 1) {
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(bElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
}
}
if (bElemTy == i8_ty && numOfElems == 4)
valVec = bitcast(valVec, i32_ty);
if (bElemTy == i8_ty && numOfElems == 8)
valVec = bitcast(valVec, i64_ty);
hb.push_back(valVec);
}
}
} else { // normal path
llvm::SmallVector<Value> offsets = computeOffsetsBType(
rewriter, loc, bElemsPerInstr, waveN, lane, warpsPerGroupN, numOfElems,
numReps, smemObj, sharedLayout, nonKDim);
Value smemBase = computeBasePtr(rewriter, loc, smemObj);
Type resElemTy = typeConverter->convertType(bElemTy);
Type smemPtrTy = getShemPtrTy(bElemTy);
int loadsPerThread = offsets.size() / (numReps[0] * numReps[1]);
int elemsPerLoad = numOfElems / loadsPerThread;
for (int n = 0; n < numRepN; ++n) {
for (int k = 0; k < numRepK; ++k) {
auto vecTy = vec_ty(resElemTy, numOfElems);
Value valVec = undef(vecTy);
for (unsigned loadId = 0; loadId < loadsPerThread; ++loadId) {
auto loadVecTy = vec_ty(bElemTy, elemsPerLoad);
Value loadOffset = offsets[n * loadsPerThread * numRepK +
k * loadsPerThread + loadId];
Value loadAddress = bitcast(gep(smemPtrTy, smemBase, loadOffset),
getShemPtrTy(loadVecTy));
Value vectorValue = load(loadAddress);
if (numOfElems > 1) {
for (int elemId = 0; elemId < elemsPerLoad; ++elemId) {
Value elemVal =
extract_element(bElemTy, vectorValue, i32_val(elemId));
elemVal = bitcast(elemVal, resElemTy);
valVec = insert_element(vecTy, valVec, elemVal,
i32_val(loadId * elemsPerLoad + elemId));
}
} else {
valVec = extract_element(bElemTy, vectorValue, i32_val(0));
valVec = bitcast(valVec, resElemTy);
}
}
if (bElemTy == i8_ty && numOfElems == 4)
valVec = bitcast(valVec, i32_ty);
if (bElemTy == i8_ty && numOfElems == 8)
valVec = bitcast(valVec, i64_ty);
hb.push_back(valVec);
}
}
}
MLIRContext *ctx = mfmaLayout.getContext();
Type structTy = LLVM::LLVMStructType::getLiteral(
ctx, SmallVector<Type>(hb.size(), hb[0].getType()));
auto result = typeConverter->packLLElements(loc, hb, rewriter, structTy);
return result;
}
Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
Location loc, Value tensor, DotOperandEncodingAttr encoding,
const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
switch (opIdx) {
case 0:
// operand $a
return loadA(rewriter, loc, thread, encoding, typeConverter, tensor,
smemObj);
case 1:
// operand $b
return loadB(rewriter, loc, thread, encoding, typeConverter, tensor,
smemObj);
default:
assert(false && "unexpected operand idx");
return Value();
}
}
} // namespace SharedToDotOperandMFMA
#endif // ifdef USE_ROCM

View File

@@ -380,8 +380,11 @@ struct DotOpMFMAConversionHelper {
auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto repA = aEncoding.getMFMARep(aTensorTy.getShape(), elemTy);
auto repB = bEncoding.getMFMARep(bTensorTy.getShape(), elemTy);
auto kWidth = aEncoding.getKWidth();
assert(kWidth == bEncoding.getKWidth());
auto repA = aEncoding.getMFMARep(aTensorTy.getShape());
auto repB = bEncoding.getMFMARep(bTensorTy.getShape());
assert(repA[1] == repB[0]);
@@ -394,9 +397,9 @@ struct DotOpMFMAConversionHelper {
auto numRepK = repA[1];
ValueTable ha = getValuesFromDotOperandLayoutStruct(
loadedA, numRepM, numRepK, aTensorTy.getElementType());
loadedA, numRepM, numRepK, kWidth, aTensorTy.getElementType());
ValueTable hb = getValuesFromDotOperandLayoutStruct(
loadedB, numRepN, numRepK, aTensorTy.getElementType());
loadedB, numRepN, numRepK, kWidth, aTensorTy.getElementType());
auto dstElemTy = dTensorTy.getElementType();
auto fc =
typeConverter->unpackLLElements(loc, loadedC, rewriter, dstElemTy);
@@ -441,13 +444,29 @@ struct DotOpMFMAConversionHelper {
return success();
}
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1,
/**
* @brief Converts dot operand structure to value table and converts types appropriate for mfma instructions
*/
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1, int kWidth,
Type type) const {
auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type);
ValueTable vals;
for (int i = 0; i < n0; i++) {
for (int j = 0; j < n1; j++) {
vals[{i, j}] = elems[n1 * i + j];
auto rawElems = elems[n1 * i + j];
Value convertedElems;
if (type.isF32()) {
convertedElems = extract_element(type, rawElems, i32_val(0));
} else if (type.getIntOrFloatBitWidth() == 8) {
if (kWidth == 4)
convertedElems = bitcast(rawElems, i32_ty);
if (kWidth == 8)
convertedElems = bitcast(rawElems, i64_ty);
} else {
assert(type.isBF16() || type.isF16());
convertedElems = rawElems;
}
vals[{i, j}] = convertedElems;
}
}
return vals;

View File

@@ -165,16 +165,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct(
#ifdef USE_ROCM
if (auto mfmaParent = dotOpLayout.getParent().dyn_cast<MfmaEncodingAttr>()) {
if (elemTy.isF32())
return elemTy;
if (elemTy.isInteger(16)) // aka BF16
return vec_ty(elemTy, dotOpLayout.getKWidth());
if (elemTy.isF16())
return vec_ty(elemTy, 4);
if (elemTy.isInteger(8) && dotOpLayout.getKWidth() == 4)
return IntegerType::get(ctx, 32);
if (elemTy.isInteger(8) && dotOpLayout.getKWidth() == 8)
return IntegerType::get(ctx, 64);
return vec_ty(elemTy, dotOpLayout.getKWidth());
}
#endif

View File

@@ -1003,8 +1003,7 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
}
SmallVector<int64_t>
DotOperandEncodingAttr::getMFMARep(ArrayRef<int64_t> operandShape,
Type elemType) const {
DotOperandEncodingAttr::getMFMARep(ArrayRef<int64_t> operandShape) const {
auto operandTileShape = getMFMAElemsPerInstr();
auto warpsPerCTA = getParent().cast<MfmaEncodingAttr>().getWarpsPerCTA();
if (getOpIdx() == 0)
@@ -1033,7 +1032,7 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
int warpsPerCTAN = mfmaParent.getWarpsPerCTA()[1];
constexpr int waveSize = 64;
auto tileSize = getMFMAElemsPerInstr();
auto rep = getMFMARep(shape, eltTy);
auto rep = getMFMARep(shape);
return rep[0] * rep[1];
}
auto shapePerCTA = getShapePerCTA(*this, shape);