mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
Small Tensor.cat optimization and reformating (#1347)
This commit is contained in:
@@ -322,13 +322,14 @@ class Tensor:
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
||||
catargs = [self] + list(args)
|
||||
assert all(len(t.shape) != 0 for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
||||
shape_cumsum = [0, *accumulate([y.shape[dim] for y in catargs])]
|
||||
slc = [[(0, s) for s in self.shape] for _ in catargs]
|
||||
for s,k in zip(slc, shape_cumsum):
|
||||
s[dim] = (-k, shape_cumsum[-1]-k)
|
||||
return reduce(Tensor.__add__, [arg.slice(s) for arg,s in zip(catargs, slc)])
|
||||
catargs = [self, *args]
|
||||
assert all(t.shape for t in catargs), "zero-dimensional tensor cannot be concatenated"
|
||||
shapes = [s.shape[dim] for s in catargs]
|
||||
shape_cumsum = [0, *accumulate(shapes)]
|
||||
slc = [[(0, 0) for _ in self.shape] for _ in catargs]
|
||||
for shp,k,s in zip(shapes, shape_cumsum[:-1], slc):
|
||||
s[dim] = (k, shape_cumsum[-1] - k - shp)
|
||||
return reduce(Tensor.__add__, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)])
|
||||
|
||||
@staticmethod
|
||||
def stack(tensors, dim=0):
|
||||
|
||||
Reference in New Issue
Block a user