clean up rearrange (#13851)

This commit is contained in:
chenyu
2025-12-27 11:06:10 -05:00
committed by GitHub
parent f6c660f7fa
commit 0f74909ae9

View File

@@ -1,5 +1,4 @@
# mixins add syntactic sugar to Tensor and UOp
import functools
from typing import TypeAlias, TYPE_CHECKING, Self
from tinygrad.uop import Ops
from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv
@@ -282,38 +281,38 @@ class MovementMixin:
```
"""
def parse_formula(formula: str):
tokens = f" {formula} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
def parse_side(s: str) -> tuple[list[str], list[tuple[int, int]]]:
"""Parse one side of formula into (axis_names, dims) where dims are (start, end) index pairs for parens."""
tokens = f" {s} ".replace("", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
lparens, rparens = [i for i, tok in enumerate(tokens) if tok == "("], [i for i, tok in enumerate(tokens) if tok == ")"]
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)]
return [tok for tok in tokens if tok not in ("(", ")")], [(lp - 2*i, rp - 1 - 2*i) for i, (lp, rp) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_side, formula.split("->"))
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
for name in sizes:
assert name in lhs, f"axis {name} is not used in transform"
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
for name in flatten((lhs, rhs)):
assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
for name in lhs+rhs: assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
# resolve ellipsis
if "..." in lhs:
ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs))
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
lhs, rhs = map(lambda l: l[:(i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
def newdims(side, s, e): return (s + (ell_len - 1 if "...0" in side[:s] else 0), e + (ell_len - 1 if "...0" in side[:e] else 0))
unflatten_dims, flatten_dims = [newdims(lhs, s, e) for s, e in unflatten_dims], [newdims(rhs, s, e) for s, e in flatten_dims]
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
# unflatten -> permute -> flatten
t = self
for start, end in unflatten_dims: t = t.unflatten(start, tuple(sizes.get(lhs[i], -1) for i in range(start, end)))
for i, name in enumerate(lhs):
assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
if name in sizes: assert sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
t = t.permute([lhs.index(name) for name in rhs])
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
for start, end in reversed(flatten_dims): t = t.flatten(start, end - 1) if start < end else t.unsqueeze(start)
return t
# *** movement ops with expand ***