From 787998fac3b27525f82853cad29fcdb006dbb6cd Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 12 Feb 2026 16:04:37 -0500 Subject: [PATCH] fix getitem tensor indexing detection (#14712) issue with sint --- test/unit/test_indexing.py | 7 ++++++- tinygrad/tensor.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/test/unit/test_indexing.py b/test/unit/test_indexing.py index ddc2c7fa0e..c95e69193f 100644 --- a/test/unit/test_indexing.py +++ b/test/unit/test_indexing.py @@ -3,7 +3,7 @@ import unittest, random, warnings import numpy as np -from tinygrad import Tensor, dtypes, Device, TinyJit +from tinygrad import Tensor, dtypes, Device, TinyJit, Variable from tinygrad.helpers import all_same, prod from test.helpers import slow @@ -647,6 +647,11 @@ class TestIndexing(unittest.TestCase): i, j = indices numpy_testing_assert_equal_helper(x[i:j], x[0:1]) + def test_variable_with_tensor_index(self): + t = Tensor.arange(12).reshape(3, 4) + v = Variable("v", 0, 2).bind(1) + numpy_testing_assert_equal_helper(t[v, Tensor([0, 1, 2])], t[1, Tensor([0, 1, 2])]) + def test_ellipsis_tensor(self): x = Tensor.arange(0, 9).reshape(3, 3) idx = Tensor([0, 2]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e15152e89d..dd85e88f04 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1215,7 +1215,7 @@ class Tensor(OpMixin): x = x.reshape(tuple(index['size'] for index in indices_parsed if not isinstance(index['index'], sint))) # tensor indexing - if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], int)) if isinstance(i['index'], Tensor)]: + if tops := [(d,i) for d,i in enumerate(i_ for i_ in indices_parsed if not isinstance(i_['index'], sint)) if isinstance(i['index'], Tensor)]: # unload the tensor object into actual tensors dims, tensors, masks = [d for d,_ in tops], cast(list[Tensor], [i['index'] for _,i in tops]), [] big_shape = _broadcast_shape(*(t.shape for t in tensors))