mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[BACKEND] Fix topological sort and add new test cases (#1132)
Previous https://github.com/openai/triton/pull/1113 forgot to consider that a node may have multiple parents, visiting the instruction before any parent violates the semantic of topological sort. The fixed implementation exhaustively add all operations into a candidate subgraph and move an operation to the "ready" queue once all of its operands have been visited.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include <deque>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -164,4 +165,126 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/// A data structure similar to SetVector but maintains
|
||||
/// a deque instead of a vector to allow for efficient
|
||||
/// push_back and pop_front operations.
|
||||
/// Using SetVector doesn't suffice our needs because
|
||||
/// it only pushes and pops from the back.
|
||||
/// For example, if we have a queue like this:
|
||||
/// 0->4 1->2->3
|
||||
/// ^--------
|
||||
/// where 3 depends on 4, once we pop 3, we found
|
||||
/// 4 is not ready, so we check 2 and push 3 back
|
||||
/// to the queue.
|
||||
struct DFSSubgraphState {
|
||||
DFSSubgraphState() : set(), deque() {}
|
||||
DenseSet<Operation *> set;
|
||||
std::deque<Operation *> deque;
|
||||
|
||||
bool push_back(Operation *op) {
|
||||
if (set.insert(op).second) {
|
||||
deque.push_back(op);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Operation *pop_front() {
|
||||
Operation *op = deque.front();
|
||||
deque.pop_front();
|
||||
set.erase(op);
|
||||
return op;
|
||||
}
|
||||
|
||||
bool empty() { return deque.empty(); }
|
||||
};
|
||||
|
||||
/// DFS post-order implementation that maintains a global count to work across
|
||||
/// multiple invocations, to help implement topological sort on multi-root DAGs.
|
||||
/// We traverse all operations but only record the ones that appear in
|
||||
/// `toSort` for the final result.
|
||||
struct DFSState {
|
||||
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
|
||||
const SetVector<Operation *> &toSort;
|
||||
SmallVector<Operation *, 16> topologicalCounts;
|
||||
DenseSet<Operation *> seen;
|
||||
|
||||
/// We mark each op as ready if all its operands are seen. If an op is ready,
|
||||
/// we add it to the queue. Otherwise, we keep adding its operands to the
|
||||
/// ancestors set.
|
||||
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
|
||||
SmallVector<Operation *, 4> &readyQueue) {
|
||||
bool ready = true;
|
||||
for (Value operand : op->getOperands()) {
|
||||
auto def = operand.getDefiningOp();
|
||||
if (def && !seen.count(def)) {
|
||||
subGraph.push_back(def);
|
||||
ready = false;
|
||||
}
|
||||
}
|
||||
if (ready)
|
||||
readyQueue.push_back(op);
|
||||
}
|
||||
};
|
||||
|
||||
void dfsPostorder(Operation *root, DFSState *state) {
|
||||
DFSSubgraphState subGraph;
|
||||
subGraph.push_back(root);
|
||||
SmallVector<Operation *> ops;
|
||||
while (!subGraph.empty()) {
|
||||
// Nodes in the ready queue are ready to be processed.
|
||||
// Meaning that either their operands are all seen or they have null
|
||||
// operands.
|
||||
SmallVector<Operation *, 4> readyQueue;
|
||||
auto *current = subGraph.pop_front();
|
||||
state->addToReadyQueue(current, subGraph, readyQueue);
|
||||
while (!readyQueue.empty()) {
|
||||
Operation *current = readyQueue.pop_back_val();
|
||||
if (!state->seen.insert(current).second)
|
||||
continue;
|
||||
ops.push_back(current);
|
||||
for (Value result : current->getResults()) {
|
||||
for (Operation *op : result.getUsers())
|
||||
state->addToReadyQueue(op, subGraph, readyQueue);
|
||||
}
|
||||
for (Region ®ion : current->getRegions()) {
|
||||
for (Operation &op : region.getOps())
|
||||
state->addToReadyQueue(&op, subGraph, readyQueue);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (Operation *op : llvm::reverse(ops)) {
|
||||
if (state->toSort.count(op) > 0)
|
||||
state->topologicalCounts.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SetVector<Operation *>
|
||||
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
|
||||
if (toSort.empty()) {
|
||||
return toSort;
|
||||
}
|
||||
|
||||
// Run from each root with global count and `seen` set.
|
||||
DFSState state(toSort);
|
||||
for (auto *s : toSort) {
|
||||
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
|
||||
dfsPostorder(s, &state);
|
||||
}
|
||||
|
||||
// Reorder and return.
|
||||
SetVector<Operation *> res;
|
||||
for (auto it = state.topologicalCounts.rbegin(),
|
||||
eit = state.topologicalCounts.rend();
|
||||
it != eit; ++it) {
|
||||
res.insert(*it);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
Reference in New Issue
Block a user