list indexing can normalize in python (#11609)

* list indexing can normalize in python

list index does not need to be normalized in tensor

* update those
This commit is contained in:
chenyu
2025-08-10 17:02:38 -07:00
committed by GitHub
parent 1181ec0cd2
commit a67e0917c3
3 changed files with 15 additions and 13 deletions

View File

@@ -2742,10 +2742,10 @@ class TestOps(unittest.TestCase):
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[0]]], lambda x: x[[[0]]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[0],b,c,d,:], lambda x: x[[0],j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[[[0]]],b,c,d,[[1]]], lambda x: x[[[[0]]],j,k,o,[[1]]])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0],b,c,d,:], lambda x: x[[1,0],j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[[1,0,-1],b,c,d,:], lambda x: x[[1,0,-1],j,k,o,:])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[1,2,3],...], lambda x: x[i,j,k,[1,2,3],...])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,b,c,[[1],[2],[3]],...], lambda x: x[i,j,k,[[1],[2],[3]],...])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[2,1,0],e], lambda x: x[i,[2,1,0],k,[2,1,0],p])
helper_test_op([(2,5,6,5,3,4)], lambda x: x[a,[2,1,0],c,[-2,1,0],e], lambda x: x[i,[2,1,0],k,[-2,1,0],p])
def test_slice_fancy_indexing_tuple_indices(self):
a,b,c,d,e,i,j,k,o,p = self._get_index_randoms()

View File

@@ -1741,15 +1741,15 @@ class TestIndexing(unittest.TestCase):
def test_simple_indexing_alt(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[[1, 2], [1, 2]]
self.check_schedule(xt, 3)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [1, 2]])
xt = X[[1, 2], [-1, 2]]
self.check_schedule(xt, 1)
np.testing.assert_equal(xt.numpy(), (np.arange(16).reshape(4, 4))[[1, 2], [-1, 2]])
def test_advanced_indexing(self):
X = Tensor.arange(10)+1
xt = X[[0]]
self.check_schedule(xt, 2)
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0]])
xt = X[[0, -1]]
self.check_schedule(xt, 1)
np.testing.assert_equal(xt.numpy(), (np.arange(10)+1)[[0, -1]])
def test_advanced_indexing_alt(self):
X = Tensor.arange(6).reshape(3, 2)+1
@@ -1759,8 +1759,8 @@ class TestIndexing(unittest.TestCase):
def test_advanced_simple_indexing_combined(self):
X = Tensor.arange(16).reshape(4, 4)
xt = X[1:2, [1, 2]]
self.check_schedule(xt, 2)
xt = X[1:2, [-1, 2]]
self.check_schedule(xt, 1)
def test_push_through_reshape(self):
Tensor.manual_seed(0)

View File

@@ -1140,10 +1140,12 @@ class Tensor(MathTrait):
size = 1 if index is None else self.shape[dim]
boundary, stride = [0, size], 1 # defaults
match index:
case list() | tuple() | Tensor():
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
case Tensor():
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
index = (index < 0).where(index+size, index).to(self.device) # treat negative index values
case list() | tuple():
if not dtypes.is_int((ti:=Tensor(index)).dtype): raise IndexError(f"{index=} contains non-int element")
index = Tensor([i+size if i<0 else i for i in fully_flatten(index)], self.device, requires_grad=False).reshape(ti.shape)
case int() | UOp(): # sint
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]