[Backend] Refactor mfma selection (#441)

* Select mfma dimensions and instruction from static table

* Extend mfmaLayout to include version and instrShape

* Simplify generateMFMAOp by searching the mfma instruction in the table

* Fix getNonKDim() and non_k_dim

* Break instrShape into MDim and NDim
This commit is contained in:
Lixun Zhang
2024-01-16 21:05:35 -06:00
committed by GitHub
parent d2f8bc1740
commit 02a2f24dd5
15 changed files with 457 additions and 411 deletions

View File

@@ -104,7 +104,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
}
if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
unsigned rows, cols;
if (mfmaLayout.getNonKDim() == 32) {
int mfmaMDim = mfmaLayout.getMDim();
if (32 == mfmaMDim) {
cols = 2;
rows = 32;
} else {
@@ -240,7 +241,7 @@ SmallVector<unsigned> getSizePerThread(Attribute layout) {
}
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
unsigned rows, cols;
switch (mfmaLayout.getNonKDim()) {
switch (mfmaLayout.getMDim()) {
case 32:
rows = 16;
cols = 1;
@@ -349,13 +350,19 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
} else
llvm::report_fatal_error("Unimplemented usage of MmaEncodingAttr");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
if (mfmaLayout.getNonKDim() == 32) {
threads = {32 * mfmaLayout.getWarpsPerCTA()[0],
2 * mfmaLayout.getWarpsPerCTA()[1]};
int mfmaMDim = mfmaLayout.getMDim();
SmallVector<unsigned> threadsPerWarp;
if (32 == mfmaMDim) {
threadsPerWarp = {2, 32};
} else {
threads = {16 * mfmaLayout.getWarpsPerCTA()[0],
4 * mfmaLayout.getWarpsPerCTA()[1]};
threadsPerWarp = {4, 16};
}
if (mfmaLayout.getIsTransposed())
threads = {threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[0],
threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[1]};
else
threads = {threadsPerWarp[0] * mfmaLayout.getWarpsPerCTA()[0],
threadsPerWarp[1] * mfmaLayout.getWarpsPerCTA()[1]};
} else {
llvm::report_fatal_error("Unimplemented usage of getThreadsPerCTA");
}
@@ -393,9 +400,10 @@ SmallVector<unsigned> getShapePerCTATile(Attribute layout,
}
llvm::report_fatal_error("Unexpected MMA layout version found");
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
auto nonKDim = mfmaLayout.getNonKDim();
return {nonKDim * mfmaLayout.getWarpsPerCTA()[0],
nonKDim * mfmaLayout.getWarpsPerCTA()[1]};
auto mfmaMDim = mfmaLayout.getMDim();
auto mfmaNDim = mfmaLayout.getNDim();
return {mfmaMDim * mfmaLayout.getWarpsPerCTA()[0],
mfmaNDim * mfmaLayout.getWarpsPerCTA()[1]};
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
auto parentLayout = dotLayout.getParent();
assert(parentLayout && "DotOperandEncodingAttr must have a parent");
@@ -719,11 +727,11 @@ bool sameBlockedEncodings(BlockedEncodingAttr blockedA,
}
bool sameMfmaEncodings(MfmaEncodingAttr mfmaA, MfmaEncodingAttr mfmaB) {
auto nonKDimA = mfmaA.getNonKDim();
auto nonKDimA = mfmaA.getMDim();
auto warpsPerCTAA = mfmaA.getWarpsPerCTA();
auto isTransposedA = mfmaA.getIsTransposed();
auto nonKDimB = mfmaB.getNonKDim();
auto nonKDimB = mfmaB.getMDim();
auto warpsPerCTAB = mfmaB.getWarpsPerCTA();
auto isTransposedB = mfmaB.getIsTransposed();
@@ -913,19 +921,22 @@ MfmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
assert(rank == 2 && "Unexpected rank of mfma layout");
SmallVector<unsigned> elemsPerThread(rank);
auto nonKDim = getNonKDim();
auto elemsPerThreadPerTile = (nonKDim == 16 ? 4 : 16);
auto mfmaMDim = getMDim();
auto mfmaNDim = getNDim();
auto elemsPerThreadPerTile = (mfmaMDim == 32 ? 16 : 4);
if (getIsTransposed()) {
unsigned elemsCol =
ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]) *
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]) *
elemsPerThreadPerTile;
unsigned elemsRow = ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]);
unsigned elemsRow =
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]);
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
} else {
unsigned elemsCol = ceil<unsigned>(shape[1], nonKDim * getWarpsPerCTA()[1]);
unsigned elemsCol =
ceil<unsigned>(shape[1], mfmaNDim * getWarpsPerCTA()[1]);
unsigned elemsRow =
ceil<unsigned>(shape[0], nonKDim * getWarpsPerCTA()[0]) *
ceil<unsigned>(shape[0], mfmaMDim * getWarpsPerCTA()[0]) *
elemsPerThreadPerTile;
elemsPerThread[0] = elemsRow;
elemsPerThread[1] = elemsCol;
@@ -1053,7 +1064,7 @@ DotOperandEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
SmallVector<int64_t>
DotOperandEncodingAttr::getMFMAElemsPerInstr() const {
auto mfmaEncoding = getParent().cast<MfmaEncodingAttr>();
int64_t nonKDim = mfmaEncoding.getNonKDim();
int64_t nonKDim = mfmaEncoding.getMDim();
assert(nonKDim == 32 || nonKDim == 16 || nonKDim == 4);
int64_t kWidth = getKWidth();
constexpr int waveSize = 64; // MFMA is used on wave64 architectures only
@@ -1367,32 +1378,46 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseGreater().failed())
return {};
unsigned nonKDim = 0;
unsigned versionMajor = 0;
unsigned versionMinor = 0;
SmallVector<unsigned> warpsPerCTA;
SmallVector<unsigned> instrShape;
bool isTransposed;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "nonKDim") {
if (parseUInt(parser, attr, nonKDim, "nonKDim").failed())
if (attr.getName() == "versionMajor") {
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
return {};
}
if (attr.getName() == "versionMinor") {
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
} else if (attr.getName() == "isTransposed") {
}
if (attr.getName() == "instrShape") {
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
return {};
}
if (attr.getName() == "isTransposed") {
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
return {};
}
}
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
warpsPerCTA, isTransposed);
return parser.getChecked<MfmaEncodingAttr>(
parser.getContext(), versionMajor, versionMinor, warpsPerCTA,
instrShape[0], instrShape[1], isTransposed);
}
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "nonKDim = " << getNonKDim() << ", "
<< "version = " << getVersionMajor() << "." << getVersionMinor()
<< ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
<< "instrShape = [" << getMDim() << ", " << getNDim() << "], "
<< "isTransposed = " << getIsTransposed() << "}>";
}

