mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
cat the stack to mixin (#15715)
This commit is contained in:
@@ -58,5 +58,17 @@ class TestTensorUOpCumalu(unittest.TestCase):
|
||||
def test_cumsum_large(self): _check(self, _t(600), lambda x: x.cumsum()) # exercises _split_cumalu
|
||||
def test_cumprod(self): _check(self, _t(4), lambda x: x.cumprod(0))
|
||||
|
||||
class TestTensorUOpCat(unittest.TestCase):
|
||||
def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0))
|
||||
def test_cat_dim1(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=1))
|
||||
def test_cat_3tensors(self): _check(self, _t(2, 3), lambda x: x.cat(x, x, dim=0))
|
||||
def test_cat_neg_dim(self): _check(self, _t(2, 3, 4), lambda x: x.cat(x, dim=-1))
|
||||
|
||||
class TestTensorUOpStack(unittest.TestCase):
|
||||
def test_stack_dim0(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=0))
|
||||
def test_stack_dim1(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=1))
|
||||
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
|
||||
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import functools
|
||||
import functools, itertools
|
||||
from typing import Self, Sequence, Literal, get_args
|
||||
from tinygrad.mixin.elementwise import ElementwiseMixin
|
||||
from tinygrad.mixin.movement import MovementMixin
|
||||
@@ -260,6 +260,42 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
|
||||
m = self.max(axis=axis, keepdim=True)
|
||||
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis))
|
||||
|
||||
def cat(self, *args:Self, dim:int=0) -> Self:
|
||||
"""
|
||||
Concatenates self with other tensors in `args` along an axis specified by `dim`.
|
||||
All tensors must have the same shape except in the concatenating dimension.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
|
||||
print(t0.cat(t1, t2, dim=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t0.cat(t1, t2, dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim)
|
||||
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
|
||||
tensors = [self, *args]
|
||||
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
|
||||
padded = [t.pad(tuple((dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim))) for i,t in enumerate(tensors)]
|
||||
return padded[0].usum(*padded[1:])
|
||||
|
||||
def stack(self, *args:Self, dim:int=0) -> Self:
|
||||
"""
|
||||
Concatenates self with other tensors in `args` along a new dimension specified by `dim`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
|
||||
print(t0.stack(t1, t2, dim=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t0.stack(t1, t2, dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
# checks for shapes and number of dimensions delegated to cat
|
||||
unsqueezed = [t.unsqueeze(dim) for t in argfix(self, *args)]
|
||||
return unsqueezed[0].cat(*unsqueezed[1:], dim=dim)
|
||||
|
||||
def _cumalu(self, axis:int, op:Ops) -> Self:
|
||||
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
|
||||
pads = (None,)*(self.ndim-1) + ((self.shape[axis]-1, 0),)
|
||||
|
||||
@@ -1298,41 +1298,6 @@ class Tensor(OpMixin):
|
||||
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
|
||||
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
|
||||
|
||||
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
|
||||
All tensors must have the same shape except in the concatenating dimension.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
|
||||
print(t0.cat(t1, t2, dim=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t0.cat(t1, t2, dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
dim = self._resolve_dim(dim)
|
||||
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
|
||||
tensors = [self, *args]
|
||||
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
|
||||
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
|
||||
return Tensor.usum(*tensors)
|
||||
|
||||
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
|
||||
"""
|
||||
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
|
||||
print(t0.stack(t1, t2, dim=0).numpy())
|
||||
```
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
print(t0.stack(t1, t2, dim=1).numpy())
|
||||
```
|
||||
"""
|
||||
# checks for shapes and number of dimensions delegated to cat
|
||||
return Tensor.cat(*[t.unsqueeze(dim) for t in argfix(self, *args)], dim=dim)
|
||||
|
||||
def masked_select(self, mask):
|
||||
"""
|
||||
Selects elements from `self` based on the boolean `mask`.
|
||||
|
||||
Reference in New Issue
Block a user