mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[MFMA] Switch between MFMA types (#352)
This PR introduces matrix_instr_nonkdim flag to switch between MFMA 16 and MFMA 32.
This commit is contained in:
@@ -378,18 +378,21 @@ bool supportMMA(triton::DotOp op, int version) {
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool supportMFMAGranularity(int m, int n, int k, int64_t nonKDim) {
|
||||
static bool supportMFMAGranularity(int m, int n, int k) {
|
||||
// these limitations are dtype dependent, in future we may relax them
|
||||
const int granularityMN = nonKDim;
|
||||
const int granularityK = nonKDim == 32 ? 8 : 16;
|
||||
if (m % granularityMN != 0 || n % granularityMN != 0)
|
||||
return false;
|
||||
if (k % granularityK != 0)
|
||||
return false;
|
||||
return true;
|
||||
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
|
||||
for (const auto &mfmaType : mfmaTypes) {
|
||||
auto [granularityMN, granularityK] = mfmaType;
|
||||
if (m % granularityMN != 0 || n % granularityMN != 0)
|
||||
continue;
|
||||
if (k % granularityK != 0)
|
||||
continue;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
|
||||
bool supportMFMA(triton::DotOp op) {
|
||||
auto aTy = op.getA().getType().cast<RankedTensorType>();
|
||||
auto bTy = op.getB().getType().cast<RankedTensorType>();
|
||||
|
||||
@@ -403,7 +406,7 @@ bool supportMFMA(triton::DotOp op, int64_t nonKDim) {
|
||||
auto bShape = bTy.getShape();
|
||||
|
||||
assert(aShape[1] == bShape[0]);
|
||||
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1], nonKDim))
|
||||
if (!supportMFMAGranularity(aShape[0], bShape[1], aShape[1]))
|
||||
return false;
|
||||
|
||||
return aElemTy.isF16() || aElemTy.isBF16() || aElemTy.isF32() ||
|
||||
|
||||
Reference in New Issue
Block a user