From 0fc43c2e54584cedfdd96a5013dcaffe482cdc6c Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 13 Aug 2025 16:50:16 -0700 Subject: [PATCH] fix test_const_tensor_index index (#11660) index should be ints --- test/test_const_folding.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index fbd219bfcb..becc9904b5 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -143,13 +143,12 @@ class TestIndexingConstFolding(unittest.TestCase): _check_ast_count(1, t[:,:,Tensor(1)+2,:]) _check_ast_count(1, t[:,:,Tensor(1),Tensor(0)]) - @unittest.expectedFailure def test_const_tensor_index(self): - # TODO: implement const tensor folded indexing + # TODO: these can be 0, implement const tensor folded indexing t = Tensor.arange(16).float().reshape(1,1,4,4).realize() - _check_ast_count(0, t[:,:,Tensor.ones(2,1),:]) - _check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:]) - _check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)]) + _check_ast_count(1, t[:,:,Tensor.ones(2,1,dtype=dtypes.int),:]) + _check_ast_count(1, t[:,:,Tensor.ones(1,2,dtype=dtypes.int)+2,:]) + _check_ast_count(1, t[:,:,Tensor.ones(1,1,dtype=dtypes.int),Tensor.zeros(2,1,2,dtype=dtypes.int)]) class TestMovedConstFolding(unittest.TestCase): def test_add_shrunk_zero(self):