From 9cfc4f68c82cd8fa7b405c7e7dd8ddcd9d2c618c Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 14 Nov 2024 13:46:02 -0500 Subject: [PATCH] clean up Tensor.cat (#7701) --- tinygrad/tensor.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 9ec0957755..99f484f269 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1249,13 +1249,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ dim = self._resolve_dim(dim) - assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args) - catargs = [self, *args] - cat_dims = [s.shape[dim] for s in catargs] - cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)] - slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs] - for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d) - return functools.reduce(Tensor.add, [arg.pad(s) for arg,s in zip(catargs, slc)]) + for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim) + tensors = [self, *args] + dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) + for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) + return functools.reduce(Tensor.add, tensors) def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """ @@ -1270,7 +1268,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ # checks for shapes and number of dimensions delegated to cat - return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim) + return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim) def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor: """