[MFMA] Switch between MFMA types (#352)

This PR introduces matrix_instr_nonkdim flag to switch
between MFMA 16 and MFMA 32.
This commit is contained in:
Alexander Efimov
2023-10-18 16:57:34 +02:00
committed by GitHub
parent 4d539d7dae
commit 20f316b19a
14 changed files with 371 additions and 247 deletions

View File

@@ -378,18 +378,21 @@ bool supportMMA(triton::DotOp op, int version) {
}
#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) {
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const int granularityMN = nonKDim;
const int granularityK = nonKDim == 32 ? 8 : 16;
if (m % granularityMN != 0 || n % granularityMN != 0)
return false;
if (k % granularityK != 0)
return false;
return true;
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)
continue;
if (k % granularityK != 0)
continue;
return true;
}
return false;
}
bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
bool supportMFMA(triton::DotOp op) {
auto aTy = op.getA().getType().cast<RankedTensorType>();
auto bTy = op.getB().getType().cast<RankedTensorType>();
@@ -403,7 +406,7 @@ bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
auto bShape = bTy.getShape();
assert(aShape[1] == bShape[0]);
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim))
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
return false;
return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||