diff --git a/test/null/test_tensor_uop_mixin.py b/test/null/test_tensor_uop_mixin.py index 2e5ce01ea5..2f05a6f3a3 100644 --- a/test/null/test_tensor_uop_mixin.py +++ b/test/null/test_tensor_uop_mixin.py @@ -58,5 +58,17 @@ class TestTensorUOpCumalu(unittest.TestCase): def test_cumsum_large(self): _check(self, _t(600), lambda x: x.cumsum()) # exercises _split_cumalu def test_cumprod(self): _check(self, _t(4), lambda x: x.cumprod(0)) +class TestTensorUOpCat(unittest.TestCase): + def test_cat_dim0(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=0)) + def test_cat_dim1(self): _check(self, _t(2, 3), lambda x: x.cat(x, dim=1)) + def test_cat_3tensors(self): _check(self, _t(2, 3), lambda x: x.cat(x, x, dim=0)) + def test_cat_neg_dim(self): _check(self, _t(2, 3, 4), lambda x: x.cat(x, dim=-1)) + +class TestTensorUOpStack(unittest.TestCase): + def test_stack_dim0(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=0)) + def test_stack_dim1(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=1)) + def test_stack_3tensors(self): _check(self, _t(2, 3), lambda x: x.stack(x, x, dim=0)) + def test_stack_new_last(self): _check(self, _t(2, 3), lambda x: x.stack(x, dim=-1)) + if __name__ == "__main__": unittest.main() diff --git a/tinygrad/mixin/__init__.py b/tinygrad/mixin/__init__.py index efd645e4af..8aa9e38c0f 100644 --- a/tinygrad/mixin/__init__.py +++ b/tinygrad/mixin/__init__.py @@ -1,4 +1,4 @@ -import functools +import functools, itertools from typing import Self, Sequence, Literal, get_args from tinygrad.mixin.elementwise import ElementwiseMixin from tinygrad.mixin.movement import MovementMixin @@ -260,6 +260,42 @@ class OpMixin(ElementwiseMixin, ReduceMixin): m = self.max(axis=axis, keepdim=True) return (self - m).exp().sum(axis=axis, keepdim=keepdim).log() + (m if keepdim else m.squeeze(axis)) + def cat(self, *args:Self, dim:int=0) -> Self: + """ + Concatenates self with other tensors in `args` along an axis specified by `dim`. + All tensors must have the same shape except in the concatenating dimension. + + ```python exec="true" source="above" session="tensor" result="python" + t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]]) + print(t0.cat(t1, t2, dim=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t0.cat(t1, t2, dim=1).numpy()) + ``` + """ + dim = self._resolve_dim(dim) + for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim) + tensors = [self, *args] + dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) + padded = [t.pad(tuple((dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim))) for i,t in enumerate(tensors)] + return padded[0].usum(*padded[1:]) + + def stack(self, *args:Self, dim:int=0) -> Self: + """ + Concatenates self with other tensors in `args` along a new dimension specified by `dim`. + + ```python exec="true" source="above" session="tensor" result="python" + t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6]) + print(t0.stack(t1, t2, dim=0).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t0.stack(t1, t2, dim=1).numpy()) + ``` + """ + # checks for shapes and number of dimensions delegated to cat + unsqueezed = [t.unsqueeze(dim) for t in argfix(self, *args)] + return unsqueezed[0].cat(*unsqueezed[1:], dim=dim) + def _cumalu(self, axis:int, op:Ops) -> Self: assert self.shape[axis] != 0 and op in (Ops.ADD, Ops.MAX, Ops.MUL) pads = (None,)*(self.ndim-1) + ((self.shape[axis]-1, 0),) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 8961e98bd7..53ed081ef2 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1298,41 +1298,6 @@ class Tensor(OpMixin): x = self.shrink_to(tuple(i if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) return (index.unsqueeze(-1)._one_hot_along_dim(self.shape[dim]).where(x, 0)).sum(-1, dtype=self.dtype) - def cat(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: - """ - Concatenates self with other `Tensor` in `args` along an axis specified by `dim`. - All tensors must have the same shape except in the concatenating dimension. - - ```python exec="true" source="above" session="tensor" result="python" - t0, t1, t2 = Tensor([[1, 2]]), Tensor([[3, 4]]), Tensor([[5, 6]]) - print(t0.cat(t1, t2, dim=0).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t0.cat(t1, t2, dim=1).numpy()) - ``` - """ - dim = self._resolve_dim(dim) - for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim) - tensors = [self, *args] - dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) - for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) - return Tensor.usum(*tensors) - - def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: - """ - Concatenates self with other `Tensor` in `args` along a new dimension specified by `dim`. - - ```python exec="true" source="above" session="tensor" result="python" - t0, t1, t2 = Tensor([1, 2]), Tensor([3, 4]), Tensor([5, 6]) - print(t0.stack(t1, t2, dim=0).numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t0.stack(t1, t2, dim=1).numpy()) - ``` - """ - # checks for shapes and number of dimensions delegated to cat - return Tensor.cat(*[t.unsqueeze(dim) for t in argfix(self, *args)], dim=dim) - def masked_select(self, mask): """ Selects elements from `self` based on the boolean `mask`.