Rebase Triton to LLVM-15. (#1070)

This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are
mechanical, except for the analysis framework changes.
This commit is contained in:
Christian Sigg
2023-02-16 15:40:53 +01:00
committed by GitHub
parent f21e76affe
commit fc7a8e3581
78 changed files with 807 additions and 741 deletions

View File

@@ -1,5 +1,8 @@
#include "triton/Analysis/Utility.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <deque>
@@ -325,4 +328,55 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
return multiRootTopologicalSort(slice);
}
namespace {
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
// interacts with constant propagation, but SparseConstantPropagation
// doesn't seem to be sufficient.
struct ConstantAnalysis : public DataFlowAnalysis {
using DataFlowAnalysis::DataFlowAnalysis;
LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
if (failed(visit(op)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
}
LogicalResult visit(ProgramPoint point) override {
Operation *op = point.get<Operation *>();
Attribute value;
if (matchPattern(op, m_Constant(&value))) {
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
op->getResult(0));
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
value, op->getDialect())));
return success();
}
setAllToUnknownConstants(op->getResults());
for (Region &region : op->getRegions())
setAllToUnknownConstants(region.getArguments());
return success();
}
/// Set all given values as not constants.
void setAllToUnknownConstants(ValueRange values) {
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
for (Value value : values) {
auto *constant =
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
propagateIfChanged(constant, constant->join(unknownConstant));
}
}
};
} // namespace
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
auto solver = std::make_unique<DataFlowSolver>();
solver->load<dataflow::DeadCodeAnalysis>();
solver->load<ConstantAnalysis>();
return solver;
}
} // namespace mlir