Small Tensor.cat optimization and reformating (#1347)

This commit is contained in:
Umut Zengin
2023-07-27 01:01:12 +03:00
committed by GitHub
parent 4056f97187
commit d4ebadf2da
3 changed files with 14 additions and 9 deletions

View File

@@ -322,13 +322,14 @@ class Tensor:
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
catargs = [self] + list(args)
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
slc = [[(0, s) for s in self.shape] for _ in catargs]
for s,k in zip(slc, shape_cumsum):
s[dim] = (-k, shape_cumsum[-1]-k)
return reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)])
catargs = [self, *args]
assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
shapes = [s.shape[dim] for s in catargs]
shape_cumsum = [0, *accumulate(shapes)]
slc = [[(0, 0) for _ in self.shape] for _ in catargs]
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
s[dim] = (k, shape_cumsum[-1] - k - shp)
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
@staticmethod
def stack(tensors, dim=0):