[Backend] Refactor mfma selection (#441)

* Select mfma dimensions and instruction from static table

* Extend mfmaLayout to include version and instrShape

* Simplify generateMFMAOp by searching the mfma instruction in the table

* Fix getNonKDim() and non_k_dim

* Break instrShape into MDim and NDim
This commit is contained in:
Lixun Zhang
2024-01-16 21:05:35 -06:00
committed by GitHub
parent d2f8bc1740
commit 02a2f24dd5
15 changed files with 457 additions and 411 deletions

View File

@@ -612,7 +612,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getKWidth() == 4 &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getNonKDim() == 32 || mfmaLayout.getNonKDim() == 16) && mfmaLayout.getIsTransposed() &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
mfmaLayout.getIsTransposed() &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
#endif

View File

@@ -395,7 +395,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
int nonKDimIdx = opIdx == 0 ? 0 : 1;
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
auto nonKDim = mfmaLayout.getNonKDim();
int nonKDim = mfmaLayout.getMDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();

View File

@@ -24,6 +24,7 @@
#include "../DotOpToLLVM.h"
#include "../Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
@@ -36,27 +37,6 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
using ::mlir::triton::gpu::MfmaEncodingAttr;
using ::mlir::triton::gpu::SharedEncodingAttr;
enum class MatrixCoreType : uint8_t {
// D = AB + C
FP32_FP8_FP8_FP32,
FP32_FP8_BF8_FP32,
FP32_BF8_FP8_FP32,
FP32_BF8_BF8_FP32,
FP32_FP16_FP16_FP32,
FP32_BF16_BF16_FP32,
FP32_BF16_BF16_FP32_1K,
FP32_FP32_FP32_FP32,
FP64_FP64_FP64_FP64,
INT32_INT8_INT8_INT32,
INT32_INT8_INT8_INT32_CDNA3,
NOT_APPLICABLE,
};
struct MFMAInstrDescr {
MatrixCoreType coreType;
unsigned size;
};
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
struct DotOpMFMAConversionHelper {
@@ -80,159 +60,14 @@ struct DotOpMFMAConversionHelper {
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
}
Value generateMFMA32Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
default:
llvm::report_fatal_error("MFMA 32x32 data type not supported");
}
}
Value generateMFMA16Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_FP8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_fp8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF8_BF8_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_bf8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
return rewriter.create<ROCDL::mfma_i32_16x16x32_i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP64_FP64_FP64_FP64:
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
default:
llvm::report_fatal_error("MFMA data type not supported");
}
}
Value generateMFMA4Op(MatrixCoreType coreType, Value valA, Value valB,
Value valC) const {
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
switch (coreType) {
case MatrixCoreType::FP32_FP16_FP16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x4f16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x2bf16>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
return rewriter.create<ROCDL::mfma_f32_4x4x4bf16_1k>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::FP32_FP32_FP32_FP32:
return rewriter.create<ROCDL::mfma_f32_4x4x1f32>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
case MatrixCoreType::INT32_INT8_INT8_INT32:
return rewriter.create<ROCDL::mfma_i32_4x4x4i8>(
loc, TypeRange{resType},
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
default:
llvm::report_fatal_error("MFMA4 data type not supported");
}
}
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB,
Value valC) const {
switch (mfmaDescr.size) {
case 32:
return generateMFMA32Op(mfmaDescr.coreType, valA, valB, valC);
break;
case 16:
return generateMFMA16Op(mfmaDescr.coreType, valA, valB, valC);
break;
case 4:
return generateMFMA4Op(mfmaDescr.coreType, valA, valB, valC);
default:
llvm::report_fatal_error("MFMA nonkDim size is not supported");
}
return Value();
auto resType = valC.getType();
Value zeroFlag = i32_val(0);
OperationState loweredOp(loc, mfmaInsnName);
loweredOp.addTypes(resType);
loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
return rewriter.create(loweredOp)->getResult(0);
}
int getNumSubmatrices(Type elementType, int nonKDim) const {
@@ -254,64 +89,6 @@ struct DotOpMFMAConversionHelper {
return -1;
}
// TODO unify this function with Utility.cpp:supportMFMATypes
static MatrixCoreType getMatrixCoreTypeFromDot(DotOp op) {
auto aOperandTy = op.getA().getType();
auto aTensorTy = aOperandTy.cast<RankedTensorType>();
auto aElemTy = aTensorTy.getElementType();
auto bOperandTy = op.getB().getType();
auto bTensorTy = bOperandTy.cast<RankedTensorType>();
auto bElemTy = bTensorTy.getElementType();
auto dotOpEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto mfmaEncoding = dotOpEncoding.getParent().cast<MfmaEncodingAttr>();
if (aElemTy.isFloat8E4M3FNUZ() && bElemTy.isFloat8E4M3FNUZ())
return MatrixCoreType::FP32_FP8_FP8_FP32;
if (aElemTy.isFloat8E4M3FNUZ() && bElemTy.isFloat8E5M2FNUZ())
return MatrixCoreType::FP32_FP8_BF8_FP32;
if (aElemTy.isFloat8E5M2FNUZ() && bElemTy.isFloat8E4M3FNUZ())
return MatrixCoreType::FP32_BF8_FP8_FP32;
if (aElemTy.isFloat8E5M2FNUZ() && bElemTy.isFloat8E5M2FNUZ())
return MatrixCoreType::FP32_BF8_BF8_FP32;
if (aElemTy.isF16())
return MatrixCoreType::FP32_FP16_FP16_FP32;
if (aElemTy.isF32())
return MatrixCoreType::FP32_FP32_FP32_FP32;
if (aElemTy.isBF16()) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4) {
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
} else {
assert((nonKDim == 32 && kWidth == 2) ||
(nonKDim == 16 && kWidth == 2) || (nonKDim == 4 && kWidth == 2));
return MatrixCoreType::FP32_BF16_BF16_FP32;
}
}
if (aElemTy.isInteger(8)) {
auto nonKDim = mfmaEncoding.getNonKDim();
auto kWidth = dotOpEncoding.getKWidth();
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 8) {
return MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3;
} else {
assert((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4);
return MatrixCoreType::INT32_INT8_INT8_INT32;
}
}
if (aElemTy.isF64())
return MatrixCoreType::FP64_FP64_FP64_FP64;
return MatrixCoreType::NOT_APPLICABLE;
}
static MFMAInstrDescr getMatrixInstrDescr(DotOp op) {
MFMAInstrDescr descr;
auto tensorTy = op.getD().getType().cast<RankedTensorType>();
auto encoding = tensorTy.getEncoding().cast<MfmaEncodingAttr>();
descr.coreType = getMatrixCoreTypeFromDot(op);
descr.size = encoding.getNonKDim();
return descr;
}
Value processSubBlocks(int numSubBlocks, Value acc, bool reduceSubBlocks,
bool zeroSubBlocks) const {
assert((numSubBlocks & (numSubBlocks - 1)) == 0 &&
@@ -385,9 +162,9 @@ struct DotOpMFMAConversionHelper {
// Conduct the Dot conversion.
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto nonKDim = mfmaLayout.getNonKDim();
auto nonKDim = mfmaLayout.getMDim();
auto mfmaVersion = mfmaLayout.getVersionMajor();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
auto mfmaInstrDescr = getMatrixInstrDescr(op);
Value a = op.getA();
Value b = op.getB();
@@ -395,7 +172,16 @@ struct DotOpMFMAConversionHelper {
auto aTensorTy = a.getType().cast<RankedTensorType>();
auto bTensorTy = b.getType().cast<RankedTensorType>();
auto dTensorTy = d.getType().cast<RankedTensorType>();
auto elemTy = aTensorTy.getElementType();
auto elemTyA = aTensorTy.getElementType();
auto elemTyB = bTensorTy.getElementType();
StringRef mfmaInsnName;
auto maybeMfmaInsn =
MfmaInsn::selectMfma(nonKDim, elemTyA, elemTyB, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");
else
mfmaInsnName = (*maybeMfmaInsn).getInsnName();
auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
@@ -442,10 +228,9 @@ struct DotOpMFMAConversionHelper {
}
acc = zeroAuxiliarBlocks(subBlocks, acc);
for (size_t k = 0; k < numRepK; k++) {
acc =
mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInstrDescr, hb[{n, k}], ha[{m, k}], acc)
: generateMFMAOp(mfmaInstrDescr, ha[{m, k}], hb[{n, k}], acc);
acc = mfmaLayout.getIsTransposed()
? generateMFMAOp(mfmaInsnName, hb[{n, k}], ha[{m, k}], acc)
: generateMFMAOp(mfmaInsnName, ha[{m, k}], hb[{n, k}], acc);
}
acc = reduceSubBlocks(subBlocks, acc);
for (unsigned v = 0; v < elemsPerVec; ++v) {

View File

@@ -772,11 +772,11 @@ public:
void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout,
SmallVector<SmallVector<unsigned>> &offsets,
unsigned ctaOffsetX, unsigned ctaOffsetY) const {
auto nonKDim = mfmaLayout.getNonKDim();
int mfmaMDim = mfmaLayout.getMDim();
// MFMA output tile consists of repeated "dot operand B" layout groups along
// row axis. This variable defines number of these groups.
DenseMap<int, int> groups{{4, 1}, {16, 1}, {32, 4}};
unsigned numGroups = groups.at(nonKDim);
unsigned numGroups = groups.at(mfmaMDim);
const unsigned elemsPerThreadPerGroup = 4;
auto warpSize = getWarpSize(mfmaLayout);
@@ -784,7 +784,7 @@ public:
auto shapePerCta = getShapePerCTATile(mfmaLayout);
for (unsigned block = 0; block < numGroups; block++) {
unsigned rowOrColOffset =
block * elemsPerThreadPerGroup * warpSize / nonKDim;
block * elemsPerThreadPerGroup * warpSize / mfmaMDim;
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
if (mfmaLayout.getIsTransposed()) {
offsets.push_back(
@@ -1191,12 +1191,12 @@ private:
assert(_warpsPerCTA.size() == 2);
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
i32_val(_warpsPerCTA[1])};
int nonKDim = mfmaLayout.getNonKDim();
int mfmaMDim = mfmaLayout.getMDim();
Value threadId = getThreadId(rewriter, loc);
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
Value effectiveWarpSize = warpSize;
if (nonKDim == 4) {
if (mfmaMDim == 4) {
const int uniqueValuesPerWarp = 4;
effectiveWarpSize = i32_val(uniqueValuesPerWarp);
}
@@ -1204,22 +1204,22 @@ private:
Value warpId = udiv(threadId, warpSize);
Value warpId0 =
urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / nonKDim));
urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / mfmaMDim));
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
i32_val(shape[1] / nonKDim));
i32_val(shape[1] / mfmaMDim));
Value offWarp0 = mul(warpId0, i32_val(nonKDim));
Value offWarp1 = mul(warpId1, i32_val(nonKDim));
Value offWarp0 = mul(warpId0, i32_val(mfmaMDim));
Value offWarp1 = mul(warpId1, i32_val(mfmaMDim));
SmallVector<Value> multiDimBase(2);
if (mfmaLayout.getIsTransposed()) {
multiDimBase[1] =
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp1);
multiDimBase[0] = add(urem(laneId, i32_val(nonKDim)), offWarp0);
add(mul(i32_val(4), udiv(laneId, i32_val(mfmaMDim))), offWarp1);
multiDimBase[0] = add(urem(laneId, i32_val(mfmaMDim)), offWarp0);
} else {
multiDimBase[0] =
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp0);
multiDimBase[1] = add(urem(laneId, i32_val(nonKDim)), offWarp1);
add(mul(i32_val(4), udiv(laneId, i32_val(mfmaMDim))), offWarp0);
multiDimBase[1] = add(urem(laneId, i32_val(mfmaMDim)), offWarp1);
}
return multiDimBase;
}
@@ -1236,7 +1236,7 @@ private:
for (unsigned d = 0; d < 2; ++d) {
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getNonKDim());
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getMDim());
}
for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {

View File

@@ -771,7 +771,8 @@ private:
auto srcMfma =
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY},
mod.getContext(), srcMfma.getVersionMajor(), srcMfma.getVersionMinor(),
{warpsPerCtaX, warpsPerCtaY}, srcMfma.getMDim(), srcMfma.getNDim(),
srcMfma.getIsTransposed());
auto newDstType = RankedTensorType::get(

View File

@@ -104,7 +104,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
}
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
unsigned rows, cols;
if (mfmaLayout.getNonKDim() == 32) {
int mfmaMDim = mfmaLayout.getMDim();
if (32 == mfmaMDim) {
cols = 2;
rows = 32;
} else {
@@ -240,7 +241,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
}
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
unsigned rows, cols;
switch (mfmaLayout.getNonKDim()) {
switch (mfmaLayout.getMDim()) {
case 32:
rows = 16;
cols = 1;
@@ -349,13 +350,19 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
} else
llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
if (mfmaLayout.getNonKDim() == 32) {
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
2 * mfmaLayout.getWarpsPerCTA()[1]};
int mfmaMDim = mfmaLayout.getMDim();
SmallVector<unsigned> threadsPerWarp;
if (32 == mfmaMDim) {
threadsPerWarp = {2, 32};
} else {
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
4 * mfmaLayout.getWarpsPerCTA()[1]};
threadsPerWarp = {4, 16};
}
if (mfmaLayout.getIsTransposed())
threads = {threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[0],
threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[1]};
else
threads = {threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[0],
threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[1]};
} else {
llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA");
}
@@ -393,9 +400,10 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
}
llvm::report_fatal_error("Unexpected MMA layout version found");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
auto nonKDim = mfmaLayout.getNonKDim();
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
auto mfmaMDim = mfmaLayout.getMDim();
auto mfmaNDim = mfmaLayout.getNDim();
return {mfmaMDim * mfmaLayout.getWarpsPerCTA()[0],
mfmaNDim * mfmaLayout.getWarpsPerCTA()[1]};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -719,11 +727,11 @@ bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
}
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) {
auto nonKDimA = mfmaA.getNonKDim();
auto nonKDimA = mfmaA.getMDim();
auto warpsPerCTAA = mfmaA.getWarpsPerCTA();
auto isTransposedA = mfmaA.getIsTransposed();
auto nonKDimB = mfmaB.getNonKDim();
auto nonKDimB = mfmaB.getMDim();
auto warpsPerCTAB = mfmaB.getWarpsPerCTA();
auto isTransposedB = mfmaB.getIsTransposed();
@@ -913,19 +921,22 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
assert(rank == 2 && "Unexpected rank of mfma layout");
SmallVector<unsigned> elemsPerThread(rank);
auto nonKDim = getNonKDim();
auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16);
auto mfmaMDim = getMDim();
auto mfmaNDim = getNDim();
auto elemsPerThreadPerTile = (mfmaMDim == 32 ? 16 : 4);
if (getIsTransposed()) {
unsigned elemsCol =
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]) *
elemsPerThreadPerTile;
unsigned elemsRow = ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]);
unsigned elemsRow =
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]);
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
} else {
unsigned elemsCol = ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]);
unsigned elemsCol =
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]);
unsigned elemsRow =
ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]) *
elemsPerThreadPerTile;
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
@@ -1053,7 +1064,7 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
SmallVector<int64_t>
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
int64_t nonKDim = mfmaEncoding.getNonKDim();
int64_t nonKDim = mfmaEncoding.getMDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
int64_t kWidth = getKWidth();
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
@@ -1367,32 +1378,46 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
unsigned nonKDim = 0;
unsigned versionMajor = 0;
unsigned versionMinor = 0;
SmallVector<unsigned> warpsPerCTA;
SmallVector<unsigned> instrShape;
bool isTransposed;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "nonKDim") {
if (parseUInt(parser, attr, nonKDim, "nonKDim").failed())
if (attr.getName() == "versionMajor") {
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
return {};
}
if (attr.getName() == "versionMinor") {
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
} else if (attr.getName() == "isTransposed") {
}
if (attr.getName() == "instrShape") {
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
return {};
}
if (attr.getName() == "isTransposed") {
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
return {};
}
}
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
warpsPerCTA, isTransposed);
return parser.getChecked<MfmaEncodingAttr>(
parser.getContext(), versionMajor, versionMinor, warpsPerCTA,
instrShape[0], instrShape[1], isTransposed);
}
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "nonKDim = " << getNonKDim() << ", "
<< "version = " << getVersionMajor() << "." << getVersionMinor()
<< ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
<< "instrShape = [" << getMDim() << ", " << getNDim() << "], "
<< "isTransposed = " << getIsTransposed() << "}>";
}

