Gather bugfix (#1561)

This commit is contained in:
geohotstan
2023-08-16 16:53:14 -07:00
committed by GitHub
parent cb62911f6b
commit a293c18d34
2 changed files with 2 additions and 1 deletions

View File

@@ -359,7 +359,7 @@ class Tensor:
def gather(self: Tensor, idx: Tensor, dim: int):
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(s > i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
if dim < 0: dim += self.ndim
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))