[FRONTEND] Adding static range (#1130)

Included: Revert "[BACKEND] Replace `mlir::topologicalSort` with a
custom implementation (#1113)"
This commit is contained in:
Philippe Tillet
2023-01-31 18:04:19 -08:00
committed by GitHub
parent be3da96919
commit 8fea1fb478
7 changed files with 52 additions and 89 deletions

View File

@@ -164,65 +164,4 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
dotOperandLayout.getParent() == mmaLayout;
}
namespace {
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all operations but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
};
void dfsPostorder(Operation *root, DFSState *state) {
SmallVector<Operation *> queue(1, root);
std::vector<Operation *> ops;
while (!queue.empty()) {
Operation *current = queue.pop_back_val();
if (!state->seen.insert(current).second)
continue;
ops.push_back(current);
for (Value result : current->getResults()) {
for (Operation *op : result.getUsers())
queue.push_back(op);
}
for (Region &region : current->getRegions()) {
for (Operation &op : region.getOps())
queue.push_back(&op);
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
} // namespace
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
// Run from each root with global count and `seen` set.
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}
// Reorder and return.
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}
} // namespace mlir