View File

@@ -133,20 +133,23 @@ public:
/// @brief Choose MFMA instruction parameters
/// @param dot target dot operation
/// @return pair {nonKDim, kDim} sizes of one MFMA instruction arguments
std::pair<int64_t, int64_t> chooseMfmaDimensions(tt::DotOp dot) const {
std::pair<unsigned, unsigned> chooseMfmaDimensions(tt::DotOp dot) const {
// number of matrix elements along k dim per one MFMA intruction
int64_t kDim = -1;
unsigned kDim = 0;
auto opType = dot.getA().getType().cast<RankedTensorType>();
auto elemType = opType.getElementType();
auto dataTypeA = opType.getElementType();
auto dataTypeB =
dot.getB().getType().cast<RankedTensorType>().getElementType();
auto resType = dot.getD().getType().cast<RankedTensorType>();
auto resShape = resType.getShape();
int64_t nonKDim = -1;
unsigned nonKDim = 0;
if (enforcedNonKDim != 0) {
nonKDim = enforcedNonKDim;
} else {
nonKDim = -1;
nonKDim = 0;
int minSize = std::min(resShape[0], resShape[1]);
if (minSize >= 32)
nonKDim = 32;
@@ -154,77 +157,17 @@ public:
nonKDim = 16;
if (minSize < 16)
nonKDim = 4;
assert(nonKDim != -1);
assert(nonKDim != 0);
}
switch (nonKDim) {
case 32:
if (elemType.isF32())
kDim = 2;
if (elemType.isF16())
kDim = 8;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 4;
if (mfmaVersion >= 2)
kDim = 8;
}
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
assert(mfmaVersion == 3);
kDim = 16;
}
if (elemType.isInteger(8)) {
if (mfmaVersion == 3) {
kDim = 16;
}
else {
kDim = 8;
}
}
break;
case 16:
if (elemType.isF32())
kDim = 4;
if (elemType.isF16())
kDim = 16;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 8;
if (mfmaVersion >= 2)
kDim = 16;
}
if (elemType.isFloat8E4M3FNUZ() || elemType.isFloat8E5M2FNUZ()) {
assert(mfmaVersion == 3);
kDim = 32;
}
if (elemType.isInteger(8)) {
if (mfmaVersion == 3) {
kDim = 32;
}
else {
kDim = 16;
}
}
break;
case 4:
if (elemType.isF32())
kDim = 16;
if (elemType.isF16())
kDim = 64;
if (elemType.isBF16()) {
if (mfmaVersion == 1)
kDim = 32;
if (mfmaVersion >= 2)
kDim = 64;
}
if (elemType.isInteger(8)) {
kDim = 64;
}
break;
default:
llvm::report_fatal_error("unsupported nonKDim size in MFMA dot");
}
assert(kDim != -1);
assert(nonKDim != -1);
auto maybeMfmaInsn =
MfmaInsn::selectMfma(nonKDim, dataTypeA, dataTypeB, mfmaVersion);
if (failed(maybeMfmaInsn))
llvm::report_fatal_error("No match found in MFMA database\n");
else
kDim = (*maybeMfmaInsn).getKDim();
assert(kDim != 0);
assert(nonKDim != 0);
assert(resShape[0] % nonKDim == 0 && resShape[1] % nonKDim == 0);
assert(opType.getShape()[1] % kDim == 0);
return {nonKDim, kDim};
@@ -268,8 +211,10 @@ public:
warpsPerTileMFMA(dotOp, retShape, numWarps, {nonKDim, nonKDim});
bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(oldRetType.getContext(), nonKDim,
warpsPerTile, isTransposed);
mfmaEnc = ttg::MfmaEncodingAttr::get(
oldRetType.getContext(),
/*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile,
/*instrShape*/ nonKDim, nonKDim, isTransposed);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);

