mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Rebase Triton to LLVM-15. (#1070)
This PR rebases Triton from LLVM-14 to LLVM-15. Most changes are mechanical, except for the analysis framework changes.
This commit is contained in:
@@ -1,5 +1,8 @@
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
||||
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <deque>
|
||||
@@ -325,4 +328,55 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
|
||||
return multiRootTopologicalSort(slice);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
|
||||
// interacts with constant propagation, but SparseConstantPropagation
|
||||
// doesn't seem to be sufficient.
|
||||
struct ConstantAnalysis : public DataFlowAnalysis {
|
||||
using DataFlowAnalysis::DataFlowAnalysis;
|
||||
|
||||
LogicalResult initialize(Operation *top) override {
|
||||
WalkResult result = top->walk([&](Operation *op) {
|
||||
if (failed(visit(op)))
|
||||
return WalkResult::interrupt();
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return success(!result.wasInterrupted());
|
||||
}
|
||||
|
||||
LogicalResult visit(ProgramPoint point) override {
|
||||
Operation *op = point.get<Operation *>();
|
||||
Attribute value;
|
||||
if (matchPattern(op, m_Constant(&value))) {
|
||||
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
|
||||
op->getResult(0));
|
||||
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
|
||||
value, op->getDialect())));
|
||||
return success();
|
||||
}
|
||||
setAllToUnknownConstants(op->getResults());
|
||||
for (Region ®ion : op->getRegions())
|
||||
setAllToUnknownConstants(region.getArguments());
|
||||
return success();
|
||||
}
|
||||
|
||||
/// Set all given values as not constants.
|
||||
void setAllToUnknownConstants(ValueRange values) {
|
||||
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
|
||||
for (Value value : values) {
|
||||
auto *constant =
|
||||
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
|
||||
propagateIfChanged(constant, constant->join(unknownConstant));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
|
||||
auto solver = std::make_unique<DataFlowSolver>();
|
||||
solver->load<dataflow::DeadCodeAnalysis>();
|
||||
solver->load<ConstantAnalysis>();
|
||||
return solver;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user