mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
move permute/flip/shrink to mixins (#13113)
* move permute to mixins * move more stuff * two more * fix local mypy * fix tests * fix shrink
This commit is contained in:
@@ -57,7 +57,7 @@ class _ROCParseCtx:
|
|||||||
self.inst_execs:dict[tuple[str, int, int, int], list[InstExec]] = {}
|
self.inst_execs:dict[tuple[str, int, int, int], list[InstExec]] = {}
|
||||||
|
|
||||||
for prog in prog_evs:
|
for prog in prog_evs:
|
||||||
arch = "gfx%d%x%x" % ((trgt:=dev_evs[prog.device].props['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
|
arch = "gfx%d%x%x" % ((trgt:=unwrap(dev_evs[prog.device].props)['gfx_target_version']) // 10000, (trgt // 100) % 100, trgt % 100)
|
||||||
for addr, info in llvm_disasm(arch, unwrap(prog.lib)).items():
|
for addr, info in llvm_disasm(arch, unwrap(prog.lib)).items():
|
||||||
self.disasms[unwrap(prog.base) + addr] = info
|
self.disasms[unwrap(prog.base) + addr] = info
|
||||||
self.addr2prg[unwrap(prog.base) + addr] = prog
|
self.addr2prg[unwrap(prog.base) + addr] = prog
|
||||||
|
|||||||
@@ -894,7 +894,7 @@ class TestNumpy(unittest.TestCase):
|
|||||||
a = Tensor([[1, 2, 3],
|
a = Tensor([[1, 2, 3],
|
||||||
[4, 5, 6],
|
[4, 5, 6],
|
||||||
[7, 8, 9]])
|
[7, 8, 9]])
|
||||||
self.assertIsNot(a[...], a)
|
self.assertIs(a[...], a)
|
||||||
numpy_testing_assert_equal_helper(a[...], a)
|
numpy_testing_assert_equal_helper(a[...], a)
|
||||||
# `a[...]` was `a` in numpy <1.9.
|
# `a[...]` was `a` in numpy <1.9.
|
||||||
#numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
|
#numpy_testing_assert_equal_helper(data_ptr(a[...]), data_ptr(a))
|
||||||
@@ -1037,9 +1037,9 @@ class TestNumpy(unittest.TestCase):
|
|||||||
# Before `...` would return a itself.
|
# Before `...` would return a itself.
|
||||||
a = Tensor([5])
|
a = Tensor([5])
|
||||||
|
|
||||||
self.assertIsNot(a, a[()])
|
self.assertIs(a, a[()])
|
||||||
self.assertIsNot(a, a[...])
|
self.assertIs(a, a[...])
|
||||||
self.assertIsNot(a, a[:])
|
self.assertIs(a, a[:])
|
||||||
|
|
||||||
def test_broaderrors_indexing(self):
|
def test_broaderrors_indexing(self):
|
||||||
a = Tensor.zeros(5, 5)
|
a = Tensor.zeros(5, 5)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ pm_gradient = PatternMatcher([
|
|||||||
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
(UPat(Ops.PAD, name="ret"), lambda ctx, ret: (ctx.shrink(tuple([(p[0], s+p[0]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip(ret.marg),)),
|
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
||||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||||
# NOTE: this is only correct when the KERNEL has a single output
|
# NOTE: this is only correct when the KERNEL has a single output
|
||||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
# mixins add syntactic sugar to Tensor and UOp
|
# mixins add syntactic sugar to Tensor and UOp
|
||||||
|
import functools
|
||||||
from typing import TypeAlias, TYPE_CHECKING, Self
|
from typing import TypeAlias, TYPE_CHECKING, Self
|
||||||
from tinygrad.uop import Ops
|
from tinygrad.uop import Ops
|
||||||
from tinygrad.helpers import prod, argfix
|
from tinygrad.helpers import prod, argfix, flatten, dedup
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from tinygrad.uop.ops import UOp
|
from tinygrad.uop.ops import UOp
|
||||||
sint:TypeAlias = UOp|int
|
sint:TypeAlias = UOp|int
|
||||||
@@ -41,10 +42,6 @@ class MovementMixin:
|
|||||||
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}")
|
||||||
return dim + total if dim < 0 else dim
|
return dim + total if dim < 0 else dim
|
||||||
|
|
||||||
def view(self, shape, *args) -> Self:
|
|
||||||
"""`.view` is an alias for `.reshape`."""
|
|
||||||
return self.reshape(shape, *args)
|
|
||||||
|
|
||||||
def reshape(self, shape, *args) -> Self:
|
def reshape(self, shape, *args) -> Self:
|
||||||
"""
|
"""
|
||||||
Returns a tensor with the same data as the original tensor but with a different shape.
|
Returns a tensor with the same data as the original tensor but with a different shape.
|
||||||
@@ -61,7 +58,131 @@ class MovementMixin:
|
|||||||
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
if (c := new_shape.count(-1)) > 1: raise RuntimeError(f"only one dimension can be inferred using -1, getting {new_shape}")
|
||||||
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
if c: new_shape = tuple([-prod(self.shape) // prod(new_shape) if s == -1 else s for s in new_shape])
|
||||||
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
if prod(self.shape) != prod(new_shape): raise ValueError(f"size mismatch, can't reshape ({self.shape}) -> ({new_shape})")
|
||||||
return self._mop(Ops.RESHAPE, arg=new_shape) if new_shape != self.shape else self
|
ret = self._mop(Ops.RESHAPE, arg=new_shape)
|
||||||
|
return self if ret.shape == self.shape else ret
|
||||||
|
|
||||||
|
def shrink(self, arg:tuple[tuple["sint", "sint"]|None, ...]) -> Self:
|
||||||
|
"""
|
||||||
|
Returns a tensor that shrinks the each axis based on input arg.
|
||||||
|
`arg` must have the same length as `self.ndim`.
|
||||||
|
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.arange(9).reshape(3, 3)
|
||||||
|
print(t.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.shrink(((None, (1, 3)))).numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
||||||
|
ret = self._mop(Ops.SHRINK, arg=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)])
|
||||||
|
return self if ret.shape == self.shape else ret
|
||||||
|
|
||||||
|
def permute(self, order, *args) -> Self:
|
||||||
|
"""
|
||||||
|
Returns a tensor that is a permutation of the original tensor.
|
||||||
|
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
||||||
|
`order` can be passed as a tuple or as separate arguments.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.empty(2, 3, 5)
|
||||||
|
print(t.shape)
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.permute(2, 0, 1).shape)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
||||||
|
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
||||||
|
return self._mop(Ops.PERMUTE, arg=order_arg) if order_arg != tuple(range(self.ndim)) else self
|
||||||
|
|
||||||
|
def flip(self, axis, *args) -> Self:
|
||||||
|
"""
|
||||||
|
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
||||||
|
`axis` can be passed as a tuple or as separate arguments.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.arange(6).reshape(2, 3)
|
||||||
|
print(t.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.flip(0).numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.flip((0, 1)).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
||||||
|
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
||||||
|
flip_arg = tuple([i in axis_arg for i in range(len(self.shape))])
|
||||||
|
return self._mop(Ops.FLIP, arg=flip_arg) if any(flip_arg) else self
|
||||||
|
|
||||||
|
# **** high level ****
|
||||||
|
|
||||||
|
def view(self, shape, *args) -> Self:
|
||||||
|
"""`.view` is an alias for `.reshape`."""
|
||||||
|
return self.reshape(shape, *args)
|
||||||
|
|
||||||
|
def squeeze(self, dim:int|None=None) -> Self:
|
||||||
|
"""
|
||||||
|
Returns a tensor with specified dimensions of input of size 1 removed.
|
||||||
|
If `dim` is not specified, all dimensions with size 1 are removed.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.zeros(2, 1, 2, 1, 2)
|
||||||
|
print(t.squeeze().shape)
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.squeeze(0).shape)
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.squeeze(1).shape)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
||||||
|
dim = self._resolve_dim(dim)
|
||||||
|
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
||||||
|
|
||||||
|
def unsqueeze(self, dim:int) -> Self:
|
||||||
|
"""
|
||||||
|
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor([1, 2, 3, 4])
|
||||||
|
print(t.unsqueeze(0).numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.unsqueeze(1).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
dim = self._resolve_dim(dim, extra=True)
|
||||||
|
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def T(self) -> Self:
|
||||||
|
"""`.T` is an alias for `.transpose()`."""
|
||||||
|
return self.transpose()
|
||||||
|
|
||||||
|
def transpose(self, dim0=1, dim1=0) -> Self:
|
||||||
|
"""
|
||||||
|
Returns a tensor that is a transposed version of the original tensor.
|
||||||
|
The given dimensions `dim0` and `dim1` are swapped.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
t = Tensor.arange(6).reshape(2, 3)
|
||||||
|
print(t.numpy())
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(t.transpose(0, 1).numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
order = list(range(self.ndim))
|
||||||
|
order[dim0], order[dim1] = order[dim1], order[dim0]
|
||||||
|
return self.permute(order)
|
||||||
|
|
||||||
def flatten(self, start_dim=0, end_dim=-1) -> Self:
|
def flatten(self, start_dim=0, end_dim=-1) -> Self:
|
||||||
"""
|
"""
|
||||||
@@ -77,4 +198,61 @@ class MovementMixin:
|
|||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
start_dim, end_dim = self._resolve_dim(start_dim), self._resolve_dim(end_dim)
|
||||||
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
return self.reshape(self.shape[:start_dim] + (prod(self.shape[start_dim:end_dim+1]), ) + self.shape[end_dim+1:])
|
||||||
|
|
||||||
|
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Self:
|
||||||
|
"""
|
||||||
|
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
||||||
|
```
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
dim = self._resolve_dim(dim)
|
||||||
|
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
||||||
|
|
||||||
|
def rearrange(self, formula:str, **sizes) -> Self:
|
||||||
|
"""
|
||||||
|
Rearranges input according to formula
|
||||||
|
|
||||||
|
See: https://einops.rocks/api/rearrange/
|
||||||
|
|
||||||
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
|
x = Tensor([[1, 2], [3, 4]])
|
||||||
|
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def parse_formula(formula: str):
|
||||||
|
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
||||||
|
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
||||||
|
pairs = list(zip(lparens, rparens))
|
||||||
|
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
||||||
|
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
||||||
|
|
||||||
|
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
||||||
|
|
||||||
|
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
||||||
|
|
||||||
|
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
||||||
|
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
||||||
|
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
||||||
|
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
||||||
|
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
||||||
|
|
||||||
|
# resolve ellipsis
|
||||||
|
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
||||||
|
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
||||||
|
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
||||||
|
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
||||||
|
|
||||||
|
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
||||||
|
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
||||||
|
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
||||||
|
t = t.permute([lhs.index(name) for name in rhs])
|
||||||
|
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ def handle_allreduce(buf:UOp, red:UOp) -> UOp|None:
|
|||||||
# allgather
|
# allgather
|
||||||
copied_chunks = []
|
copied_chunks = []
|
||||||
for i,c in enumerate(reduced_chunks):
|
for i,c in enumerate(reduced_chunks):
|
||||||
this_chunk = [None] * len(buf.device)
|
this_chunk: list[UOp|None] = [None] * len(buf.device)
|
||||||
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
|
this_chunk[(i+len(buf.device)-1)%n_lbs] = c
|
||||||
for step in range(n_lbs-1):
|
for step in range(n_lbs-1):
|
||||||
dest = (i+step)%n_lbs
|
dest = (i+step)%n_lbs
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from contextlib import ContextDecorator
|
|||||||
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic
|
from typing import Callable, ClassVar, Sequence, cast, get_args, Literal, SupportsIndex, ParamSpec, TypeVar, Generic
|
||||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||||
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
from tinygrad.dtype import _from_np_dtype, _to_np_dtype
|
||||||
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten
|
||||||
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC
|
from tinygrad.helpers import IMAGE, WINO, Metadata, TRACEMETA, ceildiv, fetch, polyN, DEBUG, is_numpy_ndarray, SPEC
|
||||||
from tinygrad.helpers import suppress_finalizing
|
from tinygrad.helpers import suppress_finalizing
|
||||||
from tinygrad.gradient import compute_gradient
|
from tinygrad.gradient import compute_gradient
|
||||||
@@ -1055,65 +1055,6 @@ class Tensor(OpMixin):
|
|||||||
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
|
new_shape = tuple(from_ if to == -1 or to is None else to for from_, to in zip(*(_align_left(self.shape, argfix(shape, *args)))))
|
||||||
return self._broadcast_to(new_shape)
|
return self._broadcast_to(new_shape)
|
||||||
|
|
||||||
def permute(self, order, *args) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor that is a permutation of the original tensor.
|
|
||||||
The new tensor has the same data as the original tensor but with the dimensions permuted according to the order specified.
|
|
||||||
`order` can be passed as a tuple or as separate arguments.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.empty(2, 3, 5)
|
|
||||||
print(t.shape)
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.permute(2, 0, 1).shape)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
order_arg = tuple(self._resolve_dim(x) for x in argfix(order, *args))
|
|
||||||
if sorted(order_arg) != list(range(self.ndim)): raise RuntimeError(f"order is not a valid permutation, getting {order_arg}")
|
|
||||||
return self._apply_uop(UOp.permute, arg=order_arg)
|
|
||||||
|
|
||||||
def flip(self, axis, *args) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor that reverses the order of the original tensor along given `axis`.
|
|
||||||
`axis` can be passed as a tuple or as separate arguments.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.arange(6).reshape(2, 3)
|
|
||||||
print(t.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.flip(0).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.flip((0, 1)).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
axis_arg = tuple(self._resolve_dim(x) for x in argfix(axis, *args))
|
|
||||||
if len(axis_arg) != len(dedup(axis_arg)): raise RuntimeError(f"dim can appear at most once, getting {axis_arg}")
|
|
||||||
return self._apply_uop(UOp.flip, arg=tuple([i in axis_arg for i in range(len(self.shape))]))
|
|
||||||
|
|
||||||
def shrink(self, arg:tuple[tuple[sint, sint]|None, ...]) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor that shrinks the each axis based on input arg.
|
|
||||||
`arg` must have the same length as `self.ndim`.
|
|
||||||
For each axis, it can be `None`, which means no shrink, or a tuple `(start, end)` that works the same as Python slice.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.arange(9).reshape(3, 3)
|
|
||||||
print(t.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.shrink(((None, (1, 3)))).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.shrink((((0, 2), (0, 2)))).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if self.ndim != len(arg): raise ValueError(f"{self.ndim=} != {len(arg)=}")
|
|
||||||
if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self
|
|
||||||
return self._apply_uop(UOp.shrink, arg=tuple(shrink_arg))
|
|
||||||
|
|
||||||
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
|
def pad(self, padding:Sequence[sint]|Sequence[tuple[sint, sint]|None], mode:str="constant", value:float=0.0) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a tensor with padding applied based on the input `padding`.
|
Returns a tensor with padding applied based on the input `padding`.
|
||||||
@@ -1535,80 +1476,6 @@ class Tensor(OpMixin):
|
|||||||
output_shape = _broadcast_shape(*(t.shape for t in tensors))
|
output_shape = _broadcast_shape(*(t.shape for t in tensors))
|
||||||
return tuple(t._broadcast_to(output_shape) for t in tensors)
|
return tuple(t._broadcast_to(output_shape) for t in tensors)
|
||||||
|
|
||||||
def squeeze(self, dim:int|None=None) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor with specified dimensions of input of size 1 removed.
|
|
||||||
If `dim` is not specified, all dimensions with size 1 are removed.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.zeros(2, 1, 2, 1, 2)
|
|
||||||
print(t.squeeze().shape)
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.squeeze(0).shape)
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.squeeze(1).shape)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if dim is None: return self.reshape(tuple(dim for dim in self.shape if dim != 1))
|
|
||||||
dim = self._resolve_dim(dim)
|
|
||||||
return self if not self.ndim or self.shape[dim] != 1 else self.reshape(self.shape[:dim] + self.shape[dim+1:])
|
|
||||||
|
|
||||||
def unsqueeze(self, dim:int) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor with a new dimension of size 1 inserted at the specified `dim`.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor([1, 2, 3, 4])
|
|
||||||
print(t.unsqueeze(0).numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.unsqueeze(1).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
dim = self._resolve_dim(dim, extra=True)
|
|
||||||
return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def T(self) -> Tensor:
|
|
||||||
"""`.T` is an alias for `.transpose()`."""
|
|
||||||
return self.transpose()
|
|
||||||
|
|
||||||
def transpose(self, dim0=1, dim1=0) -> Tensor:
|
|
||||||
"""
|
|
||||||
Returns a tensor that is a transposed version of the original tensor.
|
|
||||||
The given dimensions `dim0` and `dim1` are swapped.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
t = Tensor.arange(6).reshape(2, 3)
|
|
||||||
print(t.numpy())
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(t.transpose(0, 1).numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
order = list(range(self.ndim))
|
|
||||||
order[dim0], order[dim1] = order[dim1], order[dim0]
|
|
||||||
return self.permute(order)
|
|
||||||
|
|
||||||
def unflatten(self, dim:int, sizes:tuple[int,...]) -> Tensor:
|
|
||||||
"""
|
|
||||||
Unflattens dimension `dim` of the tensor into multiple dimensions specified by `sizes`. `Tensor.flatten()` is the inverse of this function.
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.ones(3, 4, 1).unflatten(1, (2, 2)).shape)
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.ones(3, 4, 1).unflatten(1, (-1, 2)).shape)
|
|
||||||
```
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
print(Tensor.ones(5, 12, 3).unflatten(-2, (2, 2, 3, 1, 1)).shape)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
dim = self._resolve_dim(dim)
|
|
||||||
return self.reshape(self.shape[:dim] + sizes + self.shape[dim+1:])
|
|
||||||
|
|
||||||
def diag(self) -> Tensor:
|
def diag(self) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Returns a 2-D square tensor with the elements of input as the main diagonal.
|
Returns a 2-D square tensor with the elements of input as the main diagonal.
|
||||||
@@ -1654,46 +1521,6 @@ class Tensor(OpMixin):
|
|||||||
for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
|
for dim, shift in zip(dims, shifts): slices[dim] = slice(delta:=self.shape[dim]-shift%self.shape[dim], delta+self.shape[dim])
|
||||||
return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
|
return self.repeat(*tuple(2 if i in dims else 1 for i in range(self.ndim)))[slices]
|
||||||
|
|
||||||
def rearrange(self, formula:str, **sizes) -> Tensor:
|
|
||||||
"""
|
|
||||||
Rearranges input according to formula
|
|
||||||
|
|
||||||
See: https://einops.rocks/api/rearrange/
|
|
||||||
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
|
||||||
x = Tensor([[1, 2], [3, 4]])
|
|
||||||
print(Tensor.rearrange(x, "batch channel -> (batch channel)").numpy())
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
def parse_formula(formula: str):
|
|
||||||
tokens = f" {formula} ".replace("…", "...").replace("(", " ( ").replace(")", " ) ").replace(" ", " ").replace(" 1 ", " ( ) ").split()
|
|
||||||
lparens, rparens = map(lambda x: [i for i, ch in enumerate(tokens) if ch == x], ("(", ")"))
|
|
||||||
pairs = list(zip(lparens, rparens))
|
|
||||||
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
|
||||||
return [name for name in tokens if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
|
||||||
|
|
||||||
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
|
||||||
|
|
||||||
(lhs, unflatten_dims), (rhs, flatten_dims) = map(parse_formula, formula.split("->"))
|
|
||||||
|
|
||||||
for name in sizes: assert name in lhs, f"axis {name} is not used in transform"
|
|
||||||
assert sorted(lhs) == sorted(rhs) and len(lhs) == len(set(lhs)), f"name mismatch in {formula}"
|
|
||||||
for name in flatten((lhs, rhs)): assert name == "..." or (name.isidentifier() and "_" not in (name[0], name[-1])), f"invalid axis name {name}"
|
|
||||||
assert "..." not in flatten([lhs[s:e] for s, e in unflatten_dims]), f"cannot have collapsed ellipsis (...) in lhs of {formula}"
|
|
||||||
assert lhs.count("...") <= 1, f"too many ellipses in {formula}"
|
|
||||||
|
|
||||||
# resolve ellipsis
|
|
||||||
if "..." in lhs: ell_len = len(self.shape) - len(lhs) + 1 + sum(e - s - 1 for s, e in unflatten_dims)
|
|
||||||
lhs, rhs = map(lambda l: l[:(i:=l.index("..."))] + [f"...{j}" for j in range(ell_len)] + l[i + 1:] if "..." in l else l, (lhs, rhs))
|
|
||||||
unflatten_dims = [(s + (ell_len - 1 if "...0" in lhs[:s] else 0), e + (ell_len - 1 if "...0" in lhs[:e] else 0)) for s, e in unflatten_dims]
|
|
||||||
flatten_dims = [(s + (ell_len - 1 if "...0" in rhs[:s] else 0), e + (ell_len - 1 if "...0" in rhs[:e] else 0)) for s, e in flatten_dims]
|
|
||||||
|
|
||||||
# apply movement ops in order unflatten -> permute -> flatten/unsqueeze
|
|
||||||
t = functools.reduce(lambda x, dims: x.unflatten(dims[0], tuple(sizes.get(lhs[d], -1) for d in range(*dims))), unflatten_dims, self)
|
|
||||||
for i, name in enumerate(lhs): assert (name not in sizes) or sizes[name] == t.shape[i], f"size provided for dimension {name} incorrect"
|
|
||||||
t = t.permute([lhs.index(name) for name in rhs])
|
|
||||||
return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0]<dims[1] else x.unsqueeze(dims[0]), reversed(flatten_dims), t)
|
|
||||||
|
|
||||||
def masked_select(self, mask):
|
def masked_select(self, mask):
|
||||||
"""
|
"""
|
||||||
Selects elements from `self` based on the boolean `mask`.
|
Selects elements from `self` based on the boolean `mask`.
|
||||||
|
|||||||
@@ -549,12 +549,12 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
|||||||
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
def forced_reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=False)
|
||||||
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
#def reshape(self, arg:tuple[sint, ...]): return self._mop(Ops.RESHAPE, arg, same_shape_noop=True)
|
||||||
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
def expand(self, arg:tuple[sint, ...]): return self._mop(Ops.EXPAND, arg, same_shape_noop=True)
|
||||||
def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
#def shrink(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.SHRINK, arg, same_shape_noop=True)
|
||||||
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
def pad(self, arg:tuple[tuple[sint, sint], ...]): return self._mop(Ops.PAD, arg, same_shape_noop=True)
|
||||||
|
|
||||||
# in these two, we have custom logic to check if they are a no-op
|
# in these two, we have custom logic to check if they are a no-op
|
||||||
def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
|
#def permute(self, arg:tuple[int, ...]): return self._mop(Ops.PERMUTE, arg, same_shape_noop=False) if arg != tuple(range(len(self.shape))) else self
|
||||||
def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
|
#def flip(self, arg:tuple[bool, ...]): return self._mop(Ops.FLIP, arg, same_shape_noop=False) if any(arg) and len(arg) == len(self.shape) else self
|
||||||
|
|
||||||
# *** uop UNIQUE ***
|
# *** uop UNIQUE ***
|
||||||
|
|
||||||
@@ -1322,7 +1322,8 @@ renderer_infer = PatternMatcher([
|
|||||||
|
|
||||||
def srcs(ctx, src): return f"({ctx[src[0]]},)" if len(src) == 1 else f"({', '.join([ctx[x] for x in src])})"
|
def srcs(ctx, src): return f"({ctx[src[0]]},)" if len(src) == 1 else f"({', '.join([ctx[x] for x in src])})"
|
||||||
def render_marg(ctx,x:UOp):
|
def render_marg(ctx,x:UOp):
|
||||||
if x.op in {Ops.PERMUTE, Ops.FLIP}: return str(x.marg)
|
if x.op is Ops.PERMUTE: return str(x.marg)
|
||||||
|
if x.op is Ops.FLIP: return str(tuple([i for i,x in enumerate(x.marg) if x]))
|
||||||
pieces = []
|
pieces = []
|
||||||
if x.op in {Ops.RESHAPE, Ops.EXPAND}:
|
if x.op in {Ops.RESHAPE, Ops.EXPAND}:
|
||||||
pieces = [f"{ctx[a] if isinstance(a, UOp) else str(a)}" for a in x.marg]
|
pieces = [f"{ctx[a] if isinstance(a, UOp) else str(a)}" for a in x.marg]
|
||||||
|
|||||||
Reference in New Issue
Block a user