mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
@@ -111,6 +111,11 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
|
||||
// The gcd of both arguments for each dimension
|
||||
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
|
||||
// If one argument is not initialized, return the other.
|
||||
if (lhs.getRank() == 0)
|
||||
return rhs;
|
||||
if (rhs.getRank() == 0)
|
||||
return lhs;
|
||||
DimVectorT contiguity;
|
||||
DimVectorT divisibility;
|
||||
DimVectorT constancy;
|
||||
@@ -151,8 +156,8 @@ public:
|
||||
AxisInfo
|
||||
getAxisInfo(triton::MakeRangeOp op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
|
||||
auto start = op.start();
|
||||
auto end = op.end();
|
||||
auto start = op.getStart();
|
||||
auto end = op.getEnd();
|
||||
return AxisInfo(/*contiguity=*/{end - start},
|
||||
/*divisibility=*/{highestPowOf2Divisor(start)},
|
||||
/*constancy=*/{1});
|
||||
@@ -450,9 +455,9 @@ public:
|
||||
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
|
||||
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
|
||||
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
|
||||
contiguity.insert(contiguity.begin() + op.axis(), 1);
|
||||
divisibility.insert(divisibility.begin() + op.axis(), 1);
|
||||
constancy.insert(constancy.begin() + op.axis(), 1);
|
||||
contiguity.insert(contiguity.begin() + op.getAxis(), 1);
|
||||
divisibility.insert(divisibility.begin() + op.getAxis(), 1);
|
||||
constancy.insert(constancy.begin() + op.getAxis(), 1);
|
||||
return AxisInfo(contiguity, divisibility, constancy,
|
||||
operands[0]->getValue().getConstantValue());
|
||||
}
|
||||
@@ -551,7 +556,7 @@ public:
|
||||
|
||||
private:
|
||||
static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) {
|
||||
return op.predicate();
|
||||
return op.getPredicate();
|
||||
}
|
||||
|
||||
static arith::CmpIPredicate getPredicate(arith::CmpIOp op) {
|
||||
@@ -843,7 +848,7 @@ void AxisInfoAnalysis::visitOperation(
|
||||
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
|
||||
AxisInfo curr = visitors.apply(op, operands);
|
||||
if (curr.getRank() == 0) {
|
||||
return markAllPessimisticFixpoint(results);
|
||||
return setAllToEntryStates(results);
|
||||
}
|
||||
// override with hint
|
||||
auto newContiguity = curr.getContiguity();
|
||||
@@ -892,7 +897,7 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr);
|
||||
if (!latticeElement || latticeElement->isUninitialized())
|
||||
if (!latticeElement)
|
||||
return 1;
|
||||
auto axisInfo = latticeElement->getValue();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
@@ -911,7 +916,7 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask);
|
||||
if (!latticeElement || latticeElement->isUninitialized())
|
||||
if (!latticeElement)
|
||||
return 1;
|
||||
auto maskAxis = latticeElement->getValue();
|
||||
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
|
||||
|
||||
Reference in New Issue
Block a user