fix getitem tensor indexing detection (#14712)

issue with sint
This commit is contained in:
chenyu
2026-02-12 16:04:37 -05:00
committed by GitHub
parent 86352988d8
commit 787998fac3
2 changed files with 7 additions and 2 deletions

View File

@@ -3,7 +3,7 @@
import unittest, random, warnings
import numpy as np
from tinygrad import Tensor, dtypes, Device, TinyJit
from tinygrad import Tensor, dtypes, Device, TinyJit, Variable
from tinygrad.helpers import all_same, prod
from test.helpers import slow
@@ -647,6 +647,11 @@ class TestIndexing(unittest.TestCase):
i, j = indices
numpy_testing_assert_equal_helper(x[i:j], x[0:1])
def test_variable_with_tensor_index(self):
t = Tensor.arange(12).reshape(3, 4)
v = Variable("v", 0, 2).bind(1)
numpy_testing_assert_equal_helper(t[v, Tensor([0, 1, 2])], t[1, Tensor([0, 1, 2])])
def test_ellipsis_tensor(self):
x = Tensor.arange(0, 9).reshape(3, 3)
idx = Tensor([0, 2])

View File

@@ -1215,7 +1215,7 @@ class Tensor(OpMixin):
x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], sint)))
# tensor indexing
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]:
if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], sint)) if isinstance(i['index'], Tensor)]:
# unload the tensor object into actual tensors
dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), []
big_shape = _broadcast_shape(*(t.shape for t in tensors))