mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
few more tensor indexing test cases (#11608)
This commit is contained in:
@@ -2768,10 +2768,14 @@ class TestOps(unittest.TestCase):
|
||||
a = Tensor.ones(10,11,12)
|
||||
# tensors used as indices must be int tensors
|
||||
with self.assertRaises(IndexError): a[Tensor(1.1)]
|
||||
with self.assertRaises(IndexError): a[Tensor([True, True])]
|
||||
with self.assertRaises(IndexError): a[[1.1]]
|
||||
with self.assertRaises(IndexError): a[Tensor([True, False])]
|
||||
with self.assertRaises(IndexError): a[[True, False]]
|
||||
# shape mismatch, cannot broadcast. either exception is okay
|
||||
with self.assertRaises((IndexError, ValueError)): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1), Tensor.randint(2,4,4,1)]
|
||||
with self.assertRaises((IndexError, ValueError)): a[Tensor.randint(3,1,1,1), Tensor.randint(1,4,1,1,1)]
|
||||
# this is fine
|
||||
helper_test_op([(5, 6)], lambda x: x[[True, False, 2]])
|
||||
|
||||
def test_gather(self):
|
||||
# indices cannot have gradient
|
||||
|
||||
Reference in New Issue
Block a user