mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[ANALYSIS] propagate divisibility through tl.where for all types (#2023)
This commit is contained in:
@@ -667,14 +667,10 @@ public:
|
||||
AxisInfo
|
||||
getAxisInfo(OpTy op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto resTy = op.getResult().getType().template dyn_cast<RankedTensorType>();
|
||||
if (!resTy)
|
||||
return AxisInfo();
|
||||
auto shape = resTy.getShape();
|
||||
auto rank = shape.size();
|
||||
auto condConstancy = operands[0]->getValue().getConstancy();
|
||||
auto lhsInfo = operands[1]->getValue();
|
||||
auto rhsInfo = operands[2]->getValue();
|
||||
auto rank = lhsInfo.getRank();
|
||||
|
||||
AxisInfo::DimVectorT contiguity, divisibility, constancy;
|
||||
std::optional<int64_t> constantValue;
|
||||
|
||||
Reference in New Issue
Block a user