mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[Backend] Refactor mfma selection (#441)
* Select mfma dimensions and instruction from static table * Extend mfmaLayout to include version and instrShape * Simplify generateMFMAOp by searching the mfma instruction in the table * Fix getNonKDim() and non_k_dim * Break instrShape into MDim and NDim
This commit is contained in:
@@ -104,7 +104,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
}
|
||||
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
unsigned rows, cols;
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
int mfmaMDim = mfmaLayout.getMDim();
|
||||
if (32 == mfmaMDim) {
|
||||
cols = 2;
|
||||
rows = 32;
|
||||
} else {
|
||||
@@ -240,7 +241,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
}
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
unsigned rows, cols;
|
||||
switch (mfmaLayout.getNonKDim()) {
|
||||
switch (mfmaLayout.getMDim()) {
|
||||
case 32:
|
||||
rows = 16;
|
||||
cols = 1;
|
||||
@@ -349,13 +350,19 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
} else
|
||||
llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
if (mfmaLayout.getNonKDim() == 32) {
|
||||
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
2 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
int mfmaMDim = mfmaLayout.getMDim();
|
||||
SmallVector<unsigned> threadsPerWarp;
|
||||
if (32 == mfmaMDim) {
|
||||
threadsPerWarp = {2, 32};
|
||||
} else {
|
||||
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
|
||||
4 * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
threadsPerWarp = {4, 16};
|
||||
}
|
||||
if (mfmaLayout.getIsTransposed())
|
||||
threads = {threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[0],
|
||||
threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
else
|
||||
threads = {threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[0],
|
||||
threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
} else {
|
||||
llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA");
|
||||
}
|
||||
@@ -393,9 +400,10 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
|
||||
}
|
||||
llvm::report_fatal_error("Unexpected MMA layout version found");
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
auto nonKDim = mfmaLayout.getNonKDim();
|
||||
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
|
||||
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
auto mfmaMDim = mfmaLayout.getMDim();
|
||||
auto mfmaNDim = mfmaLayout.getNDim();
|
||||
return {mfmaMDim * mfmaLayout.getWarpsPerCTA()[0],
|
||||
mfmaNDim * mfmaLayout.getWarpsPerCTA()[1]};
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
auto parentLayout = dotLayout.getParent();
|
||||
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
|
||||
@@ -719,11 +727,11 @@ bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
|
||||
}
|
||||
|
||||
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) {
|
||||
auto nonKDimA = mfmaA.getNonKDim();
|
||||
auto nonKDimA = mfmaA.getMDim();
|
||||
auto warpsPerCTAA = mfmaA.getWarpsPerCTA();
|
||||
auto isTransposedA = mfmaA.getIsTransposed();
|
||||
|
||||
auto nonKDimB = mfmaB.getNonKDim();
|
||||
auto nonKDimB = mfmaB.getMDim();
|
||||
auto warpsPerCTAB = mfmaB.getWarpsPerCTA();
|
||||
auto isTransposedB = mfmaB.getIsTransposed();
|
||||
|
||||
@@ -913,19 +921,22 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
|
||||
assert(rank == 2 && "Unexpected rank of mfma layout");
|
||||
|
||||
SmallVector<unsigned> elemsPerThread(rank);
|
||||
auto nonKDim = getNonKDim();
|
||||
auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16);
|
||||
auto mfmaMDim = getMDim();
|
||||
auto mfmaNDim = getNDim();
|
||||
auto elemsPerThreadPerTile = (mfmaMDim == 32 ? 16 : 4);
|
||||
if (getIsTransposed()) {
|
||||
unsigned elemsCol =
|
||||
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
|
||||
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]) *
|
||||
elemsPerThreadPerTile;
|
||||
unsigned elemsRow = ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]);
|
||||
unsigned elemsRow =
|
||||
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]);
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
} else {
|
||||
unsigned elemsCol = ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]);
|
||||
unsigned elemsCol =
|
||||
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]);
|
||||
unsigned elemsRow =
|
||||
ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
|
||||
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]) *
|
||||
elemsPerThreadPerTile;
|
||||
elemsPerThread[0] = elemsRow;
|
||||
elemsPerThread[1] = elemsCol;
|
||||
@@ -1053,7 +1064,7 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
|
||||
SmallVector<int64_t>
|
||||
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
|
||||
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
|
||||
int64_t nonKDim = mfmaEncoding.getNonKDim();
|
||||
int64_t nonKDim = mfmaEncoding.getMDim();
|
||||
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
|
||||
int64_t kWidth = getKWidth();
|
||||
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
|
||||
@@ -1367,32 +1378,46 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parser.parseGreater().failed())
|
||||
return {};
|
||||
|
||||
unsigned nonKDim = 0;
|
||||
unsigned versionMajor = 0;
|
||||
unsigned versionMinor = 0;
|
||||
SmallVector<unsigned> warpsPerCTA;
|
||||
SmallVector<unsigned> instrShape;
|
||||
bool isTransposed;
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "nonKDim") {
|
||||
if (parseUInt(parser, attr, nonKDim, "nonKDim").failed())
|
||||
if (attr.getName() == "versionMajor") {
|
||||
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "versionMinor") {
|
||||
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "warpsPerCTA") {
|
||||
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
|
||||
return {};
|
||||
} else if (attr.getName() == "isTransposed") {
|
||||
}
|
||||
if (attr.getName() == "instrShape") {
|
||||
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "isTransposed") {
|
||||
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
|
||||
warpsPerCTA, isTransposed);
|
||||
return parser.getChecked<MfmaEncodingAttr>(
|
||||
parser.getContext(), versionMajor, versionMinor, warpsPerCTA,
|
||||
instrShape[0], instrShape[1], isTransposed);
|
||||
}
|
||||
|
||||
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "nonKDim = " << getNonKDim() << ", "
|
||||
<< "version = " << getVersionMajor() << "." << getVersionMinor()
|
||||
<< ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
|
||||
<< "instrShape = [" << getMDim() << ", " << getNDim() << "], "
|
||||
<< "isTransposed = " << getIsTransposed() << "}>";
|
||||
}
|
||||
|
||||
|
||||
@@ -133,20 +133,23 @@ public:
|
||||
/// @brief Choose MFMA instruction parameters
|
||||
/// @param dot target dot operation
|
||||
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
|
||||
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot) const {
|
||||
std::pair<unsigned, unsigned> chooseMfmaDimensions(tt::DotOp dot) const {
|
||||
// number of matrix elements along k dim per one MFMA intruction
|
||||
int64_t kDim = -1;
|
||||
unsigned kDim = 0;
|
||||
auto opType = dot.getA().getType().cast<RankedTensorType>();
|
||||
auto elemType = opType.getElementType();
|
||||
|
||||
auto dataTypeA = opType.getElementType();
|
||||
auto dataTypeB =
|
||||
dot.getB().getType().cast<RankedTensorType>().getElementType();
|
||||
|
||||
auto resType = dot.getD().getType().cast<RankedTensorType>();
|
||||
auto resShape = resType.getShape();
|
||||
|
||||
int64_t nonKDim = -1;
|
||||
unsigned nonKDim = 0;
|
||||
if (enforcedNonKDim != 0) {
|
||||
nonKDim = enforcedNonKDim;
|
||||
} else {
|
||||
nonKDim = -1;
|
||||
nonKDim = 0;
|
||||
int minSize = std::min(resShape[0], resShape[1]);
|
||||
if (minSize >= 32)
|
||||
nonKDim = 32;
|
||||
@@ -154,77 +157,17 @@ public:
|
||||
nonKDim = 16;
|
||||
if (minSize < 16)
|
||||
nonKDim = 4;
|
||||
assert(nonKDim != -1);
|
||||
assert(nonKDim != 0);
|
||||
}
|
||||
switch (nonKDim) {
|
||||
case 32:
|
||||
if (elemType.isF32())
|
||||
kDim = 2;
|
||||
if (elemType.isF16())
|
||||
kDim = 8;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 4;
|
||||
if (mfmaVersion >= 2)
|
||||
kDim = 8;
|
||||
}
|
||||
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
|
||||
assert(mfmaVersion == 3);
|
||||
kDim = 16;
|
||||
}
|
||||
if (elemType.isInteger(8)) {
|
||||
if (mfmaVersion == 3) {
|
||||
kDim = 16;
|
||||
}
|
||||
else {
|
||||
kDim = 8;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 16:
|
||||
if (elemType.isF32())
|
||||
kDim = 4;
|
||||
if (elemType.isF16())
|
||||
kDim = 16;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 8;
|
||||
if (mfmaVersion >= 2)
|
||||
kDim = 16;
|
||||
}
|
||||
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
|
||||
assert(mfmaVersion == 3);
|
||||
kDim = 32;
|
||||
}
|
||||
if (elemType.isInteger(8)) {
|
||||
if (mfmaVersion == 3) {
|
||||
kDim = 32;
|
||||
}
|
||||
else {
|
||||
kDim = 16;
|
||||
}
|
||||
}
|
||||
break;
|
||||
case 4:
|
||||
if (elemType.isF32())
|
||||
kDim = 16;
|
||||
if (elemType.isF16())
|
||||
kDim = 64;
|
||||
if (elemType.isBF16()) {
|
||||
if (mfmaVersion == 1)
|
||||
kDim = 32;
|
||||
if (mfmaVersion >= 2)
|
||||
kDim = 64;
|
||||
}
|
||||
if (elemType.isInteger(8)) {
|
||||
kDim = 64;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
llvm::report_fatal_error("unsupported nonKDim size in MFMA dot");
|
||||
}
|
||||
assert(kDim != -1);
|
||||
assert(nonKDim != -1);
|
||||
|
||||
auto maybeMfmaInsn =
|
||||
MfmaInsn::selectMfma(nonKDim, dataTypeA, dataTypeB, mfmaVersion);
|
||||
if (failed(maybeMfmaInsn))
|
||||
llvm::report_fatal_error("No match found in MFMA database\n");
|
||||
else
|
||||
kDim = (*maybeMfmaInsn).getKDim();
|
||||
assert(kDim != 0);
|
||||
assert(nonKDim != 0);
|
||||
assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0);
|
||||
assert(opType.getShape()[1] % kDim == 0);
|
||||
return {nonKDim, kDim};
|
||||
@@ -268,8 +211,10 @@ public:
|
||||
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
|
||||
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
|
||||
warpsPerTile, isTransposed);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(
|
||||
oldRetType.getContext(),
|
||||
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
|
||||
/*instrShape*/ nonKDim, nonKDim, isTransposed);
|
||||
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);
|
||||
|
||||
@@ -635,4 +635,180 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
|
||||
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
|
||||
}
|
||||
|
||||
// mfma instruction selection logic
|
||||
static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) {
|
||||
if (dataTypeA.isF32() && dataTypeB.isF32()) {
|
||||
return MfmaTypeId::Fp32TyId;
|
||||
}
|
||||
if (dataTypeA.isF16() && dataTypeB.isF16()) {
|
||||
return MfmaTypeId::Fp16TyId;
|
||||
}
|
||||
if (dataTypeA.isBF16() && dataTypeB.isBF16()) {
|
||||
return MfmaTypeId::Bf16TyId;
|
||||
}
|
||||
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
|
||||
return MfmaTypeId::I8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
|
||||
return MfmaTypeId::Fp8Fp8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
|
||||
return MfmaTypeId::Fp8Bf8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
|
||||
return MfmaTypeId::Bf8Fp8TyId;
|
||||
}
|
||||
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
|
||||
return MfmaTypeId::Bf8Bf8TyId;
|
||||
}
|
||||
llvm_unreachable("Unsupported input argument type.");
|
||||
}
|
||||
|
||||
using MfmaInsnGroupMap = llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnAttr,
|
||||
MfmaInsnGroupSelectKeyInfo>;
|
||||
|
||||
auto getMfmaInsnGroupAttrMap = []() -> const MfmaInsnGroupMap & {
|
||||
static MfmaInsnGroupMap MfmaInsnMap{
|
||||
// f32
|
||||
// mfma_f32_32x32x2f32
|
||||
{{32, MfmaTypeId::Fp32TyId, 1},
|
||||
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp32TyId, 2},
|
||||
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp32TyId, 3},
|
||||
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
|
||||
// mfma_f32_16x16x4f32
|
||||
{{16, MfmaTypeId::Fp32TyId, 1},
|
||||
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp32TyId, 2},
|
||||
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp32TyId, 3},
|
||||
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
|
||||
// mfma_f32_4x4x1f32
|
||||
{{4, MfmaTypeId::Fp32TyId, 1},
|
||||
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
|
||||
{{4, MfmaTypeId::Fp32TyId, 2},
|
||||
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
|
||||
// mfma_f32_4x4x1_16B_f32
|
||||
{{4, MfmaTypeId::Fp32TyId, 3},
|
||||
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
|
||||
// f16
|
||||
// mfma_f32_32x32x8f16
|
||||
{{32, MfmaTypeId::Fp16TyId, 1},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp16TyId, 2},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
|
||||
{{32, MfmaTypeId::Fp16TyId, 3},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
|
||||
// mfma_f32_16x16x16xf16
|
||||
{{16, MfmaTypeId::Fp16TyId, 1},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp16TyId, 2},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
|
||||
{{16, MfmaTypeId::Fp16TyId, 3},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
|
||||
// mfma_f32_4x4x4f16
|
||||
{{4, MfmaTypeId::Fp16TyId, 1},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
|
||||
{{4, MfmaTypeId::Fp16TyId, 2},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
|
||||
{{4, MfmaTypeId::Fp16TyId, 3},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
|
||||
// bf16
|
||||
// mfma_f32_32x32x4_bf16
|
||||
{{32, MfmaTypeId::Bf16TyId, 1},
|
||||
{32, 32, 4, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}},
|
||||
// mfma_f32_32x32x8_bf16_1K
|
||||
{{32, MfmaTypeId::Bf16TyId, 2},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
|
||||
{{32, MfmaTypeId::Bf16TyId, 3},
|
||||
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
|
||||
// mfma_f32_16x16x8_bf16
|
||||
{{16, MfmaTypeId::Bf16TyId, 1},
|
||||
{16, 16, 8, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}},
|
||||
// mfma_f32_16x16x16_bf16_1K
|
||||
{{16, MfmaTypeId::Bf16TyId, 2},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
|
||||
{{16, MfmaTypeId::Bf16TyId, 3},
|
||||
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
|
||||
// mfma_f32_4x4x2_bf16
|
||||
{{4, MfmaTypeId::Bf16TyId, 1},
|
||||
{4, 4, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}},
|
||||
// mfma_f32_4x4x4_bf16_1K
|
||||
{{4, MfmaTypeId::Bf16TyId, 2},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
|
||||
{{4, MfmaTypeId::Bf16TyId, 3},
|
||||
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
|
||||
// int8
|
||||
// mfma_f32_32x32x8i8
|
||||
{{32, MfmaTypeId::I8TyId, 1},
|
||||
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
|
||||
{{32, MfmaTypeId::I8TyId, 2},
|
||||
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
|
||||
// mfma_f32_32x32x16i8
|
||||
{{32, MfmaTypeId::I8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}},
|
||||
// mfma_f32_16x16x16i8
|
||||
{{16, MfmaTypeId::I8TyId, 1},
|
||||
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
|
||||
{{16, MfmaTypeId::I8TyId, 2},
|
||||
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
|
||||
// mfma_f32_16x16x32i8
|
||||
{{16, MfmaTypeId::I8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}},
|
||||
// mfma_f32_4x4x4i8
|
||||
{{4, MfmaTypeId::I8TyId, 1},
|
||||
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
|
||||
{{4, MfmaTypeId::I8TyId, 2},
|
||||
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
|
||||
{{4, MfmaTypeId::I8TyId, 3},
|
||||
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
|
||||
// fp8 * pf8
|
||||
// mfma_f32_32x32x16_FP8_FP8
|
||||
{{32, MfmaTypeId::Fp8Fp8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_FP8_FP8
|
||||
{{16, MfmaTypeId::Fp8Fp8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}},
|
||||
// mfma_f32_32x32x16_FP8_BF8
|
||||
{{32, MfmaTypeId::Fp8Bf8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_FP8_BF8
|
||||
{{16, MfmaTypeId::Fp8Bf8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}},
|
||||
// mfma_f32_32x32x16_BF8_FP8
|
||||
{{32, MfmaTypeId::Bf8Fp8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_BF8_FP8
|
||||
{{16, MfmaTypeId::Bf8Fp8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}},
|
||||
// mfma_f32_32x32x16_BF8_BF8
|
||||
{{32, MfmaTypeId::Bf8Bf8TyId, 3},
|
||||
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}},
|
||||
// mfma_f32_16x16x32_BF8_BF8
|
||||
{{16, MfmaTypeId::Bf8Bf8TyId, 3},
|
||||
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}};
|
||||
return MfmaInsnMap;
|
||||
};
|
||||
|
||||
FailureOr<MfmaInsn> MfmaInsn::selectMfma(unsigned nonKDim, Type elementTypeA,
|
||||
Type elementTypeB, int mfmaVersion) {
|
||||
auto mfmaInsnAttrMap = getMfmaInsnGroupAttrMap();
|
||||
MfmaInsnGroupSelectKey key = {
|
||||
nonKDim, convertTypesToId(elementTypeA, elementTypeB), mfmaVersion};
|
||||
auto it = mfmaInsnAttrMap.find(key);
|
||||
if (it == mfmaInsnAttrMap.end())
|
||||
return failure();
|
||||
return MfmaInsn(elementTypeA, elementTypeB, (*it).second);
|
||||
}
|
||||
|
||||
MfmaInsn::MfmaInsn(Type elementTypeA, Type elementTypeB,
|
||||
const MfmaInsnAttr &attr)
|
||||
: elementTypeA(elementTypeA), elementTypeB(elementTypeB), attr(attr) {}
|
||||
|
||||
unsigned MfmaInsn::getKDim() { return attr.k; }
|
||||
unsigned MfmaInsn::getMDim() { return attr.m; }
|
||||
unsigned MfmaInsn::getNDim() { return attr.n; }
|
||||
StringRef MfmaInsn::getInsnName() { return attr.insn; }
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user