Rebase to LLVM-head. (#1200)

Rebase to
37b7a60cd7
This commit is contained in:
Christian Sigg
2023-02-17 22:16:11 +01:00
committed by GitHub
parent 3b72ebd199
commit 9ef4b5d773
56 changed files with 535 additions and 516 deletions

View File

@@ -111,6 +111,11 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
// The gcd of both arguments for each dimension
AxisInfo AxisInfo::join(const AxisInfo &lhs, const AxisInfo &rhs) {
// If one argument is not initialized, return the other.
if (lhs.getRank() == 0)
return rhs;
if (rhs.getRank() == 0)
return lhs;
DimVectorT contiguity;
DimVectorT divisibility;
DimVectorT constancy;
@@ -151,8 +156,8 @@ public:
AxisInfo
getAxisInfo(triton::MakeRangeOp op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto start = op.start();
auto end = op.end();
auto start = op.getStart();
auto end = op.getEnd();
return AxisInfo(/*contiguity=*/{end - start},
/*divisibility=*/{highestPowOf2Divisor(start)},
/*constancy=*/{1});
@@ -450,9 +455,9 @@ public:
AxisInfo::DimVectorT contiguity = opInfo.getContiguity();
AxisInfo::DimVectorT divisibility = opInfo.getDivisibility();
AxisInfo::DimVectorT constancy = opInfo.getConstancy();
contiguity.insert(contiguity.begin() + op.axis(), 1);
divisibility.insert(divisibility.begin() + op.axis(), 1);
constancy.insert(constancy.begin() + op.axis(), 1);
contiguity.insert(contiguity.begin() + op.getAxis(), 1);
divisibility.insert(divisibility.begin() + op.getAxis(), 1);
constancy.insert(constancy.begin() + op.getAxis(), 1);
return AxisInfo(contiguity, divisibility, constancy,
operands[0]->getValue().getConstantValue());
}
@@ -551,7 +556,7 @@ public:
private:
static arith::CmpIPredicate getPredicate(triton::gpu::CmpIOp op) {
return op.predicate();
return op.getPredicate();
}
static arith::CmpIPredicate getPredicate(arith::CmpIOp op) {
@@ -843,7 +848,7 @@ void AxisInfoAnalysis::visitOperation(
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
AxisInfo curr = visitors.apply(op, operands);
if (curr.getRank() == 0) {
return markAllPessimisticFixpoint(results);
return setAllToEntryStates(results);
}
// override with hint
auto newContiguity = curr.getContiguity();
@@ -892,7 +897,7 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
if (!tensorTy)
return 1;
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr);
if (!latticeElement || latticeElement->isUninitialized())
if (!latticeElement)
return 1;
auto axisInfo = latticeElement->getValue();
auto layout = tensorTy.getEncoding();
@@ -911,7 +916,7 @@ unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
if (!tensorTy)
return 1;
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask);
if (!latticeElement || latticeElement->isUninitialized())
if (!latticeElement)
return 1;
auto maskAxis = latticeElement->getValue();
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());