mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] Refactor dot pipeline to reduce code duplication (#400)
This PR: - simplifies data types generated by `shared->mfma dot op` layout conversions. Do not pack data types in int32 or int64 - reduce code duplication between fast/normal path - reduce code duplication between operand A and operand B Co-authored-by: Shucai Xiao <shucai.xiao@amd.com> Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
This commit is contained in:
@@ -1003,8 +1003,7 @@ DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
DotOperandEncodingAttr::getMFMARep(ArrayRef<int64_t> operandShape,
|
||||
Type elemType) const {
|
||||
DotOperandEncodingAttr::getMFMARep(ArrayRef<int64_t> operandShape) const {
|
||||
auto operandTileShape = getMFMAElemsPerInstr();
|
||||
auto warpsPerCTA = getParent().cast<MfmaEncodingAttr>().getWarpsPerCTA();
|
||||
if (getOpIdx() == 0)
|
||||
@@ -1033,7 +1032,7 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
int warpsPerCTAN = mfmaParent.getWarpsPerCTA()[1];
|
||||
constexpr int waveSize = 64;
|
||||
auto tileSize = getMFMAElemsPerInstr();
|
||||
auto rep = getMFMARep(shape, eltTy);
|
||||
auto rep = getMFMARep(shape);
|
||||
return rep[0] * rep[1];
|
||||
}
|
||||
auto shapePerCTA = getShapePerCTA(*this, shape);
|
||||
|
||||
Reference in New Issue
Block a user