[MFMA] Support tile size 4x4 version 1 (#413)

This PR enables 4x4 tile size in MFMA based dot operations.

Supported tiled dot is (4x64) x (64x4) -> (4x4) in MFMA layout.
However, actual dot operation should have at least 64 output elements, this is a limitation of other layouts appearing during result processing (i.e. blocked layout can not handle tensors smaller than wavesize).

For example, following dots are supported: (4x64) x (64x16) -> (4x16), (16x64) x (64x4) -> (16x4) or (8x64) x (64x8) -> (8x8)
Following dots are not supporter: (4x128) x (128x4) -> (4x4), (4x64) x (64x8) -> (4x8)

This is a first version of dot using mfma 4x4 instructions, with redundancy and reductions.
This commit is contained in:
Alexander Efimov
2023-12-12 18:23:55 +01:00
committed by GitHub
parent a944811b6d
commit 605a90c58e
10 changed files with 332 additions and 129 deletions

View File

@@ -426,7 +426,7 @@ bool supportMMA(triton::DotOp op, int version) {
#ifdef USE_ROCM
static bool supportMFMAGranularity(int m, int n, int k) {
// these limitations are dtype dependent, in future we may relax them
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
const static std::pair<int, int> mfmaTypes[] = {{32, 8}, {16, 16}, {4, 64}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)