From a67e0917c3710f761c461dc9a5be193163da1286 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 10 Aug 2025 17:02:38 -0700 Subject: [PATCH] list indexing can normalize in python (#11609) * list indexing can normalize in python list index does not need to be normalized in tensor * update those --- test/test_ops.py | 4 ++-- test/test_schedule.py | 16 ++++++++-------- tinygrad/tensor.py | 8 +++++--- 3 files changed, 15 insertions(+), 13 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 582876ccd8..df2feb9512 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2742,10 +2742,10 @@ class TestOps(unittest.TestCase): helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[0]]], lambda x: x[[[0]]]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[[0],b,c,d,:], lambda x: x[[0],j,k,o,:]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[[0]]],b,c,d,[[1]]], lambda x: x[[[[0]]],j,k,o,[[1]]]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0],b,c,d,:], lambda x: x[[1,0],j,k,o,:]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0,-1],b,c,d,:], lambda x: x[[1,0,-1],j,k,o,:]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...]) helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[[1],[2],[3]],...], lambda x: x[i,j,k,[[1],[2],[3]],...]) - helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p]) + helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[-2,1,0],e], lambda x: x[i,[2,1,0],k,[-2,1,0],p]) def test_slice_fancy_indexing_tuple_indices(self): a,b,c,d,e,i,j,k,o,p = self._get_index_randoms() diff --git a/test/test_schedule.py b/test/test_schedule.py index 59aba3ee74..94d338b543 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -1741,15 +1741,15 @@ class TestIndexing(unittest.TestCase): def test_simple_indexing_alt(self): X = Tensor.arange(16).reshape(4, 4) - xt = X[[1, 2], [1, 2]] - self.check_schedule(xt, 3) - np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]]) + xt = X[[1, 2], [-1, 2]] + self.check_schedule(xt, 1) + np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [-1, 2]]) def test_advanced_indexing(self): X = Tensor.arange(10)+1 - xt = X[[0]] - self.check_schedule(xt, 2) - np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]]) + xt = X[[0, -1]] + self.check_schedule(xt, 1) + np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0, -1]]) def test_advanced_indexing_alt(self): X = Tensor.arange(6).reshape(3, 2)+1 @@ -1759,8 +1759,8 @@ class TestIndexing(unittest.TestCase): def test_advanced_simple_indexing_combined(self): X = Tensor.arange(16).reshape(4, 4) - xt = X[1:2, [1, 2]] - self.check_schedule(xt, 2) + xt = X[1:2, [-1, 2]] + self.check_schedule(xt, 1) def test_push_through_reshape(self): Tensor.manual_seed(0) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index aaf5c05a88..741bd30fa1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1140,10 +1140,12 @@ class Tensor(MathTrait): size = 1 if index is None else self.shape[dim] boundary, stride = [0, size], 1 # defaults match index: - case list() | tuple() | Tensor(): - if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False) + case Tensor(): if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported") - index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values + index = (index < 0).where(index+size, index).to(self.device) # treat negative index values + case list() | tuple(): + if not dtypes.is_int((ti:=Tensor(index)).dtype): raise IndexError(f"{index=} contains non-int element") + index = Tensor([i+size if i<0 else i for i in fully_flatten(index)], self.device, requires_grad=False).reshape(ti.shape) case int() | UOp(): # sint if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}") boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]