ROCM IFU: Resolve merge conflicts in RemoveLayoutConversions.cpp

fix merge error

fix dot

fix make_range

additional fix
This commit is contained in:
Jason Furmanek
2023-11-07 03:10:04 +00:00
parent c3132eeda8
commit 39e8901d7a
3 changed files with 35 additions and 52 deletions

View File

@@ -1003,22 +1003,11 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
if (targetType.getEncoding().isa<triton::gpu::DotOperandEncodingAttr>())
return;
<<<<<<< HEAD
#ifndef USE_ROCM
auto isExtOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(op);
=======
auto isExtOrBroadcastOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp,
triton::BroadcastOp, triton::ExpandDimsOp>(op);
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
};
#else
auto isExtOp = [](Operation *op) {
return isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp,
triton::BroadcastOp, triton::ExpandDimsOp>(op);
};
#endif
// 1. Take a backward slice of all the tensor dependencies.
SetVector<Value> slice;
DenseMap<Value, Attribute> layout;
@@ -1064,22 +1053,12 @@ static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) {
inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]);
if (!srcEncoding)
return;
std::optional<Attribute> srcEncoding =
inferSrcEncoding(extOp, layout[extOp->getResult(0)]);
// Move the convert before the ext op and rewrite the slice.
<<<<<<< HEAD
OpBuilder builder(extOp);
auto tensorType = extOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType =
RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(),
*srcEncoding);
=======
OpBuilder builder(extOrBroadcatOp);
auto tensorType =
extOrBroadcatOp->getOperand(0).getType().cast<RankedTensorType>();
auto newType = RankedTensorType::get(
tensorType.getShape(), tensorType.getElementType(), *srcEncoding);
>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33
auto newConvertOp = builder.create<ConvertLayoutOp>(
convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0));
IRMapping mapping;