[Triton] Mfma16 support (#251)

* [MFAM] Support mfma with NM size 16

This PR code emitting of MFMA instructions with size 16.

* add control over mfma type with MFMA_TYPE=16 env var
This commit is contained in:
Alexander Efimov
2023-10-09 20:59:54 +02:00
committed by GitHub
parent e801638b40
commit 7e34c244c2
11 changed files with 252 additions and 107 deletions

View File

@@ -103,10 +103,19 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
return {8, 4};
}
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
if (mfmaLayout.getIsTransposed()) {
return {32, 2};
unsigned rows, cols;
if (mfmaLayout.getNonKDim() == 32) {
cols = 2;
rows = 32;
} else {
return {2, 32};
cols = 4;
rows = 16;
}
if (mfmaLayout.getIsTransposed()) {
return {rows, cols};
} else {
return {cols, rows};
}
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
@@ -228,10 +237,20 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
llvm_unreachable("Unexpected mma version");
}
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
unsigned rows, cols;
if (mfmaLayout.getNonKDim() == 32) {
rows = 16;
cols = 1;
} else if (mfmaLayout.getNonKDim() == 16) {
rows = 4;
cols = 1;
} else
llvm_unreachable("Unexpected mfma non-k dim");
if (mfmaLayout.getIsTransposed()) {
return {1, 16};
return {cols, rows};
} else {
return {16, 1};
return {rows, cols};
}
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
@@ -320,8 +339,13 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
} else
assert(0 && "Unimplemented usage of MmaEncodingAttr");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
2 * mfmaLayout.getWarpsPerCTA()[1]};
if (mfmaLayout.getNonKDim() == 32) {
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
2 * mfmaLayout.getWarpsPerCTA()[1]};
} else {
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
4 * mfmaLayout.getWarpsPerCTA()[1]};
}
} else {
assert(0 && "Unimplemented usage of getThreadsPerCTA");
}
@@ -359,8 +383,9 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
}
assert(0 && "Unexpected MMA layout version found");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
return {32 * mfmaLayout.getWarpsPerCTA()[0],
32 * mfmaLayout.getWarpsPerCTA()[1]};
auto nonKDim = mfmaLayout.getNonKDim();
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -818,14 +843,20 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
assert(rank == 2 && "Unexpected rank of mfma layout");
SmallVector<unsigned> elemsPerThread(rank);
auto nonKDim = getNonKDim();
auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16);
if (getIsTransposed()) {
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]) * 16;
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]);
unsigned elemsCol =
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
elemsPerThreadPerTile;
unsigned elemsRow = ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]);
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
} else {
unsigned elemsCol = ceil<unsigned>(shape[1], 32 * getWarpsPerCTA()[1]);
unsigned elemsRow = ceil<unsigned>(shape[0], 32 * getWarpsPerCTA()[0]) * 16;
unsigned elemsCol = ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]);
unsigned elemsRow =
ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
elemsPerThreadPerTile;
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
}
@@ -953,11 +984,13 @@ SmallVector<int64_t>
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
int64_t nonKDim = mfmaEncoding.getNonKDim();
int64_t kDim = getKWidth();
assert(nonKDim == 32 || nonKDim == 16);
int64_t kWidth = getKWidth();
int64_t kDim = kWidth * (nonKDim == 32 ? 2 : 4);
if (getOpIdx() == 0)
return {nonKDim, kDim*2};
return {nonKDim, kDim};
else
return {kDim*2, nonKDim};
return {kDim, nonKDim};
}
SmallVector<int64_t>

View File

@@ -160,25 +160,42 @@ public:
/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
/// @param mfmaVersion
/// @param nonKDim
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot, int mfmaVersion) const {
int64_t nonKDim = 32;
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot,
int mfmaVersion,
int64_t nonKDim) const {
// number of matrix elements along k dim per one MFMA intruction
int64_t kDim = -1;
auto opType = dot.getA().getType().cast<RankedTensorType>();
auto elemType = opType.getElementType();
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
kDim = 8;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 4;
if (mfmaVersion == 2)
if (nonKDim == 32) {
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
kDim = 8;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 4;
if (mfmaVersion == 2)
kDim = 8;
}
if (elemType.isInteger(8))
kDim = 8;
} else {
if (elemType.isF32())
kDim = 4;
if (elemType.isF16())
kDim = 16;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 8;
if (mfmaVersion == 2)
kDim = 16;
}
if (elemType.isInteger(8))
kDim = 16;
}
if (elemType.isInteger(8))
kDim = 8;
assert(kDim != -1);
return {nonKDim, kDim};
}
@@ -193,7 +210,17 @@ public:
!oldRetType.getEncoding().isa<ttg::BlockedEncodingAttr>())
return failure();
if (!supportMFMA(dotOp))
// TODO replace with nonKDim with some heuristic in chooseMfmaDimensions
// function
int64_t externalNonKDim = 32;
const char *mfmaType = std::getenv("MFMA_TYPE");
if (mfmaType) {
externalNonKDim = std::stol(mfmaType);
assert(externalNonKDim == 32 || externalNonKDim == 16);
}
if (!supportMFMA(dotOp, externalNonKDim))
return failure();
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
@@ -212,7 +239,8 @@ public:
ttg::MfmaEncodingAttr mfmaEnc;
auto [nonKDim, kDim] = chooseMfmaDimensions(dotOp, mfmaVersion);
auto [nonKDim, kDim] =
chooseMfmaDimensions(dotOp, mfmaVersion, externalNonKDim);
auto warpsPerTile = warpsPerTileMI200(dotOp, retShape, numWarps);
@@ -239,7 +267,15 @@ public:
.getOrder();
// kWidth is a number of consecutive elements per one instruction per one thread
auto kWidth = kDim / 2;
auto kWidth = kDim;
// in mfma 32x32 case argument matrix groups elements in 2 groups
// in mfma 16x16 case argument matrix groups elements in 4 groups
if (nonKDim == 32) {
kWidth /= 2;
} else {
assert(nonKDim == 16);
kWidth /= 4;
}
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
ttg::DotOperandEncodingAttr::get(ctx, 0, mfmaEnc, kWidth));