diff --git a/test/test_ops.py b/test/test_ops.py index b25bc0d00e..c346b4cdbd 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1155,6 +1155,7 @@ class TestOps(unittest.TestCase): helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0)) helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=1), lambda x: x.gather(idx=a, dim=1)) helper_test_op([(4,5,6)], lambda x: x.gather(index=b, dim=2), lambda x: x.gather(idx=a, dim=2)) + helper_test_op([(3,4,5)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0)) self.helper_test_exception([(4,5,6)], lambda x: x.gather(index=torch.tensor([1], dtype=torch.int64), dim=0), lambda x: x.gather(idx=Tensor([1], dtype=dtypes.int32), dim=0), expected=(RuntimeError, AssertionError)) self.helper_test_exception([(2,1,1)], lambda x: x.gather(index=b, dim=0), lambda x: x.gather(idx=a, dim=0), expected=(RuntimeError, AssertionError)) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d51c70bd79..b26c354606 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -359,7 +359,7 @@ class Tensor: def gather(self: Tensor, idx: Tensor, dim: int): assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" - assert all(s > i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" + assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" if dim < 0: dim += self.ndim idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) permarg = list(range(self.ndim))