diff --git a/test/test_setitem.py b/test/test_setitem.py index 934c726e77..b70cca8784 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -35,14 +35,13 @@ class TestSetitem(unittest.TestCase): for v in (5., 5, True): t = Tensor.ones(6,6, dtype=dt).contiguous() t[1] = v - assert t.dtype == dt + self.assertEqual(t.dtype, dt) def test_setitem_into_noncontiguous(self): t = Tensor.ones(4) - assert not t.lazydata.st.contiguous + self.assertFalse(t.lazydata.st.contiguous) with self.assertRaises(RuntimeError): t[1] = 5 - @unittest.skip("TODO: broken") def test_setitem_inplace_operator(self): t = Tensor.arange(4).reshape(2, 2).contiguous() t[1] += 2 @@ -74,7 +73,7 @@ class TestSetitem(unittest.TestCase): t = Tensor.arange(4).reshape(2, 2).contiguous() t[1] += 2 t = t.contiguous() - # TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),)) + # TODO: RuntimeError: can't double realize in one schedule t[1] -= 1 np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]]) @@ -120,7 +119,7 @@ class TestSetitem(unittest.TestCase): def test_setitem_overlapping_inplace(self): t = Tensor([[3.0], [2.0], [1.0]]).contiguous() t[1:] = t[:-1] - assert t.tolist() == [[3.0], [3.0], [2.0]] + self.assertEqual(t.tolist(), [[3.0], [3.0], [2.0]]) class TestWithGrad(unittest.TestCase): def test_no_requires_grad_works(self):