View File

@@ -635,4 +635,180 @@ void populateForOpDeadArgumentElimination(RewritePatternSet &patterns) {
patterns.add<ForOpDeadArgElimination>(patterns.getContext());
}
// mfma instruction selection logic
static MfmaTypeId convertTypesToId(mlir::Type dataTypeA, mlir::Type dataTypeB) {
if (dataTypeA.isF32() && dataTypeB.isF32()) {
return MfmaTypeId::Fp32TyId;
}
if (dataTypeA.isF16() && dataTypeB.isF16()) {
return MfmaTypeId::Fp16TyId;
}
if (dataTypeA.isBF16() && dataTypeB.isBF16()) {
return MfmaTypeId::Bf16TyId;
}
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
return MfmaTypeId::I8TyId;
}
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
return MfmaTypeId::Fp8Fp8TyId;
}
if (dataTypeA.isFloat8E4M3FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
return MfmaTypeId::Fp8Bf8TyId;
}
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E4M3FNUZ()) {
return MfmaTypeId::Bf8Fp8TyId;
}
if (dataTypeA.isFloat8E5M2FNUZ() && dataTypeB.isFloat8E5M2FNUZ()) {
return MfmaTypeId::Bf8Bf8TyId;
}
llvm_unreachable("Unsupported input argument type.");
}
using MfmaInsnGroupMap = llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnAttr,
MfmaInsnGroupSelectKeyInfo>;
auto getMfmaInsnGroupAttrMap = []() -> const MfmaInsnGroupMap & {
static MfmaInsnGroupMap MfmaInsnMap{
// f32
// mfma_f32_32x32x2f32
{{32, MfmaTypeId::Fp32TyId, 1},
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
{{32, MfmaTypeId::Fp32TyId, 2},
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
{{32, MfmaTypeId::Fp32TyId, 3},
{32, 32, 2, 1, ROCDL::mfma_f32_32x32x2f32::getOperationName()}},
// mfma_f32_16x16x4f32
{{16, MfmaTypeId::Fp32TyId, 1},
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
{{16, MfmaTypeId::Fp32TyId, 2},
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
{{16, MfmaTypeId::Fp32TyId, 3},
{16, 16, 4, 1, ROCDL::mfma_f32_16x16x4f32::getOperationName()}},
// mfma_f32_4x4x1f32
{{4, MfmaTypeId::Fp32TyId, 1},
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
{{4, MfmaTypeId::Fp32TyId, 2},
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
// mfma_f32_4x4x1_16B_f32
{{4, MfmaTypeId::Fp32TyId, 3},
{4, 4, 16, 1, ROCDL::mfma_f32_4x4x1f32::getOperationName()}},
// f16
// mfma_f32_32x32x8f16
{{32, MfmaTypeId::Fp16TyId, 1},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
{{32, MfmaTypeId::Fp16TyId, 2},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
{{32, MfmaTypeId::Fp16TyId, 3},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8f16::getOperationName()}},
// mfma_f32_16x16x16xf16
{{16, MfmaTypeId::Fp16TyId, 1},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
{{16, MfmaTypeId::Fp16TyId, 2},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
{{16, MfmaTypeId::Fp16TyId, 3},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16f16::getOperationName()}},
// mfma_f32_4x4x4f16
{{4, MfmaTypeId::Fp16TyId, 1},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
{{4, MfmaTypeId::Fp16TyId, 2},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
{{4, MfmaTypeId::Fp16TyId, 3},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4f16::getOperationName()}},
// bf16
// mfma_f32_32x32x4_bf16
{{32, MfmaTypeId::Bf16TyId, 1},
{32, 32, 4, 2, ROCDL::mfma_f32_32x32x4bf16::getOperationName()}},
// mfma_f32_32x32x8_bf16_1K
{{32, MfmaTypeId::Bf16TyId, 2},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
{{32, MfmaTypeId::Bf16TyId, 3},
{32, 32, 8, 4, ROCDL::mfma_f32_32x32x8bf16_1k::getOperationName()}},
// mfma_f32_16x16x8_bf16
{{16, MfmaTypeId::Bf16TyId, 1},
{16, 16, 8, 2, ROCDL::mfma_f32_16x16x8bf16::getOperationName()}},
// mfma_f32_16x16x16_bf16_1K
{{16, MfmaTypeId::Bf16TyId, 2},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
{{16, MfmaTypeId::Bf16TyId, 3},
{16, 16, 16, 4, ROCDL::mfma_f32_16x16x16bf16_1k::getOperationName()}},
// mfma_f32_4x4x2_bf16
{{4, MfmaTypeId::Bf16TyId, 1},
{4, 4, 32, 2, ROCDL::mfma_f32_4x4x2bf16::getOperationName()}},
// mfma_f32_4x4x4_bf16_1K
{{4, MfmaTypeId::Bf16TyId, 2},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
{{4, MfmaTypeId::Bf16TyId, 3},
{4, 4, 64, 4, ROCDL::mfma_f32_4x4x4bf16_1k::getOperationName()}},
// int8
// mfma_f32_32x32x8i8
{{32, MfmaTypeId::I8TyId, 1},
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
{{32, MfmaTypeId::I8TyId, 2},
{32, 32, 8, 4, ROCDL::mfma_i32_32x32x8i8::getOperationName()}},
// mfma_f32_32x32x16i8
{{32, MfmaTypeId::I8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_i32_32x32x16_i8::getOperationName()}},
// mfma_f32_16x16x16i8
{{16, MfmaTypeId::I8TyId, 1},
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
{{16, MfmaTypeId::I8TyId, 2},
{16, 16, 16, 4, ROCDL::mfma_i32_16x16x16i8::getOperationName()}},
// mfma_f32_16x16x32i8
{{16, MfmaTypeId::I8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_i32_16x16x32_i8::getOperationName()}},
// mfma_f32_4x4x4i8
{{4, MfmaTypeId::I8TyId, 1},
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
{{4, MfmaTypeId::I8TyId, 2},
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
{{4, MfmaTypeId::I8TyId, 3},
{4, 4, 64, 4, ROCDL::mfma_i32_4x4x4i8::getOperationName()}},
// fp8 * pf8
// mfma_f32_32x32x16_FP8_FP8
{{32, MfmaTypeId::Fp8Fp8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName()}},
// mfma_f32_16x16x32_FP8_FP8
{{16, MfmaTypeId::Fp8Fp8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_fp8::getOperationName()}},
// mfma_f32_32x32x16_FP8_BF8
{{32, MfmaTypeId::Fp8Bf8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_fp8_bf8::getOperationName()}},
// mfma_f32_16x16x32_FP8_BF8
{{16, MfmaTypeId::Fp8Bf8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_fp8_bf8::getOperationName()}},
// mfma_f32_32x32x16_BF8_FP8
{{32, MfmaTypeId::Bf8Fp8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_fp8::getOperationName()}},
// mfma_f32_16x16x32_BF8_FP8
{{16, MfmaTypeId::Bf8Fp8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_fp8::getOperationName()}},
// mfma_f32_32x32x16_BF8_BF8
{{32, MfmaTypeId::Bf8Bf8TyId, 3},
{32, 32, 16, 8, ROCDL::mfma_f32_32x32x16_bf8_bf8::getOperationName()}},
// mfma_f32_16x16x32_BF8_BF8
{{16, MfmaTypeId::Bf8Bf8TyId, 3},
{16, 16, 32, 8, ROCDL::mfma_f32_16x16x32_bf8_bf8::getOperationName()}}};
return MfmaInsnMap;
};
FailureOr<MfmaInsn> MfmaInsn::selectMfma(unsigned nonKDim, Type elementTypeA,
Type elementTypeB, int mfmaVersion) {
auto mfmaInsnAttrMap = getMfmaInsnGroupAttrMap();
MfmaInsnGroupSelectKey key = {
nonKDim, convertTypesToId(elementTypeA, elementTypeB), mfmaVersion};
auto it = mfmaInsnAttrMap.find(key);
if (it == mfmaInsnAttrMap.end())
return failure();
return MfmaInsn(elementTypeA, elementTypeB, (*it).second);
}
MfmaInsn::MfmaInsn(Type elementTypeA, Type elementTypeB,
const MfmaInsnAttr &attr)
: elementTypeA(elementTypeA), elementTypeB(elementTypeB), attr(attr) {}
unsigned MfmaInsn::getKDim() { return attr.k; }
unsigned MfmaInsn::getMDim() { return attr.m; }
unsigned MfmaInsn::getNDim() { return attr.n; }
StringRef MfmaInsn::getInsnName() { return attr.insn; }
} // namespace mlir