From bcfe42937fc14897a2bfb15a69b202d3dfb5f161 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:14:15 -0800 Subject: [PATCH] move permute/flip/shrink to mixins (#13113) * move permute to mixins * move more stuff * two more * fix local mypy * fix tests * fix shrink --- extra/sqtt/roc.py | 2 +- test/unit/test_indexing.py | 8 +- tinygrad/gradient.py | 2 +- tinygrad/mixin/movement.py | 192 +++++++++++++++++++++++++++++++++++-- tinygrad/schedule/multi.py | 2 +- tinygrad/tensor.py | 175 +-------------------------------- tinygrad/uop/ops.py | 9 +- 7 files changed, 198 insertions(+), 192 deletions(-) diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index aac9a6194c..989e5d9594 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -57,7 +57,7 @@ class _ROCParseCtx: self.inst_execs:dict[tuple[str, int, int, int], list[InstExec]] = {} for prog in prog_evs: - arch = "gfx%d%x%x" % ((trgt:=dev_evs[prog.device].props['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100) + arch = "gfx%d%x%x" % ((trgt:=unwrap(dev_evs[prog.device].props)['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100) for addr, info in llvm_disasm(arch, unwrap(prog.lib)).items(): self.disasms[unwrap(prog.base) + addr] = info self.addr2prg[unwrap(prog.base) + addr] = prog diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index 32bda7a415..9fad1cd381 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -894,7 +894,7 @@ class TestNumpy(unittest.TestCase): a = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - self.assertIsNot(a[...], a) + self.assertIs(a[...], a) numpy_testing_assert_equal_helper(a[...], a) # `a[...]` was `a` in numpy <1.9. #numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a)) @@ -1037,9 +1037,9 @@ class TestNumpy(unittest.TestCase): # Before `...` would return a itself. a = Tensor([5]) - self.assertIsNot(a, a[()]) - self.assertIsNot(a, a[...]) - self.assertIsNot(a, a[:]) + self.assertIs(a, a[()]) + self.assertIs(a, a[...]) + self.assertIs(a, a[:]) def test_broaderrors_indexing(self): a = Tensor.zeros(5, 5) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 9117b3ac17..23e6e5dce0 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -36,7 +36,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)), - (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)), + (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)), (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), # NOTE: this is only correct when the KERNEL has a single output (UPat(Ops.AFTER), lambda ctx: (ctx, ctx)), diff --git a/tinygrad/mixin/movement.py b/tinygrad/mixin/movement.py index c6b9fba19b..faecfee2a1 100644 --- a/tinygrad/mixin/movement.py +++ b/tinygrad/mixin/movement.py @@ -1,7 +1,8 @@ # mixins add syntactic sugar to Tensor and UOp +import functools from typing import TypeAlias, TYPE_CHECKING, Self from tinygrad.uop import Ops -from tinygrad.helpers import prod, argfix +from tinygrad.helpers import prod, argfix, flatten, dedup if TYPE_CHECKING: from tinygrad.uop.ops import UOp sint:TypeAlias = UOp|int @@ -41,10 +42,6 @@ class MovementMixin: if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") return dim + total if dim < 0 else dim - def view(self, shape, *args) -> Self: - """`.view` is an alias for `.reshape`.""" - return self.reshape(shape, *args) - def reshape(self, shape, *args) -> Self: """ Returns a tensor with the same data as the original tensor but with a different shape. @@ -61,7 +58,131 @@ class MovementMixin: if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}") if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape]) if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})") - return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self + ret = self._mop(Ops.RESHAPE, arg=new_shape) + return self if ret.shape == self.shape else ret + + def shrink(self, arg:tuple[tuple["sint", "sint"]|None, ...]) -> Self: + """ + Returns a tensor that shrinks the each axis based on input arg. + `arg` must have the same length as `self.ndim`. + For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(9).reshape(3, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.shrink(((None, (1, 3)))).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.shrink((((0, 2), (0, 2)))).numpy()) + ``` + """ + if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}") + ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) + return self if ret.shape == self.shape else ret + + def permute(self, order, *args) -> Self: + """ + Returns a tensor that is a permutation of the original tensor. + The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. + `order` can be passed as a tuple or as separate arguments. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.empty(2, 3, 5) + print(t.shape) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.permute(2, 0, 1).shape) + ``` + """ + order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) + if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") + return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self + + def flip(self, axis, *args) -> Self: + """ + Returns a tensor that reverses the order of the original tensor along given `axis`. + `axis` can be passed as a tuple or as separate arguments. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(6).reshape(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.flip(0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.flip((0, 1)).numpy()) + ``` + """ + axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) + if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") + flip_arg = tuple([i in axis_arg for i in range(len(self.shape))]) + return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self + + # **** high level **** + + def view(self, shape, *args) -> Self: + """`.view` is an alias for `.reshape`.""" + return self.reshape(shape, *args) + + def squeeze(self, dim:int|None=None) -> Self: + """ + Returns a tensor with specified dimensions of input of size 1 removed. + If `dim` is not specified, all dimensions with size 1 are removed. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.zeros(2, 1, 2, 1, 2) + print(t.squeeze().shape) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.squeeze(0).shape) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.squeeze(1).shape) + ``` + """ + if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1)) + dim = self._resolve_dim(dim) + return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:]) + + def unsqueeze(self, dim:int) -> Self: + """ + Returns a tensor with a new dimension of size 1 inserted at the specified `dim`. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor([1, 2, 3, 4]) + print(t.unsqueeze(0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.unsqueeze(1).numpy()) + ``` + """ + dim = self._resolve_dim(dim, extra=True) + return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) + + @property + def T(self) -> Self: + """`.T` is an alias for `.transpose()`.""" + return self.transpose() + + def transpose(self, dim0=1, dim1=0) -> Self: + """ + Returns a tensor that is a transposed version of the original tensor. + The given dimensions `dim0` and `dim1` are swapped. + + ```python exec="true" source="above" session="tensor" result="python" + t = Tensor.arange(6).reshape(2, 3) + print(t.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.transpose(0, 1).numpy()) + ``` + """ + order = list(range(self.ndim)) + order[dim0], order[dim1] = order[dim1], order[dim0] + return self.permute(order) def flatten(self, start_dim=0, end_dim=-1) -> Self: """ @@ -77,4 +198,61 @@ class MovementMixin: ``` """ start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim) - return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) \ No newline at end of file + return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:]) + + def unflatten(self, dim:int, sizes:tuple[int,...]) -> Self: + """ + Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function. + + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape) + ``` + """ + dim = self._resolve_dim(dim) + return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) + + def rearrange(self, formula:str, **sizes) -> Self: + """ + Rearranges input according to formula + + See: https://einops.rocks/api/rearrange/ + + ```python exec="true" source="above" session="tensor" result="python" + x = Tensor([[1, 2], [3, 4]]) + print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy()) + ``` + """ + def parse_formula(formula: str): + tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split() + lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")")) + pairs = list(zip(lparens, rparens)) + assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch" + return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)] + + assert formula.count("->") == 1, 'need exactly one "->" in formula' + + (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->")) + + for name in sizes: assert name in lhs, f"axis {name} is not used in transform" + assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}" + for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}" + assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}" + assert lhs.count("...") <= 1, f"too many ellipses in {formula}" + + # resolve ellipsis + if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims) + lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs)) + unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims] + flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims] + + # apply movement ops in order unflatten -> permute -> flatten/unsqueeze + t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self) + for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect" + t = t.permute([lhs.index(name) for name in rhs]) + return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] UOp|None: # allgather copied_chunks = [] for i,c in enumerate(reduced_chunks): - this_chunk = [None] * len(buf.device) + this_chunk: list[UOp|None] = [None] * len(buf.device) this_chunk[(i+len(buf.device)-1)%n_lbs] = c for step in range(n_lbs-1): dest = (i+step)%n_lbs diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c49c510f21..da422b4d34 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -5,7 +5,7 @@ from contextlib import ContextDecorator from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype -from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup +from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC from tinygrad.helpers import suppress_finalizing from tinygrad.gradient import compute_gradient @@ -1055,65 +1055,6 @@ class Tensor(OpMixin): new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args))))) return self._broadcast_to(new_shape) - def permute(self, order, *args) -> Tensor: - """ - Returns a tensor that is a permutation of the original tensor. - The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified. - `order` can be passed as a tuple or as separate arguments. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.empty(2, 3, 5) - print(t.shape) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.permute(2, 0, 1).shape) - ``` - """ - order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args)) - if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}") - return self._apply_uop(UOp.permute, arg=order_arg) - - def flip(self, axis, *args) -> Tensor: - """ - Returns a tensor that reverses the order of the original tensor along given `axis`. - `axis` can be passed as a tuple or as separate arguments. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(6).reshape(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.flip(0).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.flip((0, 1)).numpy()) - ``` - """ - axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args)) - if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}") - return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))])) - - def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor: - """ - Returns a tensor that shrinks the each axis based on input arg. - `arg` must have the same length as `self.ndim`. - For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(9).reshape(3, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.shrink(((None, (1, 3)))).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.shrink((((0, 2), (0, 2)))).numpy()) - ``` - """ - if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}") - if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self - return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg)) - def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor: """ Returns a tensor with padding applied based on the input `padding`. @@ -1535,80 +1476,6 @@ class Tensor(OpMixin): output_shape = _broadcast_shape(*(t.shape for t in tensors)) return tuple(t._broadcast_to(output_shape) for t in tensors) - def squeeze(self, dim:int|None=None) -> Tensor: - """ - Returns a tensor with specified dimensions of input of size 1 removed. - If `dim` is not specified, all dimensions with size 1 are removed. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.zeros(2, 1, 2, 1, 2) - print(t.squeeze().shape) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.squeeze(0).shape) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.squeeze(1).shape) - ``` - """ - if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1)) - dim = self._resolve_dim(dim) - return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:]) - - def unsqueeze(self, dim:int) -> Tensor: - """ - Returns a tensor with a new dimension of size 1 inserted at the specified `dim`. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor([1, 2, 3, 4]) - print(t.unsqueeze(0).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.unsqueeze(1).numpy()) - ``` - """ - dim = self._resolve_dim(dim, extra=True) - return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) - - @property - def T(self) -> Tensor: - """`.T` is an alias for `.transpose()`.""" - return self.transpose() - - def transpose(self, dim0=1, dim1=0) -> Tensor: - """ - Returns a tensor that is a transposed version of the original tensor. - The given dimensions `dim0` and `dim1` are swapped. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(6).reshape(2, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.transpose(0, 1).numpy()) - ``` - """ - order = list(range(self.ndim)) - order[dim0], order[dim1] = order[dim1], order[dim0] - return self.permute(order) - - def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor: - """ - Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function. - - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape) - ``` - """ - dim = self._resolve_dim(dim) - return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:]) - def diag(self) -> Tensor: """ Returns a 2-D square tensor with the elements of input as the main diagonal. @@ -1654,46 +1521,6 @@ class Tensor(OpMixin): for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim]) return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices] - def rearrange(self, formula:str, **sizes) -> Tensor: - """ - Rearranges input according to formula - - See: https://einops.rocks/api/rearrange/ - - ```python exec="true" source="above" session="tensor" result="python" - x = Tensor([[1, 2], [3, 4]]) - print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy()) - ``` - """ - def parse_formula(formula: str): - tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split() - lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")")) - pairs = list(zip(lparens, rparens)) - assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch" - return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)] - - assert formula.count("->") == 1, 'need exactly one "->" in formula' - - (lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->")) - - for name in sizes: assert name in lhs, f"axis {name} is not used in transform" - assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}" - for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}" - assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}" - assert lhs.count("...") <= 1, f"too many ellipses in {formula}" - - # resolve ellipsis - if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims) - lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs)) - unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims] - flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims] - - # apply movement ops in order unflatten -> permute -> flatten/unsqueeze - t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self) - for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect" - t = t.permute([lhs.index(name) for name in rhs]) - return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]