mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[OPTIMIZER] Propagate mma layout when the transitive use has dot_operand encoding (#2482)
This commit is contained in:
@@ -163,14 +163,17 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) {
|
||||
getForwardSlice(currentValue, &forwardSlice);
|
||||
for (Operation *op : forwardSlice) {
|
||||
if (auto convertOp = dyn_cast<triton::gpu::ConvertLayoutOp>(op)) {
|
||||
Attribute dstEncoding = convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding();
|
||||
if (auto mmaLayout =
|
||||
convertOp.getResult()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.dyn_cast_or_null<triton::gpu::MmaEncodingAttr>())
|
||||
dstEncoding.dyn_cast<triton::gpu::MmaEncodingAttr>())
|
||||
return (mmaLayout.getVersionMajor() > 1) ? true
|
||||
: mmaLayout == encoding;
|
||||
if (dstEncoding.isa<triton::gpu::DotOperandEncodingAttr>())
|
||||
return encoding.cast<triton::gpu::MmaEncodingAttr>()
|
||||
.getVersionMajor() > 1;
|
||||
}
|
||||
auto yield = dyn_cast<scf::YieldOp>(op);
|
||||
if (!yield)
|
||||
|
||||
Reference in New Issue
Block a user