fix test_const_tensor_index index (#11660)

index should be ints
This commit is contained in:
chenyu
2025-08-13 16:50:16 -07:00
committed by GitHub
parent 4fe19eec72
commit 0fc43c2e54

View File

@@ -143,13 +143,12 @@ class TestIndexingConstFolding(unittest.TestCase):
_check_ast_count(1, t[:,:,Tensor(1)+2,:])
_check_ast_count(1, t[:,:,Tensor(1),Tensor(0)])
@unittest.expectedFailure
def test_const_tensor_index(self):
# TODO: implement const tensor folded indexing
# TODO: these can be 0, implement const tensor folded indexing
t = Tensor.arange(16).float().reshape(1,1,4,4).realize()
_check_ast_count(0, t[:,:,Tensor.ones(2,1),:])
_check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:])
_check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)])
_check_ast_count(1, t[:,:,Tensor.ones(2,1,dtype=dtypes.int),:])
_check_ast_count(1, t[:,:,Tensor.ones(1,2,dtype=dtypes.int)+2,:])
_check_ast_count(1, t[:,:,Tensor.ones(1,1,dtype=dtypes.int),Tensor.zeros(2,1,2,dtype=dtypes.int)])
class TestMovedConstFolding(unittest.TestCase):
def test_add_shrunk_zero(self):