[ANALYSIS] propagate divisibility through tl.where for all types (#2023)

This commit is contained in:
Luca Wehrstedt
2023-08-15 05:26:31 +02:00
committed by GitHub
parent 0312ed3473
commit 8fa11a75d3

View File

@@ -667,14 +667,10 @@ public:
AxisInfo
getAxisInfo(OpTy op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
if (!resTy)
return AxisInfo();
auto shape = resTy.getShape();
auto rank = shape.size();
auto condConstancy = operands[0]->getValue().getConstancy();
auto lhsInfo = operands[1]->getValue();
auto rhsInfo = operands[2]->getValue();
auto rank = lhsInfo.getRank();
AxisInfo::DimVectorT contiguity, divisibility, constancy;
std::optional<int64_t> constantValue;