mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move permute/flip/shrink to mixins (#13113)
* move permute to mixins * move more stuff * two more * fix local mypy * fix tests * fix shrink
This commit is contained in:
@@ -57,7 +57,7 @@ class _ROCParseCtx:
|
||||
self.inst_execs:dict[tuple[str, int, int, int], list[InstExec]] = {}
|
||||
|
||||
for prog in prog_evs:
|
||||
arch = "gfx%d%x%x" % ((trgt:=dev_evs[prog.device].props['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
|
||||
arch = "gfx%d%x%x" % ((trgt:=unwrap(dev_evs[prog.device].props)['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
|
||||
for addr, info in llvm_disasm(arch, unwrap(prog.lib)).items():
|
||||
self.disasms[unwrap(prog.base) + addr] = info
|
||||
self.addr2prg[unwrap(prog.base) + addr] = prog
|
||||
|
||||
@@ -894,7 +894,7 @@ class TestNumpy(unittest.TestCase):
|
||||
a = Tensor([[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9]])
|
||||
self.assertIsNot(a[...], a)
|
||||
self.assertIs(a[...], a)
|
||||
numpy_testing_assert_equal_helper(a[...], a)
|
||||
# `a[...]` was `a` in numpy <1.9.
|
||||
#numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
|
||||
@@ -1037,9 +1037,9 @@ class TestNumpy(unittest.TestCase):
|
||||
# Before `...` would return a itself.
|
||||
a = Tensor([5])
|
||||
|
||||
self.assertIsNot(a, a[()])
|
||||
self.assertIsNot(a, a[...])
|
||||
self.assertIsNot(a, a[:])
|
||||
self.assertIs(a, a[()])
|
||||
self.assertIs(a, a[...])
|
||||
self.assertIs(a, a[:])
|
||||
|
||||
def test_broaderrors_indexing(self):
|
||||
a = Tensor.zeros(5, 5)
|
||||
|
||||
@@ -36,7 +36,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# mixins add syntactic sugar to Tensor and UOp
|
||||
import functools
|
||||
from typing import TypeAlias, TYPE_CHECKING, Self
|
||||
from tinygrad.uop import Ops
|
||||
from tinygrad.helpers import prod, argfix
|
||||
from tinygrad.helpers import prod, argfix, flatten, dedup
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.uop.ops import UOp
|
||||
sint:TypeAlias = UOp|int
|
||||
@@ -41,10 +42,6 @@ class MovementMixin:
|
||||
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
||||
return dim + total if dim < 0 else dim
|
||||
|
||||
def view(self, shape, *args) -> Self:
|
||||
"""`.view` is an alias for `.reshape`."""
|
||||
return self.reshape(shape, *args)
|
||||
|
||||
def reshape(self, shape, *args) -> Self:
|
||||
"""
|
||||
Returns a tensor with the same data as the original tensor but with a different shape.
|
||||
@@ -61,7 +58,131 @@ class MovementMixin:
|
||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
||||
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
|
||||
ret = self._mop(Ops.RESHAPE, arg=new_shape)
|
||||
return self if ret.shape == self.shape else ret
|
||||
|
||||
def shrink(self, arg:tuple[tuple["sint", "sint"]|None, ...]) -> Self:
|
||||
"""
|
||||
Returns a tensor that shrinks the each axis based on input arg.
|
||||
`arg` must have the same length as `self.ndim`.
|
||||
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(3, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.shrink(((None, (1, 3)))).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)])
|
||||
return self if ret.shape == self.shape else ret
|
||||
|
||||
def permute(self, order, *args) -> Self:
|
||||
"""
|
||||
Returns a tensor that is a permutation of the original tensor.
|
||||
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
||||
`order` can be passed as a tuple or as separate arguments.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 3, 5)
|
||||
print(t.shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.permute(2, 0, 1).shape)
|
||||
```
|
||||
"""
|
||||
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
||||
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
||||
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
|
||||
|
||||
def flip(self, axis, *args) -> Self:
|
||||
"""
|
||||
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
||||
`axis` can be passed as a tuple or as separate arguments.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flip(0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flip((0, 1)).numpy())
|
||||
```
|
||||
"""
|
||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
|
||||
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
|
||||
|
||||
# **** high level ****
|
||||
|
||||
def view(self, shape, *args) -> Self:
|
||||
"""`.view` is an alias for `.reshape`."""
|
||||
return self.reshape(shape, *args)
|
||||
|
||||
def squeeze(self, dim:int|None=None) -> Self:
|
||||
"""
|
||||
Returns a tensor with specified dimensions of input of size 1 removed.
|
||||
If `dim` is not specified, all dimensions with size 1 are removed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.zeros(2, 1, 2, 1, 2)
|
||||
print(t.squeeze().shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.squeeze(0).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.squeeze(1).shape)
|
||||
```
|
||||
"""
|
||||
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
||||
dim = self._resolve_dim(dim)
|
||||
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
||||
|
||||
def unsqueeze(self, dim:int) -> Self:
|
||||
"""
|
||||
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.unsqueeze(0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.unsqueeze(1).numpy())
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim, extra=True)
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
@property
|
||||
def T(self) -> Self:
|
||||
"""`.T` is an alias for `.transpose()`."""
|
||||
return self.transpose()
|
||||
|
||||
def transpose(self, dim0=1, dim1=0) -> Self:
|
||||
"""
|
||||
Returns a tensor that is a transposed version of the original tensor.
|
||||
The given dimensions `dim0` and `dim1` are swapped.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.transpose(0, 1).numpy())
|
||||
```
|
||||
"""
|
||||
order = list(range(self.ndim))
|
||||
order[dim0], order[dim1] = order[dim1], order[dim0]
|
||||
return self.permute(order)
|
||||
|
||||
def flatten(self, start_dim=0, end_dim=-1) -> Self:
|
||||
"""
|
||||
@@ -77,4 +198,61 @@ class MovementMixin:
|
||||
```
|
||||
"""
|
||||
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||
|
||||
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Self:
|
||||
"""
|
||||
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim)
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||
|
||||
def rearrange(self, formula:str, **sizes) -> Self:
|
||||
"""
|
||||
Rearranges input according to formula
|
||||
|
||||
See: https://einops.rocks/api/rearrange/
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
x = Tensor([[1, 2], [3, 4]])
|
||||
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
||||
```
|
||||
"""
|
||||
def parse_formula(formula: str):
|
||||
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
||||
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
||||
pairs = list(zip(lparens, rparens))
|
||||
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
||||
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
||||
|
||||
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
||||
|
||||
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
||||
|
||||
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
||||
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
||||
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
||||
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
||||
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
||||
|
||||
# resolve ellipsis
|
||||
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
||||
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
||||
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
||||
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
||||
|
||||
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
||||
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
||||
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
||||
t = t.permute([lhs.index(name) for name in rhs])
|
||||
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
||||
|
||||
@@ -68,7 +68,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
||||
# allgather
|
||||
copied_chunks = []
|
||||
for i,c in enumerate(reduced_chunks):
|
||||
this_chunk = [None] * len(buf.device)
|
||||
this_chunk: list[UOp|None] = [None] * len(buf.device)
|
||||
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
|
||||
for step in range(n_lbs-1):
|
||||
dest = (i+step)%n_lbs
|
||||
|
||||
@@ -5,7 +5,7 @@ from contextlib import ContextDecorator
|
||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC
|
||||
from tinygrad.helpers import suppress_finalizing
|
||||
from tinygrad.gradient import compute_gradient
|
||||
@@ -1055,65 +1055,6 @@ class Tensor(OpMixin):
|
||||
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
|
||||
return self._broadcast_to(new_shape)
|
||||
|
||||
def permute(self, order, *args) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that is a permutation of the original tensor.
|
||||
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
||||
`order` can be passed as a tuple or as separate arguments.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 3, 5)
|
||||
print(t.shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.permute(2, 0, 1).shape)
|
||||
```
|
||||
"""
|
||||
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
||||
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
||||
return self._apply_uop(UOp.permute, arg=order_arg)
|
||||
|
||||
def flip(self, axis, *args) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
||||
`axis` can be passed as a tuple or as separate arguments.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flip(0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.flip((0, 1)).numpy())
|
||||
```
|
||||
"""
|
||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
|
||||
|
||||
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that shrinks the each axis based on input arg.
|
||||
`arg` must have the same length as `self.ndim`.
|
||||
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(9).reshape(3, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.shrink(((None, (1, 3)))).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
||||
```
|
||||
"""
|
||||
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
||||
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
||||
|
||||
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with padding applied based on the input `padding`.
|
||||
@@ -1535,80 +1476,6 @@ class Tensor(OpMixin):
|
||||
output_shape = _broadcast_shape(*(t.shape for t in tensors))
|
||||
return tuple(t._broadcast_to(output_shape) for t in tensors)
|
||||
|
||||
def squeeze(self, dim:int|None=None) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with specified dimensions of input of size 1 removed.
|
||||
If `dim` is not specified, all dimensions with size 1 are removed.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.zeros(2, 1, 2, 1, 2)
|
||||
print(t.squeeze().shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.squeeze(0).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.squeeze(1).shape)
|
||||
```
|
||||
"""
|
||||
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
||||
dim = self._resolve_dim(dim)
|
||||
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
||||
|
||||
def unsqueeze(self, dim:int) -> Tensor:
|
||||
"""
|
||||
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1, 2, 3, 4])
|
||||
print(t.unsqueeze(0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.unsqueeze(1).numpy())
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim, extra=True)
|
||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||
|
||||
@property
|
||||
def T(self) -> Tensor:
|
||||
"""`.T` is an alias for `.transpose()`."""
|
||||
return self.transpose()
|
||||
|
||||
def transpose(self, dim0=1, dim1=0) -> Tensor:
|
||||
"""
|
||||
Returns a tensor that is a transposed version of the original tensor.
|
||||
The given dimensions `dim0` and `dim1` are swapped.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.arange(6).reshape(2, 3)
|
||||
print(t.numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t.transpose(0, 1).numpy())
|
||||
```
|
||||
"""
|
||||
order = list(range(self.ndim))
|
||||
order[dim0], order[dim1] = order[dim1], order[dim0]
|
||||
return self.permute(order)
|
||||
|
||||
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
|
||||
"""
|
||||
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim)
|
||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||
|
||||
def diag(self) -> Tensor:
|
||||
"""
|
||||
Returns a 2-D square tensor with the elements of input as the main diagonal.
|
||||
@@ -1654,46 +1521,6 @@ class Tensor(OpMixin):
|
||||
for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
|
||||
return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
|
||||
|
||||
def rearrange(self, formula:str, **sizes) -> Tensor:
|
||||
"""
|
||||
Rearranges input according to formula
|
||||
|
||||
See: https://einops.rocks/api/rearrange/
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
x = Tensor([[1, 2], [3, 4]])
|
||||
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
||||
```
|
||||
"""
|
||||
def parse_formula(formula: str):
|
||||
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
||||
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
||||
pairs = list(zip(lparens, rparens))
|
||||
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
||||
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
||||
|
||||
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
||||
|
||||
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
||||
|
||||
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
||||
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
||||
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
||||
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
||||
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
||||
|
||||
# resolve ellipsis
|
||||
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
||||
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
||||
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
||||
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
||||
|
||||
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
||||
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
||||
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
||||
t = t.permute([lhs.index(name) for name in rhs])
|
||||
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
||||
|
||||
def masked_select(self, mask):
|
||||
"""
|
||||
Selects elements from `self` based on the boolean `mask`.
|
||||
|
||||
@@ -549,12 +549,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
||||
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
||||
|
||||
# in these two, we have custom logic to check if they are a no-op
|
||||
def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
|
||||
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
|
||||
#def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
|
||||
#def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
|
||||
|
||||
# *** uop UNIQUE ***
|
||||
|
||||
@@ -1322,7 +1322,8 @@ renderer_infer = PatternMatcher([
|
||||
|
||||
def srcs(ctx, src): return f"({ctx[src[0]]},)" if len(src) == 1 else f"({', '.join([ctx[x] for x in src])})"
|
||||
def render_marg(ctx,x:UOp):
|
||||
if x.op in {Ops.PERMUTE, Ops.FLIP}: return str(x.marg)
|
||||
if x.op is Ops.PERMUTE: return str(x.marg)
|
||||
if x.op is Ops.FLIP: return str(tuple([i for i,x in enumerate(x.marg) if x]))
|
||||
pieces = []
|
||||
if x.op in {Ops.RESHAPE, Ops.EXPAND}:
|
||||
pieces = [f"{ctx[a] if isinstance(a, UOp) else str(a)}" for a in x.marg]
|
||||
|
||||
Reference in New Issue
Block a user