mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Backend] Refactor mfma selection (#441)
* Select mfma dimensions and instruction from static table * Extend mfmaLayout to include version and instrShape * Simplify generateMFMAOp by searching the mfma instruction in the table * Fix getNonKDim() and non_k_dim * Break instrShape into MDim and NDim
This commit is contained in:
@@ -612,7 +612,8 @@ bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
||||
dotOperandLayout.getOpIdx() == 0 &&
|
||||
dotOperandLayout.getKWidth() == 4 &&
|
||||
dotOperandLayout.getParent() == mfmaLayout &&
|
||||
(mfmaLayout.getNonKDim() == 32 || mfmaLayout.getNonKDim() == 16) && mfmaLayout.getIsTransposed() &&
|
||||
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
|
||||
mfmaLayout.getIsTransposed() &&
|
||||
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -395,7 +395,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
||||
int nonKDimIdx = opIdx == 0 ? 0 : 1;
|
||||
|
||||
auto mfmaLayout = encoding.getParent().cast<MfmaEncodingAttr>();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
int nonKDim = mfmaLayout.getMDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
|
||||
#include "../DotOpToLLVM.h"
|
||||
#include "../Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
|
||||
@@ -36,27 +37,6 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::MfmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
enum class MatrixCoreType : uint8_t {
|
||||
// D = AB + C
|
||||
FP32_FP8_FP8_FP32,
|
||||
FP32_FP8_BF8_FP32,
|
||||
FP32_BF8_FP8_FP32,
|
||||
FP32_BF8_BF8_FP32,
|
||||
FP32_FP16_FP16_FP32,
|
||||
FP32_BF16_BF16_FP32,
|
||||
FP32_BF16_BF16_FP32_1K,
|
||||
FP32_FP32_FP32_FP32,
|
||||
FP64_FP64_FP64_FP64,
|
||||
INT32_INT8_INT8_INT32,
|
||||
INT32_INT8_INT8_INT32_CDNA3,
|
||||
NOT_APPLICABLE,
|
||||
};
|
||||
|
||||
struct MFMAInstrDescr {
|
||||
MatrixCoreType coreType;
|
||||
unsigned size;
|
||||
};
|
||||
|
||||
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
|
||||
|
||||
struct DotOpMFMAConversionHelper {
|
||||
@@ -80,159 +60,14 @@ struct DotOpMFMAConversionHelper {
|
||||
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
|
||||
}
|
||||
|
||||
Value generateMFMA32Op(MatrixCoreType coreType, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
switch (coreType) {
|
||||
case MatrixCoreType::FP32_FP8_FP8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP8_BF8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_fp8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF8_FP8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF8_BF8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x16_bf8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP16_FP16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x4bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x8bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_32x32x2f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32:
|
||||
return rewriter.create<ROCDL::mfma_i32_32x32x8i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
|
||||
return rewriter.create<ROCDL::mfma_i32_32x32x16_i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP64_FP64_FP64_FP64:
|
||||
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA 32x32 data type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
Value generateMFMA16Op(MatrixCoreType coreType, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
switch (coreType) {
|
||||
case MatrixCoreType::FP32_FP8_FP8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP8_BF8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_fp8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF8_FP8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_fp8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF8_BF8_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x32_bf8_bf8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP16_FP16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x8bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x16bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_16x16x4f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32:
|
||||
return rewriter.create<ROCDL::mfma_i32_16x16x16i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3:
|
||||
return rewriter.create<ROCDL::mfma_i32_16x16x32_i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP64_FP64_FP64_FP64:
|
||||
return rewriter.create<ROCDL::mfma_f64_16x16x4f64>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA data type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
Value generateMFMA4Op(MatrixCoreType coreType, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
switch (coreType) {
|
||||
case MatrixCoreType::FP32_FP16_FP16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x4f16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x2bf16>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_BF16_BF16_FP32_1K:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x4bf16_1k>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::FP32_FP32_FP32_FP32:
|
||||
return rewriter.create<ROCDL::mfma_f32_4x4x1f32>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
case MatrixCoreType::INT32_INT8_INT8_INT32:
|
||||
return rewriter.create<ROCDL::mfma_i32_4x4x4i8>(
|
||||
loc, TypeRange{resType},
|
||||
ValueRange{valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA4 data type not supported");
|
||||
}
|
||||
}
|
||||
|
||||
Value generateMFMAOp(MFMAInstrDescr mfmaDescr, Value valA, Value valB,
|
||||
Value generateMFMAOp(StringRef mfmaInsnName, Value valA, Value valB,
|
||||
Value valC) const {
|
||||
switch (mfmaDescr.size) {
|
||||
case 32:
|
||||
return generateMFMA32Op(mfmaDescr.coreType, valA, valB, valC);
|
||||
break;
|
||||
case 16:
|
||||
return generateMFMA16Op(mfmaDescr.coreType, valA, valB, valC);
|
||||
break;
|
||||
case 4:
|
||||
return generateMFMA4Op(mfmaDescr.coreType, valA, valB, valC);
|
||||
default:
|
||||
llvm::report_fatal_error("MFMA nonkDim size is not supported");
|
||||
}
|
||||
return Value();
|
||||
auto resType = valC.getType();
|
||||
Value zeroFlag = i32_val(0);
|
||||
OperationState loweredOp(loc, mfmaInsnName);
|
||||
loweredOp.addTypes(resType);
|
||||
loweredOp.addOperands({valA, valB, valC, zeroFlag, zeroFlag, zeroFlag});
|
||||
return rewriter.create(loweredOp)->getResult(0);
|
||||
}
|
||||
|
||||
int getNumSubmatrices(Type elementType, int nonKDim) const {
|
||||
@@ -254,64 +89,6 @@ struct DotOpMFMAConversionHelper {
|
||||
return -1;
|
||||
}
|
||||
|
||||
// TODO unify this function with Utility.cpp:supportMFMATypes
|
||||
static MatrixCoreType getMatrixCoreTypeFromDot(DotOp op) {
|
||||
auto aOperandTy = op.getA().getType();
|
||||
auto aTensorTy = aOperandTy.cast<RankedTensorType>();
|
||||
auto aElemTy = aTensorTy.getElementType();
|
||||
auto bOperandTy = op.getB().getType();
|
||||
auto bTensorTy = bOperandTy.cast<RankedTensorType>();
|
||||
auto bElemTy = bTensorTy.getElementType();
|
||||
|
||||
auto dotOpEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto mfmaEncoding = dotOpEncoding.getParent().cast<MfmaEncodingAttr>();
|
||||
if (aElemTy.isFloat8E4M3FNUZ() && bElemTy.isFloat8E4M3FNUZ())
|
||||
return MatrixCoreType::FP32_FP8_FP8_FP32;
|
||||
if (aElemTy.isFloat8E4M3FNUZ() && bElemTy.isFloat8E5M2FNUZ())
|
||||
return MatrixCoreType::FP32_FP8_BF8_FP32;
|
||||
if (aElemTy.isFloat8E5M2FNUZ() && bElemTy.isFloat8E4M3FNUZ())
|
||||
return MatrixCoreType::FP32_BF8_FP8_FP32;
|
||||
if (aElemTy.isFloat8E5M2FNUZ() && bElemTy.isFloat8E5M2FNUZ())
|
||||
return MatrixCoreType::FP32_BF8_BF8_FP32;
|
||||
if (aElemTy.isF16())
|
||||
return MatrixCoreType::FP32_FP16_FP16_FP32;
|
||||
if (aElemTy.isF32())
|
||||
return MatrixCoreType::FP32_FP32_FP32_FP32;
|
||||
if (aElemTy.isBF16()) {
|
||||
auto nonKDim = mfmaEncoding.getNonKDim();
|
||||
auto kWidth = dotOpEncoding.getKWidth();
|
||||
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4) {
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32_1K;
|
||||
} else {
|
||||
assert((nonKDim == 32 && kWidth == 2) ||
|
||||
(nonKDim == 16 && kWidth == 2) || (nonKDim == 4 && kWidth == 2));
|
||||
return MatrixCoreType::FP32_BF16_BF16_FP32;
|
||||
}
|
||||
}
|
||||
if (aElemTy.isInteger(8)) {
|
||||
auto nonKDim = mfmaEncoding.getNonKDim();
|
||||
auto kWidth = dotOpEncoding.getKWidth();
|
||||
if ((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 8) {
|
||||
return MatrixCoreType::INT32_INT8_INT8_INT32_CDNA3;
|
||||
} else {
|
||||
assert((nonKDim == 32 || nonKDim == 16 || nonKDim == 4) && kWidth == 4);
|
||||
return MatrixCoreType::INT32_INT8_INT8_INT32;
|
||||
}
|
||||
}
|
||||
if (aElemTy.isF64())
|
||||
return MatrixCoreType::FP64_FP64_FP64_FP64;
|
||||
return MatrixCoreType::NOT_APPLICABLE;
|
||||
}
|
||||
|
||||
static MFMAInstrDescr getMatrixInstrDescr(DotOp op) {
|
||||
MFMAInstrDescr descr;
|
||||
auto tensorTy = op.getD().getType().cast<RankedTensorType>();
|
||||
auto encoding = tensorTy.getEncoding().cast<MfmaEncodingAttr>();
|
||||
descr.coreType = getMatrixCoreTypeFromDot(op);
|
||||
descr.size = encoding.getNonKDim();
|
||||
return descr;
|
||||
}
|
||||
|
||||
Value processSubBlocks(int numSubBlocks, Value acc, bool reduceSubBlocks,
|
||||
bool zeroSubBlocks) const {
|
||||
assert((numSubBlocks & (numSubBlocks - 1)) == 0 &&
|
||||
@@ -385,9 +162,9 @@ struct DotOpMFMAConversionHelper {
|
||||
// Conduct the Dot conversion.
|
||||
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
|
||||
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
auto nonKDim = mfmaLayout.getMDim();
|
||||
auto mfmaVersion = mfmaLayout.getVersionMajor();
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
auto mfmaInstrDescr = getMatrixInstrDescr(op);
|
||||
|
||||
Value a = op.getA();
|
||||
Value b = op.getB();
|
||||
@@ -395,7 +172,16 @@ struct DotOpMFMAConversionHelper {
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
auto elemTy = aTensorTy.getElementType();
|
||||
auto elemTyA = aTensorTy.getElementType();
|
||||
auto elemTyB = bTensorTy.getElementType();
|
||||
|
||||
StringRef mfmaInsnName;
|
||||
auto maybeMfmaInsn =
|
||||
MfmaInsn::selectMfma(nonKDim, elemTyA, elemTyB, mfmaVersion);
|
||||
if (failed(maybeMfmaInsn))
|
||||
llvm::report_fatal_error("No match found in MFMA database\n");
|
||||
else
|
||||
mfmaInsnName = (*maybeMfmaInsn).getInsnName();
|
||||
|
||||
auto aEncoding = aTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto bEncoding = bTensorTy.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
@@ -442,10 +228,9 @@ struct DotOpMFMAConversionHelper {
|
||||
}
|
||||
acc = zeroAuxiliarBlocks(subBlocks, acc);
|
||||
for (size_t k = 0; k < numRepK; k++) {
|
||||
acc =
|
||||
mfmaLayout.getIsTransposed()
|
||||
? generateMFMAOp(mfmaInstrDescr, hb[{n, k}], ha[{m, k}], acc)
|
||||
: generateMFMAOp(mfmaInstrDescr, ha[{m, k}], hb[{n, k}], acc);
|
||||
acc = mfmaLayout.getIsTransposed()
|
||||
? generateMFMAOp(mfmaInsnName, hb[{n, k}], ha[{m, k}], acc)
|
||||
: generateMFMAOp(mfmaInsnName, ha[{m, k}], hb[{n, k}], acc);
|
||||
}
|
||||
acc = reduceSubBlocks(subBlocks, acc);
|
||||
for (unsigned v = 0; v < elemsPerVec; ++v) {
|
||||
|
||||
@@ -772,11 +772,11 @@ public:
|
||||
void emitMfmaOffsetForCTA(const MfmaEncodingAttr &mfmaLayout,
|
||||
SmallVector<SmallVector<unsigned>> &offsets,
|
||||
unsigned ctaOffsetX, unsigned ctaOffsetY) const {
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
int mfmaMDim = mfmaLayout.getMDim();
|
||||
// MFMA output tile consists of repeated "dot operand B" layout groups along
|
||||
// row axis. This variable defines number of these groups.
|
||||
DenseMap<int, int> groups{{4, 1}, {16, 1}, {32, 4}};
|
||||
unsigned numGroups = groups.at(nonKDim);
|
||||
unsigned numGroups = groups.at(mfmaMDim);
|
||||
|
||||
const unsigned elemsPerThreadPerGroup = 4;
|
||||
auto warpSize = getWarpSize(mfmaLayout);
|
||||
@@ -784,7 +784,7 @@ public:
|
||||
auto shapePerCta = getShapePerCTATile(mfmaLayout);
|
||||
for (unsigned block = 0; block < numGroups; block++) {
|
||||
unsigned rowOrColOffset =
|
||||
block * elemsPerThreadPerGroup * warpSize / nonKDim;
|
||||
block * elemsPerThreadPerGroup * warpSize / mfmaMDim;
|
||||
for (unsigned elem = 0; elem < elemsPerThreadPerGroup; elem++) {
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
offsets.push_back(
|
||||
@@ -1191,12 +1191,12 @@ private:
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
|
||||
i32_val(_warpsPerCTA[1])};
|
||||
int nonKDim = mfmaLayout.getNonKDim();
|
||||
int mfmaMDim = mfmaLayout.getMDim();
|
||||
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
Value warpSize = i32_val(triton::gpu::getWarpSize(mfmaLayout));
|
||||
Value effectiveWarpSize = warpSize;
|
||||
if (nonKDim == 4) {
|
||||
if (mfmaMDim == 4) {
|
||||
const int uniqueValuesPerWarp = 4;
|
||||
effectiveWarpSize = i32_val(uniqueValuesPerWarp);
|
||||
}
|
||||
@@ -1204,22 +1204,22 @@ private:
|
||||
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 =
|
||||
urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / nonKDim));
|
||||
urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / mfmaMDim));
|
||||
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
|
||||
i32_val(shape[1] / nonKDim));
|
||||
i32_val(shape[1] / mfmaMDim));
|
||||
|
||||
Value offWarp0 = mul(warpId0, i32_val(nonKDim));
|
||||
Value offWarp1 = mul(warpId1, i32_val(nonKDim));
|
||||
Value offWarp0 = mul(warpId0, i32_val(mfmaMDim));
|
||||
Value offWarp1 = mul(warpId1, i32_val(mfmaMDim));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
if (mfmaLayout.getIsTransposed()) {
|
||||
multiDimBase[1] =
|
||||
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp1);
|
||||
multiDimBase[0] = add(urem(laneId, i32_val(nonKDim)), offWarp0);
|
||||
add(mul(i32_val(4), udiv(laneId, i32_val(mfmaMDim))), offWarp1);
|
||||
multiDimBase[0] = add(urem(laneId, i32_val(mfmaMDim)), offWarp0);
|
||||
} else {
|
||||
multiDimBase[0] =
|
||||
add(mul(i32_val(4), udiv(laneId, i32_val(nonKDim))), offWarp0);
|
||||
multiDimBase[1] = add(urem(laneId, i32_val(nonKDim)), offWarp1);
|
||||
add(mul(i32_val(4), udiv(laneId, i32_val(mfmaMDim))), offWarp0);
|
||||
multiDimBase[1] = add(urem(laneId, i32_val(mfmaMDim)), offWarp1);
|
||||
}
|
||||
return multiDimBase;
|
||||
}
|
||||
@@ -1236,7 +1236,7 @@ private:
|
||||
for (unsigned d = 0; d < 2; ++d) {
|
||||
unsigned inPerCTA = std::min<unsigned>(tensorShape[d], shapePerCTA[d]);
|
||||
unsigned inPerWarp = ceil<unsigned>(inPerCTA, warpsPerCTA[d]);
|
||||
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getNonKDim());
|
||||
numWarpsPerDim[d] = ceil<unsigned>(inPerWarp, mfmaLayout.getMDim());
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < numWarpsPerDim[0]; ++i) {
|
||||
|
||||
@@ -771,7 +771,8 @@ private:
|
||||
auto srcMfma =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::MfmaEncodingAttr>();
|
||||
auto newMfmaEnc = triton::gpu::MfmaEncodingAttr::get(
|
||||
mod.getContext(), srcMfma.getNonKDim(), {warpsPerCtaX, warpsPerCtaY},
|
||||
mod.getContext(), srcMfma.getVersionMajor(), srcMfma.getVersionMinor(),
|
||||
{warpsPerCtaX, warpsPerCtaY}, srcMfma.getMDim(), srcMfma.getNDim(),
|
||||
srcMfma.getIsTransposed());
|
||||
|
||||
auto newDstType = RankedTensorType::get(
|
||||
|
||||
@@ -104,7 +104,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
}
|
||||
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
unsigned rows, cols;
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
int mfmaMDim = mfmaLayout.getMDim();
|
||||
if (32 == mfmaMDim) {
|
||||
cols = 2;
|
||||
rows = 32;
|
||||
} else {
|
||||
@@ -240,7 +241,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
}
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
unsigned rows, cols;
|
||||
switch (mfmaLayout.getNonKDim()) {
|
||||
switch (mfmaLayout.getMDim()) {
|
||||
case 32:
|
||||
rows = 16;
|
||||
cols = 1;
|
||||
@@ -349,13 +350,19 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
} else
|
||||
llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
2 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
int mfmaMDim = mfmaLayout.getMDim();
|
||||
SmallVector<unsigned> threadsPerWarp;
|
||||
if (32 == mfmaMDim) {
|
||||
threadsPerWarp = {2, 32};
|
||||
} else {
|
||||
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
4 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
threadsPerWarp = {4, 16};
|
||||
}
|
||||
if (mfmaLayout.getIsTransposed())
|
||||
threads = {threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[0],
|
||||
threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
else
|
||||
threads = {threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[0],
|
||||
threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
} else {
|
||||
llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA");
|
||||
}
|
||||
@@ -393,9 +400,10 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
}
|
||||
llvm::report_fatal_error("Unexpected MMA layout version found");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
|
||||
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
auto mfmaMDim = mfmaLayout.getMDim();
|
||||
auto mfmaNDim = mfmaLayout.getNDim();
|
||||
return {mfmaMDim * mfmaLayout.getWarpsPerCTA()[0],
|
||||
mfmaNDim * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
@@ -719,11 +727,11 @@ bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
|
||||
}
|
||||
|
||||
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) {
|
||||
auto nonKDimA = mfmaA.getNonKDim();
|
||||
auto nonKDimA = mfmaA.getMDim();
|
||||
auto warpsPerCTAA = mfmaA.getWarpsPerCTA();
|
||||
auto isTransposedA = mfmaA.getIsTransposed();
|
||||
|
||||
auto nonKDimB = mfmaB.getNonKDim();
|
||||
auto nonKDimB = mfmaB.getMDim();
|
||||
auto warpsPerCTAB = mfmaB.getWarpsPerCTA();
|
||||
auto isTransposedB = mfmaB.getIsTransposed();
|
||||
|
||||
@@ -913,19 +921,22 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
|
||||
assert(rank == 2 && "Unexpected rank of mfma layout");
|
||||
|
||||
SmallVector<unsigned> elemsPerThread(rank);
|
||||
auto nonKDim = getNonKDim();
|
||||
auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16);
|
||||
auto mfmaMDim = getMDim();
|
||||
auto mfmaNDim = getNDim();
|
||||
auto elemsPerThreadPerTile = (mfmaMDim == 32 ? 16 : 4);
|
||||
if (getIsTransposed()) {
|
||||
unsigned elemsCol =
|
||||
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
|
||||
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]) *
|
||||
elemsPerThreadPerTile;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]);
|
||||
unsigned elemsRow =
|
||||
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]);
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
} else {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]);
|
||||
unsigned elemsCol =
|
||||
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]);
|
||||
unsigned elemsRow =
|
||||
ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
|
||||
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]) *
|
||||
elemsPerThreadPerTile;
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
@@ -1053,7 +1064,7 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
|
||||
SmallVector<int64_t>
|
||||
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
|
||||
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
|
||||
int64_t nonKDim = mfmaEncoding.getNonKDim();
|
||||
int64_t nonKDim = mfmaEncoding.getMDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
int64_t kWidth = getKWidth();
|
||||
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
|
||||
@@ -1367,32 +1378,46 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned nonKDim = 0;
|
||||
unsigned versionMajor = 0;
|
||||
unsigned versionMinor = 0;
|
||||
SmallVector<unsigned> warpsPerCTA;
|
||||
SmallVector<unsigned> instrShape;
|
||||
bool isTransposed;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "nonKDim") {
|
||||
if (parseUInt(parser, attr, nonKDim, "nonKDim").failed())
|
||||
if (attr.getName() == "versionMajor") {
|
||||
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "versionMinor") {
|
||||
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "warpsPerCTA") {
|
||||
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "isTransposed") {
|
||||
}
|
||||
if (attr.getName() == "instrShape") {
|
||||
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "isTransposed") {
|
||||
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
|
||||
warpsPerCTA, isTransposed);
|
||||
return parser.getChecked<MfmaEncodingAttr>(
|
||||
parser.getContext(), versionMajor, versionMinor, warpsPerCTA,
|
||||
instrShape[0], instrShape[1], isTransposed);
|
||||
}
|
||||
|
||||
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "nonKDim = " << getNonKDim() << ", "
|
||||
<< "version = " << getVersionMajor() << "." << getVersionMinor()
|
||||
<< ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
|
||||
<< "instrShape = [" << getMDim() << ", " << getNDim() << "], "
|
||||
<< "isTransposed = " << getIsTransposed() << "}>";
|
||||
}
|
||||
|
||||
|
||||
@@ -133,20 +133,23 @@ public:
|
||||
/// @brief Choose MFMA instruction parameters
|
||||
/// @param dot target dot operation
|
||||
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
|
||||
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot) const {
|
||||
std::pair<unsigned, unsigned> chooseMfmaDimensions(tt::DotOp dot) const {
|
||||
// number of matrix elements along k dim per one MFMA intruction
|
||||
int64_t kDim = -1;
|
||||
unsigned kDim = 0;
|
||||
auto opType = dot.getA().getType().cast<RankedTensorType>();
|
||||
auto elemType = opType.getElementType();
|
||||
|
||||
auto dataTypeA = opType.getElementType();
|
||||
auto dataTypeB =
|
||||
dot.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
|
||||
auto resType = dot.getD().getType().cast<RankedTensorType>();
|
||||
auto resShape = resType.getShape();
|
||||
|
||||
int64_t nonKDim = -1;
|
||||
unsigned nonKDim = 0;
|
||||
if (enforcedNonKDim != 0) {
|
||||
nonKDim = enforcedNonKDim;
|
||||
} else {
|
||||
nonKDim = -1;
|
||||
nonKDim = 0;
|
||||
int minSize = std::min(resShape[0], resShape[1]);
|
||||
if (minSize >= 32)
|
||||
nonKDim = 32;
|
||||
@@ -154,77 +157,17 @@ public:
|
||||
nonKDim = 16;
|
||||
if (minSize < 16)
|
||||
nonKDim = 4;
|
||||
assert(nonKDim != -1);
|
||||
assert(nonKDim != 0);
|
||||
}
|
||||
switch (nonKDim) {
|
||||
case 32:
|
||||
if (elemType.isF32())
|
||||
kDim = 2;
|
||||
if (elemType.isF16())
|
||||
kDim = 8;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 4;
|
||||
if (mfmaVersion >= 2)
|
||||
kDim = 8;
|
||||
}
|
||||
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
|
||||
assert(mfmaVersion == 3);
|
||||
kDim = 16;
|
||||
}
|
||||
if (elemType.isInteger(8)) {
|
||||
if (mfmaVersion == 3) {
|
||||
kDim = 16;
|
||||
}
|
||||
else {
|
||||
kDim = 8;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 16:
|
||||
if (elemType.isF32())
|
||||
kDim = 4;
|
||||
if (elemType.isF16())
|
||||
kDim = 16;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 8;
|
||||
if (mfmaVersion >= 2)
|
||||
kDim = 16;
|
||||
}
|
||||
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
|
||||
assert(mfmaVersion == 3);
|
||||
kDim = 32;
|
||||
}
|
||||
if (elemType.isInteger(8)) {
|
||||
if (mfmaVersion == 3) {
|
||||
kDim = 32;
|
||||
}
|
||||
else {
|
||||
kDim = 16;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
if (elemType.isF32())
|
||||
kDim = 16;
|
||||
if (elemType.isF16())
|
||||
kDim = 64;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 32;
|
||||
if (mfmaVersion >= 2)
|
||||
kDim = 64;
|
||||
}
|
||||
if (elemType.isInteger(8)) {
|
||||
kDim = 64;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
llvm::report_fatal_error("unsupported nonKDim size in MFMA dot");
|
||||
}
|
||||
assert(kDim != -1);
|
||||
assert(nonKDim != -1);
|
||||
|
||||
auto maybeMfmaInsn =
|
||||
MfmaInsn::selectMfma(nonKDim, dataTypeA, dataTypeB, mfmaVersion);
|
||||
if (failed(maybeMfmaInsn))
|
||||
llvm::report_fatal_error("No match found in MFMA database\n");
|
||||
else
|
||||
kDim = (*maybeMfmaInsn).getKDim();
|
||||
assert(kDim != 0);
|
||||
assert(nonKDim != 0);
|
||||
assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0);
|
||||
assert(opType.getShape()[1] % kDim == 0);
|
||||
return {nonKDim, kDim};
|
||||
@@ -268,8 +211,10 @@ public:
|
||||
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
|
||||
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
|
||||
warpsPerTile, isTransposed);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(
|
||||
oldRetType.getContext(),
|
||||
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
|
||||
/*instrShape*/ nonKDim, nonKDim, isTransposed);
|
||||
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);
|
||||
|
||||
@@ -635,4 +635,180 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
|
||||
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
|
||||
}
|
||||
|
||||
// mfma instruction selection logic
|
||||
static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) {
|
||||
if (dataTypeA.isF32() && dataTypeB.isF32()) {
|
||||
return MfmaTypeId::Fp32TyId;
|
||||
}
|
||||
if (dataTypeA.isF16() && dataTypeB.isF16()) {
|
||||
return MfmaTypeId::Fp16TyId;
|
||||
}
|
||||
if (dataTypeA.isBF16() && dataTypeB.isBF16()) {
|
||||
return MfmaTypeId::Bf16TyId;
|
||||
}
|
||||
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
|
||||
return MfmaTypeId::I8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
|
||||
return MfmaTypeId::Fp8Fp8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
|
||||
return MfmaTypeId::Fp8Bf8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
|
||||
return MfmaTypeId::Bf8Fp8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
|
||||
return MfmaTypeId::Bf8Bf8TyId;
|
||||
}
|
||||
llvm_unreachable("Unsupported input argument type.");
|
||||
}
|
||||
|
||||
using MfmaInsnGroupMap = llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnAttr,
|
||||
MfmaInsnGroupSelectKeyInfo>;
|
||||
|
||||
auto getMfmaInsnGroupAttrMap = []() -> const MfmaInsnGroupMap & {
|
||||
static MfmaInsnGroupMap MfmaInsnMap{
|
||||
// f32
|
||||
// mfma_f32_32x32x2f32
|
||||
{{32, MfmaTypeId::Fp32TyId, 1},
|
||||
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp32TyId, 2},
|
||||
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp32TyId, 3},
|
||||
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
|
||||
// mfma_f32_16x16x4f32
|
||||
{{16, MfmaTypeId::Fp32TyId, 1},
|
||||
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp32TyId, 2},
|
||||
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp32TyId, 3},
|
||||
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
|
||||
// mfma_f32_4x4x1f32
|
||||
{{4, MfmaTypeId::Fp32TyId, 1},
|
||||
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
|
||||
{{4, MfmaTypeId::Fp32TyId, 2},
|
||||
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
|
||||
// mfma_f32_4x4x1_16B_f32
|
||||
{{4, MfmaTypeId::Fp32TyId, 3},
|
||||
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
|
||||
// f16
|
||||
// mfma_f32_32x32x8f16
|
||||
{{32, MfmaTypeId::Fp16TyId, 1},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp16TyId, 2},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp16TyId, 3},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
|
||||
// mfma_f32_16x16x16xf16
|
||||
{{16, MfmaTypeId::Fp16TyId, 1},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp16TyId, 2},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp16TyId, 3},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
|
||||
// mfma_f32_4x4x4f16
|
||||
{{4, MfmaTypeId::Fp16TyId, 1},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
|
||||
{{4, MfmaTypeId::Fp16TyId, 2},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
|
||||
{{4, MfmaTypeId::Fp16TyId, 3},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
|
||||
// bf16
|
||||
// mfma_f32_32x32x4_bf16
|
||||
{{32, MfmaTypeId::Bf16TyId, 1},
|
||||
{32, 32, 4, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}},
|
||||
// mfma_f32_32x32x8_bf16_1K
|
||||
{{32, MfmaTypeId::Bf16TyId, 2},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
|
||||
{{32, MfmaTypeId::Bf16TyId, 3},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
|
||||
// mfma_f32_16x16x8_bf16
|
||||
{{16, MfmaTypeId::Bf16TyId, 1},
|
||||
{16, 16, 8, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}},
|
||||
// mfma_f32_16x16x16_bf16_1K
|
||||
{{16, MfmaTypeId::Bf16TyId, 2},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
|
||||
{{16, MfmaTypeId::Bf16TyId, 3},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
|
||||
// mfma_f32_4x4x2_bf16
|
||||
{{4, MfmaTypeId::Bf16TyId, 1},
|
||||
{4, 4, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}},
|
||||
// mfma_f32_4x4x4_bf16_1K
|
||||
{{4, MfmaTypeId::Bf16TyId, 2},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
|
||||
{{4, MfmaTypeId::Bf16TyId, 3},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
|
||||
// int8
|
||||
// mfma_f32_32x32x8i8
|
||||
{{32, MfmaTypeId::I8TyId, 1},
|
||||
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
|
||||
{{32, MfmaTypeId::I8TyId, 2},
|
||||
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
|
||||
// mfma_f32_32x32x16i8
|
||||
{{32, MfmaTypeId::I8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}},
|
||||
// mfma_f32_16x16x16i8
|
||||
{{16, MfmaTypeId::I8TyId, 1},
|
||||
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
|
||||
{{16, MfmaTypeId::I8TyId, 2},
|
||||
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
|
||||
// mfma_f32_16x16x32i8
|
||||
{{16, MfmaTypeId::I8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}},
|
||||
// mfma_f32_4x4x4i8
|
||||
{{4, MfmaTypeId::I8TyId, 1},
|
||||
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
|
||||
{{4, MfmaTypeId::I8TyId, 2},
|
||||
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
|
||||
{{4, MfmaTypeId::I8TyId, 3},
|
||||
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
|
||||
// fp8 * pf8
|
||||
// mfma_f32_32x32x16_FP8_FP8
|
||||
{{32, MfmaTypeId::Fp8Fp8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_FP8_FP8
|
||||
{{16, MfmaTypeId::Fp8Fp8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}},
|
||||
// mfma_f32_32x32x16_FP8_BF8
|
||||
{{32, MfmaTypeId::Fp8Bf8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_FP8_BF8
|
||||
{{16, MfmaTypeId::Fp8Bf8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}},
|
||||
// mfma_f32_32x32x16_BF8_FP8
|
||||
{{32, MfmaTypeId::Bf8Fp8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_BF8_FP8
|
||||
{{16, MfmaTypeId::Bf8Fp8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}},
|
||||
// mfma_f32_32x32x16_BF8_BF8
|
||||
{{32, MfmaTypeId::Bf8Bf8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_BF8_BF8
|
||||
{{16, MfmaTypeId::Bf8Bf8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}};
|
||||
return MfmaInsnMap;
|
||||
};
|
||||
|
||||
FailureOr<MfmaInsn> MfmaInsn::selectMfma(unsigned nonKDim, Type elementTypeA,
|
||||
Type elementTypeB, int mfmaVersion) {
|
||||
auto mfmaInsnAttrMap = getMfmaInsnGroupAttrMap();
|
||||
MfmaInsnGroupSelectKey key = {
|
||||
nonKDim, convertTypesToId(elementTypeA, elementTypeB), mfmaVersion};
|
||||
auto it = mfmaInsnAttrMap.find(key);
|
||||
if (it == mfmaInsnAttrMap.end())
|
||||
return failure();
|
||||
return MfmaInsn(elementTypeA, elementTypeB, (*it).second);
|
||||
}
|
||||
|
||||
MfmaInsn::MfmaInsn(Type elementTypeA, Type elementTypeB,
|
||||
const MfmaInsnAttr &attr)
|
||||
: elementTypeA(elementTypeA), elementTypeB(elementTypeB), attr(attr) {}
|
||||
|
||||
unsigned MfmaInsn::getKDim() { return attr.k; }
|
||||
unsigned MfmaInsn::getMDim() { return attr.m; }
|
||||
unsigned MfmaInsn::getNDim() { return attr.n; }
|
||||
StringRef MfmaInsn::getInsnName() { return attr.insn; }
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user