mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 02:21:40 -05:00
Gather bugfix (#1561)
This commit is contained in:
@@ -359,7 +359,7 @@ class Tensor:
|
||||
|
||||
def gather(self: Tensor, idx: Tensor, dim: int):
|
||||
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
|
||||
assert all(s > i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
|
||||
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
|
||||
if dim < 0: dim += self.ndim
|
||||
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
||||
permarg = list(range(self.ndim))
|
||||
|
||||
Reference in New Issue
Block a user