mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
ROCM IFU: Add new CTALayout parameter to mfma layout
This commit is contained in:
committed by
Jason Furmanek
parent
e5d7bb4fae
commit
bae0e4527c
@@ -718,7 +718,8 @@ The data will be distributed between threads as follows:
|
||||
ins
|
||||
"unsigned":$nonKDim,
|
||||
ArrayRefParameter<"unsigned">:$warpsPerCTA,
|
||||
"bool":$isTransposed
|
||||
"bool":$isTransposed,
|
||||
"CTALayoutAttr":$CTALayout
|
||||
);
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -500,6 +500,10 @@ SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
|
||||
assert(0 && "getCTAsPerCGA for SliceEncodingAttr is not well-defined");
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
|
||||
ref = mmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#ifdef USE_ROCM
|
||||
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
|
||||
ref = mfmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#endif
|
||||
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
|
||||
return getCTAsPerCGA(dotLayout.getParent());
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
@@ -521,6 +525,11 @@ SmallVector<unsigned> getCTASplitNum(Attribute layout) {
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
res.assign(mmaLayout.getCTALayout().getCTASplitNum().begin(),
|
||||
mmaLayout.getCTALayout().getCTASplitNum().end());
|
||||
#ifdef USE_ROCM
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
res.assign(mfmaLayout.getCTALayout().getCTASplitNum().begin(),
|
||||
mfmaLayout.getCTALayout().getCTASplitNum().end());
|
||||
#endif
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
res = getCTASplitNum(dotLayout.getParent());
|
||||
assert(res.size() == 2 && "Invalid dotLayout");
|
||||
@@ -546,6 +555,10 @@ SmallVector<unsigned> getCTAOrder(Attribute layout) {
|
||||
return eraseOrder(parentCTAOrder, sliceLayout.getDim());
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
ref = mmaLayout.getCTALayout().getCTAOrder();
|
||||
#ifdef USE_ROCM
|
||||
} else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>()) {
|
||||
ref = mfmaLayout.getCTALayout().getCTAOrder();
|
||||
#endif
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return getCTAOrder(dotLayout.getParent());
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
@@ -597,6 +610,10 @@ unsigned getNumWarpsPerCTA(Attribute layout) {
|
||||
return getNumWarpsPerCTA(sliceLayout.getParent());
|
||||
else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
|
||||
warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
#ifdef USE_ROCM
|
||||
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
|
||||
warpsPerCTA = mfmaLayout.getWarpsPerCTA();
|
||||
#endif
|
||||
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
|
||||
return getNumWarpsPerCTA(dotLayout.getParent());
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
@@ -614,6 +631,10 @@ unsigned getNumCTAs(Attribute layout) {
|
||||
return getNumCTAs(sliceLayout.getParent());
|
||||
else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>())
|
||||
CTAsPerCGA = mmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#ifdef USE_ROCM
|
||||
else if (auto mfmaLayout = layout.dyn_cast<MfmaEncodingAttr>())
|
||||
CTAsPerCGA = mfmaLayout.getCTALayout().getCTAsPerCGA();
|
||||
#endif
|
||||
else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>())
|
||||
return getNumCTAs(dotLayout.getParent());
|
||||
else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>())
|
||||
@@ -1240,6 +1261,10 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
unsigned nonKDim = 0;
|
||||
SmallVector<unsigned> warpsPerCTA;
|
||||
bool isTransposed;
|
||||
SmallVector<unsigned> CTAsPerCGA;
|
||||
SmallVector<unsigned> CTASplitNum;
|
||||
SmallVector<unsigned> CTAOrder;
|
||||
|
||||
|
||||
for (const NamedAttribute &attr : dict) {
|
||||
if (attr.getName() == "nonKDim") {
|
||||
@@ -1253,17 +1278,35 @@ Attribute MfmaEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "CTAsPerCGA") {
|
||||
if (parseIntArrayAttr(parser, attr, CTAsPerCGA, "CTAsPerCGA").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "CTASplitNum") {
|
||||
if (parseIntArrayAttr(parser, attr, CTASplitNum, "CTASplitNum").failed())
|
||||
return {};
|
||||
}
|
||||
if (attr.getName() == "CTAOrder") {
|
||||
if (parseIntArrayAttr(parser, attr, CTAOrder, "CTAOrder").failed())
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
auto CTALayout = CTALayoutAttr::get(parser.getContext(), CTAsPerCGA,
|
||||
CTASplitNum, CTAOrder);
|
||||
|
||||
return parser.getChecked<MfmaEncodingAttr>(parser.getContext(), nonKDim,
|
||||
warpsPerCTA, isTransposed);
|
||||
warpsPerCTA, isTransposed, CTALayout);
|
||||
}
|
||||
|
||||
void MfmaEncodingAttr::print(AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "nonKDim = " << getNonKDim() << ", "
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "]"
|
||||
<< ", isTransposed = " << getIsTransposed() << "}>";
|
||||
<< "warpsPerCTA = [" << getWarpsPerCTA() << "], "
|
||||
<< "isTransposed = " << getIsTransposed() << ", "
|
||||
<< "CTAsPerCGA = [" << getCTALayout().getCTAsPerCGA() << "], "
|
||||
<< "CTASplitNum = [" << getCTALayout().getCTASplitNum() << "], "
|
||||
<< "CTAOrder = [" << getCTALayout().getCTAOrder() << "]}>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
@@ -196,6 +196,8 @@ public:
|
||||
if (!supportMFMA(dotOp))
|
||||
return failure();
|
||||
|
||||
auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
|
||||
|
||||
// get MFMA encoding for the given number of warps
|
||||
auto retShape = oldRetType.getShape();
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
@@ -216,7 +218,7 @@ public:
|
||||
|
||||
bool isTransposed = isChainDot(dotOp);
|
||||
mfmaEnc = ttg::MfmaEncodingAttr::get(
|
||||
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed);
|
||||
oldRetType.getContext(), nonKDim, warpsPerTile, isTransposed, CTALayout);
|
||||
|
||||
auto newRetType =
|
||||
RankedTensorType::get(retShape, oldRetType.getElementType(), mfmaEnc);
|
||||
|
||||
Reference in New Issue
Block a user