[BACKEND] Fix topological sort and add new test cases (#1132)

Previous https://github.com/openai/triton/pull/1113 forgot to consider
that a node may have multiple parents, visiting the instruction before
any parent violates the semantic of topological sort.

The fixed implementation exhaustively add all operations into a
candidate subgraph and move an operation to the "ready" queue once all
of its operands have been visited.
This commit is contained in:
Keren Zhou
2023-01-31 23:41:20 -08:00
committed by GitHub
parent fc846e5e1e
commit 5dd8ce3745
6 changed files with 632 additions and 12 deletions

View File

@@ -2,6 +2,7 @@
#include "mlir/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include <deque>
namespace mlir {
@@ -164,4 +165,126 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
dotOperandLayout.getParent() == mmaLayout;
}
namespace {
/// A data structure similar to SetVector but maintains
/// a deque instead of a vector to allow for efficient
/// push_back and pop_front operations.
/// Using SetVector doesn't suffice our needs because
/// it only pushes and pops from the back.
/// For example, if we have a queue like this:
/// 0->4 1->2->3
/// ^--------
/// where 3 depends on 4, once we pop 3, we found
/// 4 is not ready, so we check 2 and push 3 back
/// to the queue.
struct DFSSubgraphState {
DFSSubgraphState() : set(), deque() {}
DenseSet<Operation *> set;
std::deque<Operation *> deque;
bool push_back(Operation *op) {
if (set.insert(op).second) {
deque.push_back(op);
return true;
}
return false;
}
Operation *pop_front() {
Operation *op = deque.front();
deque.pop_front();
set.erase(op);
return op;
}
bool empty() { return deque.empty(); }
};
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all operations but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
/// We mark each op as ready if all its operands are seen. If an op is ready,
/// we add it to the queue. Otherwise, we keep adding its operands to the
/// ancestors set.
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
SmallVector<Operation *, 4> &readyQueue) {
bool ready = true;
for (Value operand : op->getOperands()) {
auto def = operand.getDefiningOp();
if (def && !seen.count(def)) {
subGraph.push_back(def);
ready = false;
}
}
if (ready)
readyQueue.push_back(op);
}
};
void dfsPostorder(Operation *root, DFSState *state) {
DFSSubgraphState subGraph;
subGraph.push_back(root);
SmallVector<Operation *> ops;
while (!subGraph.empty()) {
// Nodes in the ready queue are ready to be processed.
// Meaning that either their operands are all seen or they have null
// operands.
SmallVector<Operation *, 4> readyQueue;
auto *current = subGraph.pop_front();
state->addToReadyQueue(current, subGraph, readyQueue);
while (!readyQueue.empty()) {
Operation *current = readyQueue.pop_back_val();
if (!state->seen.insert(current).second)
continue;
ops.push_back(current);
for (Value result : current->getResults()) {
for (Operation *op : result.getUsers())
state->addToReadyQueue(op, subGraph, readyQueue);
}
for (Region &region : current->getRegions()) {
for (Operation &op : region.getOps())
state->addToReadyQueue(&op, subGraph, readyQueue);
}
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
} // namespace
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
// Run from each root with global count and `seen` set.
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}
// Reorder and return.
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}
} // namespace mlir