update test_setitem (#7493)

some tests passed now
This commit is contained in:
chenyu
2024-11-02 17:53:04 -04:00
committed by GitHub
parent 49ae2df036
commit f887de0fd6

View File

@@ -35,14 +35,13 @@ class TestSetitem(unittest.TestCase):
for v in (5., 5, True):
t = Tensor.ones(6,6, dtype=dt).contiguous()
t[1] = v
assert t.dtype == dt
self.assertEqual(t.dtype, dt)
def test_setitem_into_noncontiguous(self):
t = Tensor.ones(4)
assert not t.lazydata.st.contiguous
self.assertFalse(t.lazydata.st.contiguous)
with self.assertRaises(RuntimeError): t[1] = 5
@unittest.skip("TODO: broken")
def test_setitem_inplace_operator(self):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
@@ -74,7 +73,7 @@ class TestSetitem(unittest.TestCase):
t = Tensor.arange(4).reshape(2, 2).contiguous()
t[1] += 2
t = t.contiguous()
# TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),))
# TODO: RuntimeError: can't double realize in one schedule
t[1] -= 1
np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])
@@ -120,7 +119,7 @@ class TestSetitem(unittest.TestCase):
def test_setitem_overlapping_inplace(self):
t = Tensor([[3.0], [2.0], [1.0]]).contiguous()
t[1:] = t[:-1]
assert t.tolist() == [[3.0], [3.0], [2.0]]
self.assertEqual(t.tolist(), [[3.0], [3.0], [2.0]])
class TestWithGrad(unittest.TestCase):
def test_no_requires_grad_works(self):