ROCM IFU: Add new CTALayout parameter to mfma layout

This commit is contained in:
Aleksandr Efimov
2023-09-25 17:26:29 +00:00
committed by Jason Furmanek
parent e5d7bb4fae
commit bae0e4527c
3 changed files with 51 additions and 5 deletions

View File

@@ -718,7 +718,8 @@ The data will be distributed between threads as follows:
ins
"unsigned":$nonKDim,
ArrayRefParameter<"unsigned">:$warpsPerCTA,
"bool":$isTransposed
"bool":$isTransposed,
"CTALayoutAttr":$CTALayout
);
let hasCustomAssemblyFormat = 1;

View File

@@ -500,6 +500,10 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
assert(0 && "getCTAsPerCGA for SliceEncodingAttr is not well-defined");
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
ref = mmaLayout.getCTALayout().getCTAsPerCGA();
#ifdef USE_ROCM
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
ref = mfmaLayout.getCTALayout().getCTAsPerCGA();
#endif
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getCTAsPerCGA(dotLayout.getParent());
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
@@ -521,6 +525,11 @@ SmallVector<unsigned> getCTASplitNum(Attribute layout) {
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
res.assign(mmaLayout.getCTALayout().getCTASplitNum().begin(),
mmaLayout.getCTALayout().getCTASplitNum().end());
#ifdef USE_ROCM
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
res.assign(mfmaLayout.getCTALayout().getCTASplitNum().begin(),
mfmaLayout.getCTALayout().getCTASplitNum().end());
#endif
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
res = getCTASplitNum(dotLayout.getParent());
assert(res.size() == 2 && "Invalid dotLayout");
@@ -546,6 +555,10 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
return eraseOrder(parentCTAOrder, sliceLayout.getDim());
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
ref = mmaLayout.getCTALayout().getCTAOrder();
#ifdef USE_ROCM
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
ref = mfmaLayout.getCTALayout().getCTAOrder();
#endif
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
return getCTAOrder(dotLayout.getParent());
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
@@ -597,6 +610,10 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
return getNumWarpsPerCTA(sliceLayout.getParent());
else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
warpsPerCTA = mmaLayout.getWarpsPerCTA();
#ifdef USE_ROCM
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
warpsPerCTA = mfmaLayout.getWarpsPerCTA();
#endif
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getNumWarpsPerCTA(dotLayout.getParent());
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
@@ -614,6 +631,10 @@ unsigned getNumCTAs(Attribute layout) {
return getNumCTAs(sliceLayout.getParent());
else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
CTAsPerCGA = mmaLayout.getCTALayout().getCTAsPerCGA();
#ifdef USE_ROCM
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
CTAsPerCGA = mfmaLayout.getCTALayout().getCTAsPerCGA();
#endif
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
return getNumCTAs(dotLayout.getParent());
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
@@ -1240,6 +1261,10 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
unsigned nonKDim = 0;
SmallVector<unsigned> warpsPerCTA;
bool isTransposed;
SmallVector<unsigned> CTAsPerCGA;
SmallVector<unsigned> CTASplitNum;
SmallVector<unsigned> CTAOrder;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "nonKDim") {
@@ -1253,17 +1278,35 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
return {};
}
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed())
return {};
}
if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed())
return {};
}
if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed())
return {};
}
}
auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA,
CTASplitNum, CTAOrder);
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
warpsPerCTA, isTransposed);
warpsPerCTA, isTransposed, CTALayout);
}
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "nonKDim = " << getNonKDim() << ", "
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
<< ", isTransposed = " << getIsTransposed() << "}>";
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
<< "isTransposed = " << getIsTransposed() << ", "
<< "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], "
<< "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], "
<< "CTAOrder = [" << getCTALayout().getCTAOrder() << "]}>";
}
//===----------------------------------------------------------------------===//

View File

@@ -196,6 +196,8 @@ public:
if (!supportMFMA(dotOp))
return failure();
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
// get MFMA encoding for the given number of warps
auto retShape = oldRetType.getShape();
auto mod = op->getParentOfType<mlir::ModuleOp>();
@@ -216,7 +218,7 @@ public:
bool isTransposed = isChainDot(dotOp);
mfmaEnc = ttg::MfmaEncodingAttr::get(
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed);
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed, CTALayout);
auto newRetType =
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);