few more tensor indexing test cases (#11608)

This commit is contained in:
chenyu
2025-08-10 15:56:42 -07:00
committed by GitHub
parent 996c907c0b
commit 1181ec0cd2

View File

@@ -2768,10 +2768,14 @@ class TestOps(unittest.TestCase):
a = Tensor.ones(10,11,12)
# tensors used as indices must be int tensors
with self.assertRaises(IndexError): a[Tensor(1.1)]
with self.assertRaises(IndexError): a[Tensor([True, True])]
with self.assertRaises(IndexError): a[[1.1]]
with self.assertRaises(IndexError): a[Tensor([True, False])]
with self.assertRaises(IndexError): a[[True, False]]
# shape mismatch, cannot broadcast. either exception is okay
with self.assertRaises((IndexError, ValueError)): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1), Tensor.randint(2,4,4,1)]
with self.assertRaises((IndexError, ValueError)): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1,1)]
# this is fine
helper_test_op([(5, 6)], lambda x: x[[True, False, 2]])
def test_gather(self):
# indices cannot have gradient