mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Rebase Triton to LLVM-15. (#1070)
This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are mechanical, except for the analysis framework changes.
This commit is contained in:
@@ -1,25 +1,15 @@
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
|
||||
struct TestAxisInfoPass
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<FuncOp>> {
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<func::FuncOp>> {
|
||||
|
||||
// LLVM15+
|
||||
// MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAlignmentPass);
|
||||
|
||||
void print(const std::string &name, raw_ostream &os, ArrayRef<int64_t> vals) {
|
||||
os << name << ": [";
|
||||
for (size_t d = 0; d < vals.size(); d++) {
|
||||
if (d != 0)
|
||||
os << ", ";
|
||||
os << vals[d];
|
||||
}
|
||||
os << "]";
|
||||
}
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
|
||||
|
||||
StringRef getArgument() const final { return "test-print-alignment"; }
|
||||
StringRef getDescription() const final {
|
||||
@@ -30,38 +20,19 @@ struct TestAxisInfoPass
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
AxisInfoAnalysis analysis(&getContext());
|
||||
analysis.run(operation);
|
||||
os << "@" << opName << "\n";
|
||||
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(operation)))
|
||||
return signalPassFailure();
|
||||
operation->walk([&](Operation *op) {
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
// std::ostringstream oss;
|
||||
// result.print(oss);
|
||||
// os << " => ";
|
||||
LatticeElement<AxisInfo> *latticeElement =
|
||||
analysis.lookupLatticeElement(result);
|
||||
if (!latticeElement) {
|
||||
os << "None\n";
|
||||
return;
|
||||
}
|
||||
AxisInfo &info = latticeElement->getValue();
|
||||
print("Contiguity", os, info.getContiguity());
|
||||
os << " ; ";
|
||||
print("Divisibility", os, info.getDivisibility());
|
||||
os << " ; ";
|
||||
print("Constancy", os, info.getConstancy());
|
||||
os << " ; ";
|
||||
auto constantValue = info.getConstantValue();
|
||||
os << "ConstantValue: [";
|
||||
if (constantValue.has_value())
|
||||
os << constantValue.value();
|
||||
else
|
||||
os << "None";
|
||||
os << "] ( ";
|
||||
result.print(os);
|
||||
os << " ) ";
|
||||
os << " => ";
|
||||
analysis->getLatticeElement(result)->getValue().print(os);
|
||||
os << "\n";
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user