mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Fix calculation of unique number of threads within a warp. We need to consider the number of elements per thread in the calculation. Also change the layout test to integer sum in order to catch bugs with unique data as max reduction may hide those kind of problems.
554 lines
19 KiB
C++
554 lines
19 KiB
C++
#include "triton/Analysis/Utility.h"
|
|
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
|
|
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
|
|
#include "mlir/IR/Dialect.h"
|
|
#include "mlir/IR/Matchers.h"
|
|
#include "triton/Dialect/Triton/IR/Dialect.h"
|
|
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
|
#include "triton/Tools/Sys/GetEnv.hpp"
|
|
#include <deque>
|
|
|
|
namespace mlir {
|
|
|
|
namespace {
|
|
|
|
int getParentAxis(Attribute layout, int axis) {
|
|
if (auto sliceEncoding = layout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
|
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
|
|
return getParentAxis(sliceEncoding.getParent(), axis);
|
|
}
|
|
return axis;
|
|
}
|
|
|
|
SmallVector<unsigned> getParentOrder(Attribute layout) {
|
|
if (auto sliceEncoding = layout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
|
return getParentOrder(sliceEncoding.getParent());
|
|
}
|
|
return triton::gpu::getOrder(layout);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool ReduceOpHelper::isFastReduction() {
|
|
// Disable fast reduction only for debugging purpose
|
|
if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION"))
|
|
return false;
|
|
return getParentAxis(getSrcLayout(), axis) ==
|
|
getParentOrder(getSrcLayout())[0];
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getInterWarpSize() {
|
|
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
|
unsigned sizeIntraWarps = getIntraWarpSize();
|
|
return std::min(srcReduceDimSize / sizeIntraWarps,
|
|
triton::gpu::getWarpsPerCTA(getSrcLayout())[axis]);
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getIntraWarpSize() {
|
|
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
|
return std::min(srcReduceDimSize,
|
|
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
|
|
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
|
unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData();
|
|
return std::min(srcReduceDimSize / sizeIntraWarps,
|
|
triton::gpu::getWarpsPerCTAWithUniqueData(
|
|
getSrcLayout(), getSrcShape())[axis]);
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
|
|
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
|
unsigned elementPerThreads = triton::gpu::getUniqueContigPerThread(
|
|
getSrcLayout(), getSrcShape())[axis];
|
|
return std::min(srcReduceDimSize / elementPerThreads,
|
|
triton::gpu::getThreadsPerWarpWithUniqueData(
|
|
getSrcLayout(), getSrcShape())[axis]);
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
|
auto srcLayout = getSrcLayout();
|
|
auto srcShape = getSrcShape();
|
|
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout,
|
|
srcShape)[axis] *
|
|
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
|
|
}
|
|
|
|
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
|
auto smemShape = convertType<unsigned>(getSrcShape());
|
|
smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis());
|
|
return smemShape;
|
|
}
|
|
|
|
SmallVector<SmallVector<unsigned>> ReduceOpHelper::getScratchConfigsFast() {
|
|
SmallVector<SmallVector<unsigned>> smemShapes(3);
|
|
|
|
auto argLayout = getSrcLayout();
|
|
auto argLayoutMma = argLayout.dyn_cast<triton::gpu::MmaEncodingAttr>();
|
|
// if (argLayoutMma && argLayoutMma.getVersionMajor() == 2 &&
|
|
// triton::gpu::getWarpsPerCTA(argLayout)[axis] == 1)
|
|
// return {{1, 1}, {1, 1}};
|
|
|
|
/// shared memory block0
|
|
smemShapes[0] = convertType<unsigned>(getSrcShape());
|
|
smemShapes[0][axis] = getInterWarpSize();
|
|
|
|
/// FIXME(Qingyi): This size is actually larger than required.
|
|
/// shared memory block1:
|
|
auto mod = op->getParentOfType<ModuleOp>();
|
|
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
unsigned threadsPerWarp =
|
|
triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
|
|
smemShapes[1].push_back(numWarps * threadsPerWarp);
|
|
|
|
return smemShapes;
|
|
}
|
|
|
|
unsigned ReduceOpHelper::getScratchSizeInBytes() {
|
|
unsigned elems = 0;
|
|
if (isFastReduction()) {
|
|
auto smemShapes = getScratchConfigsFast();
|
|
for (const auto &smemShape : smemShapes)
|
|
elems = std::max(elems, product<unsigned>(smemShape));
|
|
} else {
|
|
auto smemShape = getScratchConfigBasic();
|
|
elems = product<unsigned>(smemShape);
|
|
}
|
|
|
|
unsigned bytesPerElem = 0;
|
|
for (const auto &ty : srcElementTypes) {
|
|
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
|
|
}
|
|
return bytesPerElem * elems;
|
|
}
|
|
|
|
bool ReduceOpHelper::isSupportedLayout() {
|
|
auto srcLayout = getSrcLayout();
|
|
if (srcLayout.isa<triton::gpu::BlockedEncodingAttr>()) {
|
|
return true;
|
|
}
|
|
if (auto mmaLayout = srcLayout.dyn_cast<triton::gpu::MmaEncodingAttr>()) {
|
|
if (mmaLayout.isAmpere()) {
|
|
return true;
|
|
}
|
|
}
|
|
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
|
|
return getEncoding().getSizePerThread()[getAxis()];
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
|
|
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
|
|
sizePerThreads[getAxis()] = 1;
|
|
return product<unsigned>(sizePerThreads);
|
|
}
|
|
|
|
Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }
|
|
|
|
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
|
|
return triton::gpu::getThreadsPerWarp(getEncoding())[getAxis()];
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding());
|
|
threadsPerWarp[getAxis()] = 1;
|
|
return product<unsigned>(threadsPerWarp);
|
|
}
|
|
|
|
// Return the flat numbers of threads computing independent scan results.
|
|
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
|
|
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding());
|
|
warpsPerCTA[getAxis()] = 1;
|
|
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
|
|
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
|
|
}
|
|
unsigned ScanLoweringHelper::getAxisNumWarps() {
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
|
return warpsPerCTA[getAxis()];
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getAxisNumBlocks() {
|
|
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
|
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
|
unsigned axis = getAxis();
|
|
return ceil<unsigned>(
|
|
type.getShape()[axis],
|
|
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
|
|
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
|
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
|
unsigned axis = getAxis();
|
|
unsigned numBlocks = 1;
|
|
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
|
|
if (i == axis)
|
|
continue;
|
|
numBlocks *= ceil<unsigned>(
|
|
type.getShape()[i],
|
|
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
|
|
}
|
|
return numBlocks;
|
|
}
|
|
|
|
bool ScanLoweringHelper::isSupported() {
|
|
// TODO: Support the following cases:
|
|
// 1. Scan on non-blocking encodings
|
|
// 2. Scan with multiple operands
|
|
if (!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
|
|
return false;
|
|
if (scanOp.getNumOperands() != 1)
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
|
|
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
|
unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8;
|
|
auto mod = scanOp->getParentOfType<ModuleOp>();
|
|
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
|
unsigned numNonAxisElementsPerWapr =
|
|
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
|
|
unsigned numElements = numWarps * numNonAxisElementsPerWapr *
|
|
getAxisNumBlocks() * getNonAxisNumBlocks();
|
|
return elementSizeInBytes * numElements;
|
|
}
|
|
|
|
triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
|
|
return srcEncoding.cast<triton::gpu::BlockedEncodingAttr>();
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getAxisElementStride() {
|
|
auto order = triton::gpu::getOrder(srcEncoding);
|
|
unsigned stride = 1;
|
|
for (unsigned dim : order) {
|
|
if (dim == getAxis())
|
|
return stride;
|
|
stride *= getContigPerThread(getEncoding())[dim];
|
|
}
|
|
llvm_unreachable("Axis not found in order");
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getAxisThreadStride() {
|
|
auto order = triton::gpu::getOrder(srcEncoding);
|
|
unsigned stride = 1;
|
|
for (unsigned dim : order) {
|
|
if (dim == getAxis())
|
|
return stride;
|
|
stride *= getEncoding().getThreadsPerWarp()[dim];
|
|
}
|
|
llvm_unreachable("Axis not found in order");
|
|
}
|
|
|
|
unsigned ScanLoweringHelper::getAxisBlockStride() {
|
|
auto order = triton::gpu::getOrder(srcEncoding);
|
|
unsigned stride = 1;
|
|
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
|
|
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
|
|
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
|
|
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
|
|
for (unsigned dim : order) {
|
|
if (dim == getAxis())
|
|
return stride;
|
|
stride *= type.getShape()[dim] /
|
|
(sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]);
|
|
}
|
|
llvm_unreachable("Axis not found in order");
|
|
}
|
|
|
|
bool maybeSharedAllocationOp(Operation *op) {
|
|
// TODO(Keren): This function can be replaced by adding
|
|
// MemoryEffectOpInterface. We can then use the MemoryEffectOpInterface to
|
|
// query the memory effects of the op.
|
|
auto *dialect = op->getDialect();
|
|
return dialect &&
|
|
(dialect->getTypeID() ==
|
|
mlir::TypeID::get<triton::gpu::TritonGPUDialect>() ||
|
|
dialect->getTypeID() == mlir::TypeID::get<triton::TritonDialect>() ||
|
|
dialect->getTypeID() == mlir::TypeID::get<arith::ArithDialect>() ||
|
|
dialect->getTypeID() == mlir::TypeID::get<tensor::TensorDialect>());
|
|
}
|
|
|
|
bool maybeAliasOp(Operation *op) {
|
|
return isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
|
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
|
isa<tensor::InsertSliceOp>(op);
|
|
}
|
|
|
|
bool supportMMA(triton::DotOp op, int version) {
|
|
// Refer to mma section for the data type supported by Volta and Hopper
|
|
// Tensor Core in
|
|
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
|
|
auto aElemTy = op.getA().getType().cast<RankedTensorType>().getElementType();
|
|
auto bElemTy = op.getB().getType().cast<RankedTensorType>().getElementType();
|
|
if (aElemTy.isF32() && bElemTy.isF32()) {
|
|
return op.getAllowTF32() && version >= 2;
|
|
}
|
|
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
|
|
}
|
|
|
|
bool supportMMA(Value value, int version) {
|
|
// Tell whether a DotOp support HMMA by the operand type(either $a or $b).
|
|
// We cannot get both the operand types(in TypeConverter), here we assume the
|
|
// types of both the operands are identical here.
|
|
assert((version == 1 || version == 2) &&
|
|
"Unexpected MMA layout version found");
|
|
auto elemTy = value.getType().cast<RankedTensorType>().getElementType();
|
|
return elemTy.isF16() || elemTy.isBF16() ||
|
|
(elemTy.isF32() && version >= 2) ||
|
|
(elemTy.isInteger(8) && version >= 2);
|
|
}
|
|
|
|
Type getElementType(Value value) {
|
|
auto type = value.getType();
|
|
if (auto tensorType = type.dyn_cast<RankedTensorType>())
|
|
return tensorType.getElementType();
|
|
return type;
|
|
}
|
|
|
|
bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
|
|
// dot_op<opIdx=0, parent=#mma> = #mma
|
|
// when #mma = MmaEncoding<version=2, warpsPerCTA=[..., 1]>
|
|
auto srcLayout = srcTy.getEncoding();
|
|
auto dstLayout = dstTy.getEncoding();
|
|
auto mmaLayout = srcLayout.cast<triton::gpu::MmaEncodingAttr>();
|
|
auto dotOperandLayout = dstLayout.cast<triton::gpu::DotOperandEncodingAttr>();
|
|
return mmaLayout.getVersionMajor() == 2 &&
|
|
mmaLayout.getWarpsPerCTA()[1] == 1 &&
|
|
dotOperandLayout.getOpIdx() == 0 &&
|
|
dotOperandLayout.getParent() == mmaLayout &&
|
|
!srcTy.getElementType().isF32();
|
|
}
|
|
|
|
bool isSingleValue(Value value) {
|
|
// Don't consider load as expensive if it is loading a scalar.
|
|
if (auto tensorTy = value.getType().dyn_cast<RankedTensorType>())
|
|
return tensorTy.getNumElements() == 1;
|
|
// TODO: Handle other cases.
|
|
// For example, when ptr is a tensor of single value.
|
|
// It means that ptr is a resultant of broadcast or generated through
|
|
// a chain of broadcast and other operations.
|
|
// Rematerialize it without considering contiguous memory access pattern is
|
|
// fine.
|
|
return true;
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// A data structure similar to SetVector but maintains
|
|
/// a deque instead of a vector to allow for efficient
|
|
/// push_back and pop_front operations.
|
|
/// Using SetVector doesn't suffice our needs because
|
|
/// it only pushes and pops from the back.
|
|
/// For example, if we have a queue like this:
|
|
/// 0->4 1->2->3
|
|
/// ^--------
|
|
/// where 3 depends on 4, once we pop 3, we found
|
|
/// 4 is not ready, so we check 2 and push 3 back
|
|
/// to the queue.
|
|
struct DFSSubgraphState {
|
|
DFSSubgraphState() : set(), deque() {}
|
|
DenseSet<Operation *> set;
|
|
std::deque<Operation *> deque;
|
|
|
|
bool push_back(Operation *op) {
|
|
if (set.insert(op).second) {
|
|
deque.push_back(op);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
Operation *pop_front() {
|
|
Operation *op = deque.front();
|
|
deque.pop_front();
|
|
set.erase(op);
|
|
return op;
|
|
}
|
|
|
|
bool empty() { return deque.empty(); }
|
|
};
|
|
|
|
/// DFS post-order implementation that maintains a global count to work across
|
|
/// multiple invocations, to help implement topological sort on multi-root DAGs.
|
|
/// We traverse all operations but only record the ones that appear in
|
|
/// `toSort` for the final result.
|
|
struct DFSState {
|
|
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
|
|
const SetVector<Operation *> &toSort;
|
|
SmallVector<Operation *, 16> topologicalCounts;
|
|
DenseSet<Operation *> seen;
|
|
|
|
/// We mark each op as ready if all its operands are seen. If an op is ready,
|
|
/// we add it to the queue. Otherwise, we keep adding its operands to the
|
|
/// ancestors set.
|
|
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
|
|
SmallVector<Operation *, 4> &readyQueue) {
|
|
bool ready = true;
|
|
for (Value operand : op->getOperands()) {
|
|
auto def = operand.getDefiningOp();
|
|
if (def && !seen.count(def)) {
|
|
subGraph.push_back(def);
|
|
ready = false;
|
|
}
|
|
}
|
|
if (ready)
|
|
readyQueue.push_back(op);
|
|
}
|
|
};
|
|
|
|
void dfsPostorder(Operation *root, DFSState *state) {
|
|
DFSSubgraphState subGraph;
|
|
subGraph.push_back(root);
|
|
SmallVector<Operation *> ops;
|
|
while (!subGraph.empty()) {
|
|
// Nodes in the ready queue are ready to be processed.
|
|
// Meaning that either their operands are all seen or they have null
|
|
// operands.
|
|
SmallVector<Operation *, 4> readyQueue;
|
|
auto *current = subGraph.pop_front();
|
|
state->addToReadyQueue(current, subGraph, readyQueue);
|
|
while (!readyQueue.empty()) {
|
|
Operation *current = readyQueue.pop_back_val();
|
|
if (!state->seen.insert(current).second)
|
|
continue;
|
|
ops.push_back(current);
|
|
for (Value result : current->getResults()) {
|
|
for (Operation *op : result.getUsers())
|
|
state->addToReadyQueue(op, subGraph, readyQueue);
|
|
}
|
|
for (Region ®ion : current->getRegions()) {
|
|
for (Operation &op : region.getOps())
|
|
state->addToReadyQueue(&op, subGraph, readyQueue);
|
|
}
|
|
}
|
|
}
|
|
|
|
for (Operation *op : llvm::reverse(ops)) {
|
|
if (state->toSort.count(op) > 0)
|
|
state->topologicalCounts.push_back(op);
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
SetVector<Operation *>
|
|
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
|
|
if (toSort.empty()) {
|
|
return toSort;
|
|
}
|
|
|
|
// Run from each root with global count and `seen` set.
|
|
DFSState state(toSort);
|
|
for (auto *s : toSort) {
|
|
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
|
|
dfsPostorder(s, &state);
|
|
}
|
|
|
|
// Reorder and return.
|
|
SetVector<Operation *> res;
|
|
for (auto it = state.topologicalCounts.rbegin(),
|
|
eit = state.topologicalCounts.rend();
|
|
it != eit; ++it) {
|
|
res.insert(*it);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
SetVector<Operation *> multiRootGetSlice(Operation *op,
|
|
TransitiveFilter backwardFilter,
|
|
TransitiveFilter forwardFilter) {
|
|
SetVector<Operation *> slice;
|
|
slice.insert(op);
|
|
|
|
unsigned currentIndex = 0;
|
|
SetVector<Operation *> backwardSlice;
|
|
SetVector<Operation *> forwardSlice;
|
|
while (currentIndex != slice.size()) {
|
|
auto *currentOp = (slice)[currentIndex];
|
|
// Compute and insert the backwardSlice starting from currentOp.
|
|
backwardSlice.clear();
|
|
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
|
|
slice.insert(backwardSlice.begin(), backwardSlice.end());
|
|
|
|
// Compute and insert the forwardSlice starting from currentOp.
|
|
forwardSlice.clear();
|
|
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
|
|
slice.insert(forwardSlice.begin(), forwardSlice.end());
|
|
++currentIndex;
|
|
}
|
|
return multiRootTopologicalSort(slice);
|
|
}
|
|
|
|
namespace {
|
|
// Copied from TestDeadCodeAnalysis.cpp, because some dead code analysis
|
|
// interacts with constant propagation, but SparseConstantPropagation
|
|
// doesn't seem to be sufficient.
|
|
class ConstantAnalysis : public DataFlowAnalysis {
|
|
public:
|
|
using DataFlowAnalysis::DataFlowAnalysis;
|
|
|
|
LogicalResult initialize(Operation *top) override {
|
|
WalkResult result = top->walk([&](Operation *op) {
|
|
if (failed(visit(op)))
|
|
return WalkResult::interrupt();
|
|
return WalkResult::advance();
|
|
});
|
|
return success(!result.wasInterrupted());
|
|
}
|
|
|
|
LogicalResult visit(ProgramPoint point) override {
|
|
Operation *op = point.get<Operation *>();
|
|
Attribute value;
|
|
if (matchPattern(op, m_Constant(&value))) {
|
|
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
|
|
op->getResult(0));
|
|
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
|
|
value, op->getDialect())));
|
|
return success();
|
|
}
|
|
// Dead code analysis requires every operands has initialized ConstantValue
|
|
// state before it is visited.
|
|
// https://github.com/llvm/llvm-project/blob/2ec1aba2b69faa1de5f71832a48e25aa3b5d5314/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp#L322
|
|
// That's why we need to set all operands to unknown constants.
|
|
setAllToUnknownConstants(op->getResults());
|
|
for (Region ®ion : op->getRegions()) {
|
|
for (Block &block : region.getBlocks())
|
|
setAllToUnknownConstants(block.getArguments());
|
|
}
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
/// Set all given values as not constants.
|
|
void setAllToUnknownConstants(ValueRange values) {
|
|
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
|
|
for (Value value : values) {
|
|
auto *constant =
|
|
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
|
|
propagateIfChanged(constant, constant->join(unknownConstant));
|
|
}
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
|
|
auto solver = std::make_unique<DataFlowSolver>();
|
|
solver->load<dataflow::DeadCodeAnalysis>();
|
|
solver->load<ConstantAnalysis>();
|
|
return solver;
|
|
}
|
|
|
|
} // namespace mlir
|