View File

@@ -133,20 +133,23 @@ public:
/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot) const {
std::pair<unsigned, unsigned> chooseMfmaDimensions(tt::DotOp dot) const {
// number of matrix elements along k dim per one MFMA intruction
int64_t kDim = -1;
unsigned kDim = 0;
auto opType = dot.getA().getType().cast<RankedTensorType>();
auto elemType = opType.getElementType();
auto dataTypeA = opType.getElementType();
auto dataTypeB =
dot.getB().getType().cast<RankedTensorType>().getElementType();
auto resType = dot.getD().getType().cast<RankedTensorType>();
auto resShape = resType.getShape();
int64_t nonKDim = -1;
unsigned nonKDim = 0;
if (enforcedNonKDim != 0) {
nonKDim = enforcedNonKDim;
} else {
nonKDim = -1;
nonKDim = 0;
int minSize = std::min(resShape[0], resShape[1]);
if (minSize >= 32)
nonKDim = 32;
@@ -154,77 +157,17 @@ public:
nonKDim = 16;
if (minSize < 16)
nonKDim = 4;
assert(nonKDim != -1);
assert(nonKDim != 0);
}
switch (nonKDim) {
case 32:
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
kDim = 8;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 4;
if (mfmaVersion >= 2)
kDim = 8;
}
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
assert(mfmaVersion == 3);
kDim = 16;
}
if (elemType.isInteger(8)) {
if (mfmaVersion == 3) {
kDim = 16;
}
else {
kDim = 8;
}
}
break;
case 16:
if (elemType.isF32())
kDim = 4;
if (elemType.isF16())
kDim = 16;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 8;
if (mfmaVersion >= 2)
kDim = 16;
}
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
assert(mfmaVersion == 3);
kDim = 32;
}
if (elemType.isInteger(8)) {
if (mfmaVersion == 3) {
kDim = 32;
}
else {
kDim = 16;
}
}
break;
case 4:
if (elemType.isF32())
kDim = 16;
if (elemType.isF16())
kDim = 64;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 32;
if (mfmaVersion >= 2)
kDim = 64;
}
if (elemType.isInteger(8)) {
kDim = 64;
}
break;
default:
llvm::report_fatal_error("unsupported nonKDim size in MFMA dot");
}
assert(kDim != -1);
assert(nonKDim != -1);
auto maybeMfmaInsn =
MfmaInsn::selectMfma(nonKDim, dataTypeA, dataTypeB, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");
else
kDim = (*maybeMfmaInsn).getKDim();
assert(kDim != 0);
assert(nonKDim != 0);
assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0);
assert(opType.getShape()[1] % kDim == 0);
return {nonKDim, kDim};
@@ -268,8 +211,10 @@ public:
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
warpsPerTile, isTransposed);
mfmaEnc = ttg::MfmaEncodingAttr::get(
oldRetType.getContext(),
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
/*instrShape*/ nonKDim, nonKDim, isTransposed);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);

