[FRONTEND] Adding static range (#1130)

Included: Revert "[BACKEND] Replace `mlir::topologicalSort` with a
custom implementation (#1113)"
This commit is contained in:
Philippe Tillet
2023-01-31 18:04:19 -08:00
committed by GitHub
parent be3da96919
commit 8fea1fb478
7 changed files with 52 additions and 89 deletions

View File

@@ -65,7 +65,7 @@ template <typename Int> Int product(llvm::ArrayRef<Int> arr) {
template <typename Int> Int ceil(Int m, Int n) { return (m + n - 1) / n; }
/// output[i] = input[order[i]]
// output[i] = input[order[i]]
template <typename T, typename RES_T = T>
SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
size_t rank = order.size();
@@ -80,14 +80,6 @@ SmallVector<RES_T> reorder(ArrayRef<T> input, ArrayRef<unsigned> order) {
bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
triton::gpu::DotOperandEncodingAttr &dotOperandLayout);
/// Multi-root DAG topological sort.
/// Performs a topological sort of the Operation in the `toSort` SetVector.
/// Returns a topologically sorted SetVector.
/// It is faster than mlir::topologicalSort because it prunes nodes that have
/// been visited before.
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -164,65 +164,4 @@ bool isMmaToDotShortcut(triton::gpu::MmaEncodingAttr &mmaLayout,
dotOperandLayout.getParent() == mmaLayout;
}
namespace {
/// DFS post-order implementation that maintains a global count to work across
/// multiple invocations, to help implement topological sort on multi-root DAGs.
/// We traverse all operations but only record the ones that appear in
/// `toSort` for the final result.
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
};
void dfsPostorder(Operation *root, DFSState *state) {
SmallVector<Operation *> queue(1, root);
std::vector<Operation *> ops;
while (!queue.empty()) {
Operation *current = queue.pop_back_val();
if (!state->seen.insert(current).second)
continue;
ops.push_back(current);
for (Value result : current->getResults()) {
for (Operation *op : result.getUsers())
queue.push_back(op);
}
for (Region &region : current->getRegions()) {
for (Operation &op : region.getOps())
queue.push_back(&op);
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
} // namespace
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
// Run from each root with global count and `seen` set.
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}
// Reorder and return.
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}
} // namespace mlir

View File

@@ -652,7 +652,7 @@ public:
else
sortedValues.push_back(v);
}
tmp = mlir::multiRootTopologicalSort(tmp);
tmp = mlir::topologicalSort(tmp);
for (Operation *op : tmp)
sortedValues.push_back(op->getResult(0));

View File

@@ -625,29 +625,29 @@ class CodeGenerator(ast.NodeVisitor):
return [self.visit(dim) for dim in node.dims]
def visit_For(self, node):
iterator = self.visit(node.iter.func)
if iterator != self.builtins['range']:
raise RuntimeError('Only `range` iterator currently supported')
IteratorClass = self.visit(node.iter.func)
iter_args = [self.visit(arg) for arg in node.iter.args]
if IteratorClass == triton.language.static_range:
iterator = IteratorClass(*iter_args)
static_range = range(iterator.start.value,
iterator.end.value,
iterator.step.value)
for i in static_range:
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
if IteratorClass != self.builtins['range']:
raise RuntimeError('Only `range` and `static_range` iterators are currently supported')
# visit iterator arguments
# note: only `range` iterator is supported now
iter_args = [self.visit(arg) for arg in node.iter.args]
# collect lower bound (lb), upper bound (ub), and step
lb = iter_args[0] if len(iter_args) > 1 else self.visit(ast.Num(0))
ub = iter_args[1] if len(iter_args) > 1 else self.visit(node.iter.args[0])
step = iter_args[2] if len(iter_args) > 2 else self.visit(ast.Num(1))
# static for loops: all iterator arguments are constexpr
if isinstance(lb, triton.language.constexpr) and \
isinstance(ub, triton.language.constexpr) and \
isinstance(step, triton.language.constexpr):
sta_range = iterator(lb.value, ub.value, step.value)
static_unrolling = os.environ.get('TRITON_STATIC_LOOP_UNROLLING', False)
if static_unrolling and len(sta_range) <= 10:
for i in sta_range:
self.lscope[node.target.id] = triton.language.constexpr(i)
self.visit_compound_statement(node.body)
for stmt in node.orelse:
ast.NodeVisitor.generic_visit(self, stmt)
return
# handle negative constant step (not supported by scf.for in MLIR)
negative_step = False
if isinstance(step, triton.language.constexpr) and step.value < 0:

View File

@@ -65,6 +65,7 @@ from .core import (
store,
sum,
swizzle2d,
static_range,
tensor,
trans,
triton,
@@ -162,6 +163,7 @@ __all__ = [
"sin",
"softmax",
"sqrt",
"static_range",
"store",
"sum",
"swizzle2d",

View File

@@ -1307,3 +1307,33 @@ def printf(prefix, *args, _builder=None):
for arg in args:
new_args.append(_to_tensor(arg, _builder))
return semantic.printf(new_prefix, new_args, _builder)
# -----------------------
# Iterators
# -----------------------
class static_range:
"""Iterator that counts upward forever."""
def __init__(self, arg1, arg2=None, step=None):
assert isinstance(arg1, constexpr)
if step is None:
self.step = constexpr(1)
else:
assert isinstance(step, constexpr)
self.step = step
if arg2 is None:
self.start = constexpr(0)
self.end = arg1
else:
assert isinstance(arg2, constexpr)
self.start = arg1
self.end = arg2
def __iter__(self):
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
def __next__(self):
raise RuntimeError("static_range can only be used in @triton.jit'd functions")

View File

@@ -17,7 +17,7 @@ def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds: tl.constexpr = N_ROUNDS_DEFAUL
"""
Run `n_rounds` rounds of Philox for state (c0, c1, c2, c3) and key (k0, k1).
"""
for _ in range(n_rounds):
for _ in tl.static_range(n_rounds):
# update random state
A = PHILOX_ROUND_A
B = PHILOX_ROUND_B