mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Gather (#1329)
This commit is contained in:
@@ -312,6 +312,13 @@ class Tensor:
|
||||
final_shape.append(1)
|
||||
return sliced_tensor.reshape(tuple(final_shape)) # Reshape
|
||||
|
||||
def gather(self, idx, dim):
|
||||
idx = (idx < 0).where(idx+self.shape[dim], idx) # Turn neg idx pos
|
||||
new_self = self.reshape(*self.shape[:dim+1], *[1]*idx.ndim, *self.shape[dim+1:])
|
||||
arange = Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim, self.shape[dim], *[1]*idx.ndim, *[1]*(self.ndim-dim-1))
|
||||
new_idx = idx.reshape(*[1]*dim, 1, *idx.shape, *[1]*(self.ndim-dim-1))
|
||||
return (new_self * (arange == new_idx)).sum(dim)
|
||||
|
||||
def cat(self, *args, dim=0):
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)
|
||||
|
||||
Reference in New Issue
Block a user