View File

@@ -635,4 +635,180 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
}
// mfma instruction selection logic
static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) {
if (dataTypeA.isF32() && dataTypeB.isF32()) {
return MfmaTypeId::Fp32TyId;
}
if (dataTypeA.isF16() && dataTypeB.isF16()) {
return MfmaTypeId::Fp16TyId;
}
if (dataTypeA.isBF16() && dataTypeB.isBF16()) {
return MfmaTypeId::Bf16TyId;
}
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
return MfmaTypeId::I8TyId;
}
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
return MfmaTypeId::Fp8Fp8TyId;
}
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
return MfmaTypeId::Fp8Bf8TyId;
}
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
return MfmaTypeId::Bf8Fp8TyId;
}
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
return MfmaTypeId::Bf8Bf8TyId;
}
llvm_unreachable("Unsupported input argument type.");
}
using MfmaInsnGroupMap = llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnAttr,
MfmaInsnGroupSelectKeyInfo>;
auto getMfmaInsnGroupAttrMap = []() -> const MfmaInsnGroupMap & {
static MfmaInsnGroupMap MfmaInsnMap{
// f32
// mfma_f32_32x32x2f32
{{32, MfmaTypeId::Fp32TyId, 1},
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
{{32, MfmaTypeId::Fp32TyId, 2},
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
{{32, MfmaTypeId::Fp32TyId, 3},
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
// mfma_f32_16x16x4f32
{{16, MfmaTypeId::Fp32TyId, 1},
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
{{16, MfmaTypeId::Fp32TyId, 2},
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
{{16, MfmaTypeId::Fp32TyId, 3},
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
// mfma_f32_4x4x1f32
{{4, MfmaTypeId::Fp32TyId, 1},
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
{{4, MfmaTypeId::Fp32TyId, 2},
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
// mfma_f32_4x4x1_16B_f32
{{4, MfmaTypeId::Fp32TyId, 3},
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
// f16
// mfma_f32_32x32x8f16
{{32, MfmaTypeId::Fp16TyId, 1},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
{{32, MfmaTypeId::Fp16TyId, 2},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
{{32, MfmaTypeId::Fp16TyId, 3},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
// mfma_f32_16x16x16xf16
{{16, MfmaTypeId::Fp16TyId, 1},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
{{16, MfmaTypeId::Fp16TyId, 2},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
{{16, MfmaTypeId::Fp16TyId, 3},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
// mfma_f32_4x4x4f16
{{4, MfmaTypeId::Fp16TyId, 1},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
{{4, MfmaTypeId::Fp16TyId, 2},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
{{4, MfmaTypeId::Fp16TyId, 3},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
// bf16
// mfma_f32_32x32x4_bf16
{{32, MfmaTypeId::Bf16TyId, 1},
{32, 32, 4, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}},
// mfma_f32_32x32x8_bf16_1K
{{32, MfmaTypeId::Bf16TyId, 2},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
{{32, MfmaTypeId::Bf16TyId, 3},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
// mfma_f32_16x16x8_bf16
{{16, MfmaTypeId::Bf16TyId, 1},
{16, 16, 8, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}},
// mfma_f32_16x16x16_bf16_1K
{{16, MfmaTypeId::Bf16TyId, 2},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
{{16, MfmaTypeId::Bf16TyId, 3},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
// mfma_f32_4x4x2_bf16
{{4, MfmaTypeId::Bf16TyId, 1},
{4, 4, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}},
// mfma_f32_4x4x4_bf16_1K
{{4, MfmaTypeId::Bf16TyId, 2},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
{{4, MfmaTypeId::Bf16TyId, 3},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
// int8
// mfma_f32_32x32x8i8
{{32, MfmaTypeId::I8TyId, 1},
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
{{32, MfmaTypeId::I8TyId, 2},
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
// mfma_f32_32x32x16i8
{{32, MfmaTypeId::I8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}},
// mfma_f32_16x16x16i8
{{16, MfmaTypeId::I8TyId, 1},
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
{{16, MfmaTypeId::I8TyId, 2},
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
// mfma_f32_16x16x32i8
{{16, MfmaTypeId::I8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}},
// mfma_f32_4x4x4i8
{{4, MfmaTypeId::I8TyId, 1},
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
{{4, MfmaTypeId::I8TyId, 2},
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
{{4, MfmaTypeId::I8TyId, 3},
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
// fp8 * pf8
// mfma_f32_32x32x16_FP8_FP8
{{32, MfmaTypeId::Fp8Fp8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}},
// mfma_f32_16x16x32_FP8_FP8
{{16, MfmaTypeId::Fp8Fp8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}},
// mfma_f32_32x32x16_FP8_BF8
{{32, MfmaTypeId::Fp8Bf8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}},
// mfma_f32_16x16x32_FP8_BF8
{{16, MfmaTypeId::Fp8Bf8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}},
// mfma_f32_32x32x16_BF8_FP8
{{32, MfmaTypeId::Bf8Fp8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}},
// mfma_f32_16x16x32_BF8_FP8
{{16, MfmaTypeId::Bf8Fp8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}},
// mfma_f32_32x32x16_BF8_BF8
{{32, MfmaTypeId::Bf8Bf8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}},
// mfma_f32_16x16x32_BF8_BF8
{{16, MfmaTypeId::Bf8Bf8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}};
return MfmaInsnMap;
};
FailureOr<MfmaInsn> MfmaInsn::selectMfma(unsigned nonKDim, Type elementTypeA,
Type elementTypeB, int mfmaVersion) {
auto mfmaInsnAttrMap = getMfmaInsnGroupAttrMap();
MfmaInsnGroupSelectKey key = {
nonKDim, convertTypesToId(elementTypeA, elementTypeB), mfmaVersion};
auto it = mfmaInsnAttrMap.find(key);
if (it == mfmaInsnAttrMap.end())
return failure();
return MfmaInsn(elementTypeA, elementTypeB, (*it).second);
}
MfmaInsn::MfmaInsn(Type elementTypeA, Type elementTypeB,
const MfmaInsnAttr &attr)
: elementTypeA(elementTypeA), elementTypeB(elementTypeB), attr(attr) {}
unsigned MfmaInsn::getKDim() { return attr.k; }
unsigned MfmaInsn::getMDim() { return attr.m; }
unsigned MfmaInsn::getNDim() { return attr.n; }
StringRef MfmaInsn::getInsnName() { return attr.insn; }
} // namespace mlir