diff --git a/test/backend/test_setitem.py b/test/backend/test_setitem.py index 3f971202b6..4f27c5b568 100644 --- a/test/backend/test_setitem.py +++ b/test/backend/test_setitem.py @@ -69,10 +69,31 @@ class TestSetitem(unittest.TestCase): t = Tensor.zeros(6, dtype=dtypes.float).contiguous().realize() with self.assertRaises(RuntimeError): t[2:4] = Tensor([1, 2], dtype=dtypes.int) - def test_setitem_into_noncontiguous(self): + def test_setitem_into_empty(self): + t = Tensor.empty(4) + t[1] = 5 + self.assertEqual(t[1].item(), 5) + + def test_setitem_into_cont(self): t = Tensor.ones(4) with self.assertRaises(RuntimeError): t[1] = 5 + def test_setitem_into_const_alu(self): + # TODO: this is not consistent + t = Tensor.ones(4) + Tensor.ones(4) + t[1] = 5 + self.assertListEqual(t.tolist(), [2, 5, 2, 2]) + + t = Tensor.ones(4) + Tensor.ones(4) + t.realize() + with self.assertRaises(RuntimeError): t[1] = 5 + + def test_setitem_into_arange(self): + # NOTE: arange has no real buffer, but assigning to it is fine + t = Tensor.arange(4) + t[1] = 5 + self.assertListEqual(t.tolist(), [0, 5, 2, 3]) + def test_setitem_chained_indexing(self): # N[i][j] must work the same as N[i, j] N1 = Tensor.zeros((3, 3)).contiguous().realize()