Rebase Triton to LLVM-15. (#1070)

This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are
mechanical, except for the analysis framework changes.
This commit is contained in:
Christian Sigg
2023-02-16 15:40:53 +01:00
committed by GitHub
parent f21e76affe
commit fc7a8e3581
78 changed files with 807 additions and 741 deletions

View File

@@ -1,25 +1,15 @@
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Utility.h"
using namespace mlir;
namespace {
struct TestAxisInfoPass
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
: public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> {
// LLVM15+
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
os << name << ": [";
for (size_t d = 0; d < vals.size(); d++) {
if (d != 0)
os << ", ";
os << vals[d];
}
os << "]";
}
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
StringRef getArgument() const final { return "test-print-alignment"; }
StringRef getDescription() const final {
@@ -30,38 +20,19 @@ struct TestAxisInfoPass
Operation *operation = getOperation();
auto &os = llvm::errs();
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
AxisInfoAnalysis analysis(&getContext());
analysis.run(operation);
os << "@" << opName << "\n";
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(operation)))
return signalPassFailure();
operation->walk([&](Operation *op) {
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
// std::ostringstream oss;
// result.print(oss);
// os << " => ";
LatticeElement<AxisInfo> *latticeElement =
analysis.lookupLatticeElement(result);
if (!latticeElement) {
os << "None\n";
return;
}
AxisInfo &info = latticeElement->getValue();
print("Contiguity", os, info.getContiguity());
os << " ; ";
print("Divisibility", os, info.getDivisibility());
os << " ; ";
print("Constancy", os, info.getConstancy());
os << " ; ";
auto constantValue = info.getConstantValue();
os << "ConstantValue: [";
if (constantValue.has_value())
os << constantValue.value();
else
os << "None";
os << "] ( ";
result.print(os);
os << " ) ";
os << " => ";
analysis->getLatticeElement(result)->getValue().print(os);
os << "\n";
}
});