mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
@@ -35,14 +35,13 @@ class TestSetitem(unittest.TestCase):
|
||||
for v in (5., 5, True):
|
||||
t = Tensor.ones(6,6, dtype=dt).contiguous()
|
||||
t[1] = v
|
||||
assert t.dtype == dt
|
||||
self.assertEqual(t.dtype, dt)
|
||||
|
||||
def test_setitem_into_noncontiguous(self):
|
||||
t = Tensor.ones(4)
|
||||
assert not t.lazydata.st.contiguous
|
||||
self.assertFalse(t.lazydata.st.contiguous)
|
||||
with self.assertRaises(RuntimeError): t[1] = 5
|
||||
|
||||
@unittest.skip("TODO: broken")
|
||||
def test_setitem_inplace_operator(self):
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] += 2
|
||||
@@ -74,7 +73,7 @@ class TestSetitem(unittest.TestCase):
|
||||
t = Tensor.arange(4).reshape(2, 2).contiguous()
|
||||
t[1] += 2
|
||||
t = t.contiguous()
|
||||
# TODO: RuntimeError: must be contiguous for assign ShapeTracker(views=(View(shape=(2,), strides=(1,), offset=2, mask=None, contiguous=False),))
|
||||
# TODO: RuntimeError: can't double realize in one schedule
|
||||
t[1] -= 1
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [3, 4]])
|
||||
|
||||
@@ -120,7 +119,7 @@ class TestSetitem(unittest.TestCase):
|
||||
def test_setitem_overlapping_inplace(self):
|
||||
t = Tensor([[3.0], [2.0], [1.0]]).contiguous()
|
||||
t[1:] = t[:-1]
|
||||
assert t.tolist() == [[3.0], [3.0], [2.0]]
|
||||
self.assertEqual(t.tolist(), [[3.0], [3.0], [2.0]])
|
||||
|
||||
class TestWithGrad(unittest.TestCase):
|
||||
def test_no_requires_grad_works(self):
|
||||
|
||||
Reference in New Issue
Block a user