diff --git a/test/test_ops.py b/test/test_ops.py index 46f86364f5..a6b631d85d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1857,6 +1857,9 @@ class TestOps(unittest.TestCase): helper_test_op([(4,5,6)], lambda x: x.gather(dim=1, index=b), lambda x: x.gather(dim=1, index=a)) helper_test_op([(4,5,6)], lambda x: x.gather(dim=2, index=b), lambda x: x.gather(dim=2, index=a)) helper_test_op([(3,4,5)], lambda x: x.gather(dim=0, index=b), lambda x: x.gather(dim=0, index=a)) + helper_test_op([(4,5,6)], lambda x: x.gather(dim=-1, index=b), lambda x: x.gather(dim=-1, index=a)) + helper_test_op([(4,5,6)], lambda x: x.gather(dim=-2, index=b), lambda x: x.gather(dim=-2, index=a)) + helper_test_op([(4,5,6)], lambda x: x.gather(dim=-3, index=b), lambda x: x.gather(dim=-3, index=a)) self.helper_test_exception([(4,5,6)], lambda x: x.gather(dim=0, index=torch.tensor([1], dtype=torch.int64)), lambda x: x.gather(dim=0, index=Tensor([1], dtype=dtypes.int32)), expected=(RuntimeError, AssertionError)) self.helper_test_exception([(2,1,1)], lambda x: x.gather(dim=0, index=b), diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 0155a8f303..0c29a3eb70 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1041,8 +1041,8 @@ class Tensor: ``` """ assert index.ndim == self.ndim, f"self.ndim must equal index.ndim, {self.ndim=}, {index.ndim=}" - assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" dim = self._resolve_dim(dim) + assert all(s >= i for d,(s,i) in enumerate(zip(self.shape, index.shape)) if d != dim), "requires self.shape[d] >= index.shape[d] for all d != dim" index = index.to(self.device) x = self.shrink(tuple((0, i) if d != dim else None for d,i in enumerate(index.shape))).unsqueeze(-1).transpose(-1, dim) return ((index.unsqueeze(-1) == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * x).sum(-1, acc_dtype=self.dtype)