[OPTIMIZER] Propagate mma layout when the transitive use has dot_operand encoding (#2482)

This commit is contained in:
Thomas Raoux
2023-10-12 16:57:40 -07:00
committed by GitHub
parent 03af50b040
commit a777e1d8db

View File

@@ -163,14 +163,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
getForwardSlice(currentValue, &forwardSlice);
for (Operation *op : forwardSlice) {
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
Attribute dstEncoding = convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding();
if (auto mmaLayout =
convertOp.getResult()
.getType()
.cast<RankedTensorType>()
.getEncoding()
.dyn_cast_or_null<triton::gpu::MmaEncodingAttr>())
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
return (mmaLayout.getVersionMajor() > 1) ? true
: mmaLayout == encoding;
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
return encoding.cast<triton::gpu::MmaEncodingAttr>()
.getVersionMajor() > 1;
}
auto yield = dyn_cast<scf::YieldOp>(op);
if (!yield)