cat the stack to mixin (#15715)

This commit is contained in:
chenyu
2026-04-13 18:44:39 -04:00
committed by GitHub
parent 355e2729d3
commit 70883a6950
3 changed files with 49 additions and 36 deletions

View File

@@ -58,5 +58,17 @@ class TestTensorUOpCumalu(unittest.TestCase):
def test_cumsum_large(self): _check(self, _t(600), lambda x: x.cumsum()) # exercises _split_cumalu
def test_cumprod(self): _check(self, _t(4), lambda x: x.cumprod(0))
class TestTensorUOpCat(unittest.TestCase):
def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0))
def test_cat_dim1(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=1))
def test_cat_3tensors(self): _check(self, _t(2, 3), lambda x: x.cat(x, x, dim=0))
def test_cat_neg_dim(self): _check(self, _t(2, 3, 4), lambda x: x.cat(x, dim=-1))
class TestTensorUOpStack(unittest.TestCase):
def test_stack_dim0(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=0))
def test_stack_dim1(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=1))
def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0))
def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1))
if __name__ == "__main__":
unittest.main()

View File

@@ -1,4 +1,4 @@
import functools
import functools, itertools
from typing import Self, Sequence, Literal, get_args
from tinygrad.mixin.elementwise import ElementwiseMixin
from tinygrad.mixin.movement import MovementMixin
@@ -260,6 +260,42 @@ class OpMixin(ElementwiseMixin, ReduceMixin):
m = self.max(axis=axis, keepdim=True)
return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis))
def cat(self, *args:Self, dim:int=0) -> Self:
"""
Concatenates self with other tensors in `args` along an axis specified by `dim`.
All tensors must have the same shape except in the concatenating dimension.
```python exec="true" source="above" session="tensor" result="python"
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
print(t0.cat(t1, t2, dim=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t0.cat(t1, t2, dim=1).numpy())
```
"""
dim = self._resolve_dim(dim)
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
tensors = [self, *args]
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
padded = [t.pad(tuple((dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim))) for i,t in enumerate(tensors)]
return padded[0].usum(*padded[1:])
def stack(self, *args:Self, dim:int=0) -> Self:
"""
Concatenates self with other tensors in `args` along a new dimension specified by `dim`.
```python exec="true" source="above" session="tensor" result="python"
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(t0.stack(t1, t2, dim=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t0.stack(t1, t2, dim=1).numpy())
```
"""
# checks for shapes and number of dimensions delegated to cat
unsqueezed = [t.unsqueeze(dim) for t in argfix(self, *args)]
return unsqueezed[0].cat(*unsqueezed[1:], dim=dim)
def _cumalu(self, axis:int, op:Ops) -> Self:
assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL)
pads = (None,)*(self.ndim-1) + ((self.shape[axis]-1, 0),)

View File

@@ -1298,41 +1298,6 @@ class Tensor(OpMixin):
x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim)
return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype)
def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
Concatenates self with other `Tensor` in `args` along an axis specified by `dim`.
All tensors must have the same shape except in the concatenating dimension.
```python exec="true" source="above" session="tensor" result="python"
t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]])
print(t0.cat(t1, t2, dim=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t0.cat(t1, t2, dim=1).numpy())
```
"""
dim = self._resolve_dim(dim)
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
tensors = [self, *args]
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
return Tensor.usum(*tensors)
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`.
```python exec="true" source="above" session="tensor" result="python"
t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6])
print(t0.stack(t1, t2, dim=0).numpy())
```
```python exec="true" source="above" session="tensor" result="python"
print(t0.stack(t1, t2, dim=1).numpy())
```
"""
# checks for shapes and number of dimensions delegated to cat
return Tensor.cat(*[t.unsqueeze(dim) for t in argfix(self, *args)], dim=dim)
def masked_select(self, mask):
"""
Selects elements from `self` based on the boolean `mask`.