diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py index bb8bc96686..10dca3f24e 100644 --- a/tinygrad/mixin/movement.py +++ b/tinygrad/mixin/movement.py @@ -1,5 +1,4 @@ # mixins add syntactic sugar to Tensor and UOp -import functools from typing import TypeAlias, TYPE_CHECKING, Self from tinygrad.uop import Ops from tinygrad.helpers import prod, argfix, flatten, dedup, make_tuple, ceildiv @@ -282,38 +281,38 @@ class MovementMixin: ``` """ - def parse_formula(formula: str): - tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split() - lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")")) + def parse_side(s: str) -> tuple[list[str], list[tuple[int, int]]]: + """Parse one side of formula into (axis_names, dims) where dims are (start, end) index pairs for parens.""" + tokens = f" {s} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split() + lparens, rparens = [i for i, tok in enumerate(tokens) if tok == "("], [i for i, tok in enumerate(tokens) if tok == ")"] pairs = list(zip(lparens, rparens)) assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch" - return [name for name in tokens if name not in ("(", ")")], [(s - 2 * i, e - 1 - 2 * i) for i, (s, e) in enumerate(pairs)] + return [tok for tok in tokens if tok not in ("(", ")")], [(lp - 2*i, rp - 1 - 2*i) for i, (lp, rp) in enumerate(pairs)] assert formula.count("->") == 1, 'need exactly one "->" in formula' + (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_side, formula.split("->")) - (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->")) - - for name in sizes: - assert name in lhs, f"axis {name} is not used in transform" + for name in sizes: assert name in lhs, f"axis {name} is not used in transform" assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}" - for name in flatten((lhs, rhs)): - assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}" + for name in lhs+rhs: assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}" assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}" assert lhs.count("...") <= 1, f"too many ellipses in {formula}" # resolve ellipsis if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims) - lhs, rhs = map(lambda l: l[: (i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1 :] if "..." in l else l, (lhs, rhs)) - unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims] - flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims] + lhs, rhs = map(lambda l: l[:(i := l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs)) + def newdims(side, s, e): return (s + (ell_len - 1 if "...0" in side[:s] else 0), e + (ell_len - 1 if "...0" in side[:e] else 0)) + unflatten_dims, flatten_dims = [newdims(lhs, s, e) for s, e in unflatten_dims], [newdims(rhs, s, e) for s, e in flatten_dims] - # apply movement ops in order unflatten -> permute -> flatten/unsqueeze - t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self) + # unflatten -> permute -> flatten + t = self + for start, end in unflatten_dims: t = t.unflatten(start, tuple(sizes.get(lhs[i], -1) for i in range(start, end))) for i, name in enumerate(lhs): - assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect" + if name in sizes: assert sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect" t = t.permute([lhs.index(name) for name in rhs]) - return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] < dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t) + for start, end in reversed(flatten_dims): t = t.flatten(start, end - 1) if start < end else t.unsqueeze(start) + return t # *** movement ops with expand ***