clean up Tensor.cat (#7701)

This commit is contained in:
chenyu
2024-11-14 13:46:02 -05:00
committed by GitHub
parent 888fcb3643
commit 9cfc4f68c8

View File

@@ -1249,13 +1249,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
```
"""
dim = self._resolve_dim(dim)
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self, *args]
cat_dims = [s.shape[dim] for s in catargs]
cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)]
slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs]
for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d)
return functools.reduce(Tensor.add, [arg.pad(s) for arg,s in zip(catargs, slc)])
for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim)
tensors = [self, *args]
dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0))
for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)])
return functools.reduce(Tensor.add, tensors)
def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor:
"""
@@ -1270,7 +1268,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method
```
"""
# checks for shapes and number of dimensions delegated to cat
return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim)
return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim)
def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor:
"""