[MFMA] Refactor dot pipeline to reduce code duplication (#400)

This PR:
- simplifies data types generated by `shared->mfma dot op` layout conversions. Do not pack data types in int32 or int64
- reduce code duplication between fast/normal path
- reduce code duplication between operand A and operand B

Co-authored-by: Shucai Xiao <shucai.xiao@amd.com>
Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
Alexander Efimov
2023-12-13 22:33:02 +01:00
committed by GitHub
parent 605a90c58e
commit f2afd65e8c
6 changed files with 132 additions and 308 deletions

View File

@@ -1003,8 +1003,7 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
}
SmallVector<int64_t>
DotOperandEncodingAttr::getMFMARep(ArrayRef<int64_t> operandShape,
Type elemType) const {
DotOperandEncodingAttr::getMFMARep(ArrayRef<int64_t> operandShape) const {
auto operandTileShape = getMFMAElemsPerInstr();
auto warpsPerCTA = getParent().cast<MfmaEncodingAttr>().getWarpsPerCTA();
if (getOpIdx() == 0)
@@ -1033,7 +1032,7 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
int warpsPerCTAN = mfmaParent.getWarpsPerCTA()[1];
constexpr int waveSize = 64;
auto tileSize = getMFMAElemsPerInstr();
auto rep = getMFMARep(shape, eltTy);
auto rep = getMFMARep(shape);
return rep[0] * rep[1];
}
auto shapePerCTA = getShapePerCTA(*this, shape);