This commit is contained in:
geohotstan
2023-07-25 12:05:41 -07:00
committed by GitHub
parent 9d142430cb
commit 4056f97187
3 changed files with 25 additions and 1 deletions

View File

@@ -312,6 +312,13 @@ class Tensor:
final_shape.append(1)
return sliced_tensor.reshape(tuple(final_shape)) # Reshape
def gather(self, idx, dim):
idx = (idx < 0).where(idx+self.shape[dim], idx) # Turn neg idx pos
new_self = self.reshape(*self.shape[:dim+1], *[1]*idx.ndim, *self.shape[dim+1:])
arange = Tensor.arange(self.shape[dim], dtype=dtypes.int32, requires_grad=False).reshape(*[1]*dim, self.shape[dim], *[1]*idx.ndim, *[1]*(self.ndim-dim-1))
new_idx = idx.reshape(*[1]*dim, 1, *idx.shape, *[1]*(self.ndim-dim-1))
return (new_self * (arange == new_idx)).sum(dim)
def cat(self, *args, dim=0):
dim = (dim + len(self.shape)) if dim < 0 else dim
assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args)