mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Triton] Mfma16 support (#251)
* [MFAM] Support mfma with NM size 16 This PR code emitting of MFMA instructions with size 16. * add control over mfma type with MFMA_TYPE=16 env var
This commit is contained in:
@@ -103,10 +103,19 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
return {8, 4};
|
||||
}
|
||||
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
return {32, 2};
|
||||
unsigned rows, cols;
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
cols = 2;
|
||||
rows = 32;
|
||||
} else {
|
||||
return {2, 32};
|
||||
cols = 4;
|
||||
rows = 16;
|
||||
}
|
||||
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
return {rows, cols};
|
||||
} else {
|
||||
return {cols, rows};
|
||||
}
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
@@ -228,10 +237,20 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
llvm_unreachable("Unexpected mma version");
|
||||
}
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
unsigned rows, cols;
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
rows = 16;
|
||||
cols = 1;
|
||||
} else if (mfmaLayout.getNonKDim() == 16) {
|
||||
rows = 4;
|
||||
cols = 1;
|
||||
} else
|
||||
llvm_unreachable("Unexpected mfma non-k dim");
|
||||
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
return {1, 16};
|
||||
return {cols, rows};
|
||||
} else {
|
||||
return {16, 1};
|
||||
return {rows, cols};
|
||||
}
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
@@ -320,8 +339,13 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
} else
|
||||
assert(0 && "Unimplemented usage of MmaEncodingAttr");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
2 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
2 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
} else {
|
||||
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
4 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
}
|
||||
} else {
|
||||
assert(0 && "Unimplemented usage of getThreadsPerCTA");
|
||||
}
|
||||
@@ -359,8 +383,9 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
}
|
||||
assert(0 && "Unexpected MMA layout version found");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
return {32 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
32 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
|
||||
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
@@ -818,14 +843,20 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
|
||||
assert(rank == 2 && "Unexpected rank of mfma layout");
|
||||
|
||||
SmallVector<unsigned> elemsPerThread(rank);
|
||||
auto nonKDim = getNonKDim();
|
||||
auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16);
|
||||
if (getIsTransposed()) {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]) * 16;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]);
|
||||
unsigned elemsCol =
|
||||
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
|
||||
elemsPerThreadPerTile;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]);
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
} else {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]);
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]) * 16;
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]);
|
||||
unsigned elemsRow =
|
||||
ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
|
||||
elemsPerThreadPerTile;
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
}
|
||||
@@ -953,11 +984,13 @@ SmallVector<int64_t>
|
||||
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
|
||||
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
|
||||
int64_t nonKDim = mfmaEncoding.getNonKDim();
|
||||
int64_t kDim = getKWidth();
|
||||
assert(nonKDim == 32 || nonKDim == 16);
|
||||
int64_t kWidth = getKWidth();
|
||||
int64_t kDim = kWidth * (nonKDim == 32 ? 2 : 4);
|
||||
if (getOpIdx() == 0)
|
||||
return {nonKDim, kDim*2};
|
||||
return {nonKDim, kDim};
|
||||
else
|
||||
return {kDim*2, nonKDim};
|
||||
return {kDim, nonKDim};
|
||||
}
|
||||
|
||||
SmallVector<int64_t>
|
||||
|
||||
@@ -160,25 +160,42 @@ public:
|
||||
/// @brief Choose MFMA instruction parameters
|
||||
/// @param dot target dot operation
|
||||
/// @param mfmaVersion
|
||||
/// @param nonKDim
|
||||
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
|
||||
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot, int mfmaVersion) const {
|
||||
int64_t nonKDim = 32;
|
||||
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot,
|
||||
int mfmaVersion,
|
||||
int64_t nonKDim) const {
|
||||
// number of matrix elements along k dim per one MFMA intruction
|
||||
int64_t kDim = -1;
|
||||
auto opType = dot.getA().getType().cast<RankedTensorType>();
|
||||
auto elemType = opType.getElementType();
|
||||
if (elemType.isF32())
|
||||
kDim = 2;
|
||||
if (elemType.isF16())
|
||||
kDim = 8;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 4;
|
||||
if (mfmaVersion == 2)
|
||||
if (nonKDim == 32) {
|
||||
if (elemType.isF32())
|
||||
kDim = 2;
|
||||
if (elemType.isF16())
|
||||
kDim = 8;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 4;
|
||||
if (mfmaVersion == 2)
|
||||
kDim = 8;
|
||||
}
|
||||
if (elemType.isInteger(8))
|
||||
kDim = 8;
|
||||
} else {
|
||||
if (elemType.isF32())
|
||||
kDim = 4;
|
||||
if (elemType.isF16())
|
||||
kDim = 16;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 8;
|
||||
if (mfmaVersion == 2)
|
||||
kDim = 16;
|
||||
}
|
||||
if (elemType.isInteger(8))
|
||||
kDim = 16;
|
||||
}
|
||||
if (elemType.isInteger(8))
|
||||
kDim = 8;
|
||||
assert(kDim != -1);
|
||||
return {nonKDim, kDim};
|
||||
}
|
||||
@@ -193,7 +210,17 @@ public:
|
||||
!oldRetType.getEncoding().isa<ttg::BlockedEncodingAttr>())
|
||||
return failure();
|
||||
|
||||
if (!supportMFMA(dotOp))
|
||||
// TODO replace with nonKDim with some heuristic in chooseMfmaDimensions
|
||||
// function
|
||||
int64_t externalNonKDim = 32;
|
||||
|
||||
const char *mfmaType = std::getenv("MFMA_TYPE");
|
||||
if (mfmaType) {
|
||||
externalNonKDim = std::stol(mfmaType);
|
||||
assert(externalNonKDim == 32 || externalNonKDim == 16);
|
||||
}
|
||||
|
||||
if (!supportMFMA(dotOp, externalNonKDim))
|
||||
return failure();
|
||||
|
||||
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
|
||||
@@ -212,7 +239,8 @@ public:
|
||||
|
||||
ttg::MfmaEncodingAttr mfmaEnc;
|
||||
|
||||
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp, mfmaVersion);
|
||||
auto [nonKDim, kDim] =
|
||||
chooseMfmaDimensions(dotOp, mfmaVersion, externalNonKDim);
|
||||
|
||||
auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps);
|
||||
|
||||
@@ -239,7 +267,15 @@ public:
|
||||
.getOrder();
|
||||
|
||||
// kWidth is a number of consecutive elements per one instruction per one thread
|
||||
auto kWidth = kDim / 2;
|
||||
auto kWidth = kDim;
|
||||
// in mfma 32x32 case argument matrix groups elements in 2 groups
|
||||
// in mfma 16x16 case argument matrix groups elements in 4 groups
|
||||
if (nonKDim == 32) {
|
||||
kWidth /= 2;
|
||||
} else {
|
||||
assert(nonKDim == 16);
|
||||
kWidth /= 4;
|
||||
}
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));
|
||||
|
||||
Reference in New Issue
Block a user