mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Adding static range (#1130)
Included: Revert "[BACKEND] Replace `mlir::topologicalSort` with a custom implementation (#1113)"
This commit is contained in:
@@ -65,7 +65,7 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
|
||||
|
||||
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
|
||||
|
||||
/// output[i] = input[order[i]]
|
||||
// output[i] = input[order[i]]
|
||||
template <typename T, typename RES_T = T>
|
||||
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
size_t rank = order.size();
|
||||
@@ -80,14 +80,6 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
|
||||
bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
|
||||
triton::gpu::DotOperandEncodingAttr &dotOperandLayout);
|
||||
|
||||
/// Multi-root DAG topological sort.
|
||||
/// Performs a topological sort of the Operation in the `toSort` SetVector.
|
||||
/// Returns a topologically sorted SetVector.
|
||||
/// It is faster than mlir::topologicalSort because it prunes nodes that have
|
||||
/// been visited before.
|
||||
SetVector<Operation *>
|
||||
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
@@ -164,65 +164,4 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
|
||||
dotOperandLayout.getParent() == mmaLayout;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// DFS post-order implementation that maintains a global count to work across
|
||||
/// multiple invocations, to help implement topological sort on multi-root DAGs.
|
||||
/// We traverse all operations but only record the ones that appear in
|
||||
/// `toSort` for the final result.
|
||||
struct DFSState {
|
||||
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
|
||||
const SetVector<Operation *> &toSort;
|
||||
SmallVector<Operation *, 16> topologicalCounts;
|
||||
DenseSet<Operation *> seen;
|
||||
};
|
||||
|
||||
void dfsPostorder(Operation *root, DFSState *state) {
|
||||
SmallVector<Operation *> queue(1, root);
|
||||
std::vector<Operation *> ops;
|
||||
while (!queue.empty()) {
|
||||
Operation *current = queue.pop_back_val();
|
||||
if (!state->seen.insert(current).second)
|
||||
continue;
|
||||
ops.push_back(current);
|
||||
for (Value result : current->getResults()) {
|
||||
for (Operation *op : result.getUsers())
|
||||
queue.push_back(op);
|
||||
}
|
||||
for (Region ®ion : current->getRegions()) {
|
||||
for (Operation &op : region.getOps())
|
||||
queue.push_back(&op);
|
||||
}
|
||||
}
|
||||
|
||||
for (Operation *op : llvm::reverse(ops)) {
|
||||
if (state->toSort.count(op) > 0)
|
||||
state->topologicalCounts.push_back(op);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SetVector<Operation *>
|
||||
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
|
||||
if (toSort.empty()) {
|
||||
return toSort;
|
||||
}
|
||||
|
||||
// Run from each root with global count and `seen` set.
|
||||
DFSState state(toSort);
|
||||
for (auto *s : toSort) {
|
||||
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
|
||||
dfsPostorder(s, &state);
|
||||
}
|
||||
|
||||
// Reorder and return.
|
||||
SetVector<Operation *> res;
|
||||
for (auto it = state.topologicalCounts.rbegin(),
|
||||
eit = state.topologicalCounts.rend();
|
||||
it != eit; ++it) {
|
||||
res.insert(*it);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -652,7 +652,7 @@ public:
|
||||
else
|
||||
sortedValues.push_back(v);
|
||||
}
|
||||
tmp = mlir::multiRootTopologicalSort(tmp);
|
||||
tmp = mlir::topologicalSort(tmp);
|
||||
for (Operation *op : tmp)
|
||||
sortedValues.push_back(op->getResult(0));
|
||||
|
||||
|
||||
@@ -625,29 +625,29 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return [self.visit(dim) for dim in node.dims]
|
||||
|
||||
def visit_For(self, node):
|
||||
iterator = self.visit(node.iter.func)
|
||||
if iterator != self.builtins['range']:
|
||||
raise RuntimeError('Only `range` iterator currently supported')
|
||||
IteratorClass = self.visit(node.iter.func)
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
if IteratorClass == triton.language.static_range:
|
||||
iterator = IteratorClass(*iter_args)
|
||||
static_range = range(iterator.start.value,
|
||||
iterator.end.value,
|
||||
iterator.step.value)
|
||||
for i in static_range:
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
|
||||
if IteratorClass != self.builtins['range']:
|
||||
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
|
||||
|
||||
# visit iterator arguments
|
||||
# note: only `range` iterator is supported now
|
||||
iter_args = [self.visit(arg) for arg in node.iter.args]
|
||||
# collect lower bound (lb), upper bound (ub), and step
|
||||
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
|
||||
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
|
||||
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
|
||||
# static for loops: all iterator arguments are constexpr
|
||||
if isinstance(lb, triton.language.constexpr) and \
|
||||
isinstance(ub, triton.language.constexpr) and \
|
||||
isinstance(step, triton.language.constexpr):
|
||||
sta_range = iterator(lb.value, ub.value, step.value)
|
||||
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
|
||||
if static_unrolling and len(sta_range) <= 10:
|
||||
for i in sta_range:
|
||||
self.lscope[node.target.id] = triton.language.constexpr(i)
|
||||
self.visit_compound_statement(node.body)
|
||||
for stmt in node.orelse:
|
||||
ast.NodeVisitor.generic_visit(self, stmt)
|
||||
return
|
||||
# handle negative constant step (not supported by scf.for in MLIR)
|
||||
negative_step = False
|
||||
if isinstance(step, triton.language.constexpr) and step.value < 0:
|
||||
|
||||
@@ -65,6 +65,7 @@ from .core import (
|
||||
store,
|
||||
sum,
|
||||
swizzle2d,
|
||||
static_range,
|
||||
tensor,
|
||||
trans,
|
||||
triton,
|
||||
@@ -162,6 +163,7 @@ __all__ = [
|
||||
"sin",
|
||||
"softmax",
|
||||
"sqrt",
|
||||
"static_range",
|
||||
"store",
|
||||
"sum",
|
||||
"swizzle2d",
|
||||
|
||||
@@ -1307,3 +1307,33 @@ def printf(prefix, *args, _builder=None):
|
||||
for arg in args:
|
||||
new_args.append(_to_tensor(arg, _builder))
|
||||
return semantic.printf(new_prefix, new_args, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Iterators
|
||||
# -----------------------
|
||||
|
||||
|
||||
class static_range:
|
||||
|
||||
"""Iterator that counts upward forever."""
|
||||
|
||||
def __init__(self, arg1, arg2=None, step=None):
|
||||
assert isinstance(arg1, constexpr)
|
||||
if step is None:
|
||||
self.step = constexpr(1)
|
||||
else:
|
||||
assert isinstance(step, constexpr)
|
||||
self.step = step
|
||||
if arg2 is None:
|
||||
self.start = constexpr(0)
|
||||
self.end = arg1
|
||||
else:
|
||||
assert isinstance(arg2, constexpr)
|
||||
self.start = arg1
|
||||
self.end = arg2
|
||||
|
||||
def __iter__(self):
|
||||
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
|
||||
|
||||
def __next__(self):
|
||||
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
|
||||
|
||||
@@ -17,7 +17,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
|
||||
"""
|
||||
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
|
||||
"""
|
||||
for _ in range(n_rounds):
|
||||
for _ in tl.static_range(n_rounds):
|
||||
# update random state
|
||||
A = PHILOX_ROUND_A
|
||||
B = PHILOX_ROUND_B
|
||||
|
||||
Reference in New Issue
Block a user