diff --git a/test/test_const_folding.py b/test/test_const_folding.py index fbd219bfcb..becc9904b5 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -143,13 +143,12 @@ class TestIndexingConstFolding(unittest.TestCase): _check_ast_count(1, t[:,:,Tensor(1)+2,:]) _check_ast_count(1, t[:,:,Tensor(1),Tensor(0)]) - @unittest.expectedFailure def test_const_tensor_index(self): - # TODO: implement const tensor folded indexing + # TODO: these can be 0, implement const tensor folded indexing t = Tensor.arange(16).float().reshape(1,1,4,4).realize() - _check_ast_count(0, t[:,:,Tensor.ones(2,1),:]) - _check_ast_count(0, t[:,:,Tensor.ones(1,2)+2,:]) - _check_ast_count(0, t[:,:,Tensor.ones(1,1),Tensor.zeros(2,1,2)]) + _check_ast_count(1, t[:,:,Tensor.ones(2,1,dtype=dtypes.int),:]) + _check_ast_count(1, t[:,:,Tensor.ones(1,2,dtype=dtypes.int)+2,:]) + _check_ast_count(1, t[:,:,Tensor.ones(1,1,dtype=dtypes.int),Tensor.zeros(2,1,2,dtype=dtypes.int)]) class TestMovedConstFolding(unittest.TestCase): def test_add_shrunk_zero(self):