update assign tests to also test the expected behavior (#15132)

This commit is contained in:
chenyu
2026-03-04 11:34:43 -05:00
committed by GitHub
parent 1f96cc2b51
commit fae400d300
2 changed files with 65 additions and 24 deletions

View File

@@ -251,8 +251,11 @@ class TestSetitem(unittest.TestCase):
s1 = t.sum()
t[3:].assign(2.0)
s2 = t.sum()
# TODO: s0 and s1 see final buffer state, should be [0.0, 3.0, 9.0]
np.testing.assert_allclose([s0.item(), s1.item(), s2.item()], [9.0, 9.0, 9.0])
try:
np.testing.assert_allclose([s0.item(), s1.item(), s2.item()], [0.0, 3.0, 9.0])
except AssertionError:
# TODO: broken now, lazy sums all see final buffer state
np.testing.assert_allclose([s0.item(), s1.item(), s2.item()], [9.0, 9.0, 9.0])
# eager version
t = Tensor.zeros(6).contiguous().realize()
@@ -273,8 +276,11 @@ class TestSetitem(unittest.TestCase):
a.assign(new_a)
b.assign(new_b)
np.testing.assert_allclose(a.numpy(), [4, 6, 8, 10])
# TODO: new_b sees mutated a, should be [0, 2, 4, 6]
np.testing.assert_allclose(b.numpy(), [8, 12, 16, 20])
try:
np.testing.assert_allclose(b.numpy(), [0, 2, 4, 6])
except AssertionError:
# TODO: broken now, new_b sees mutated a
np.testing.assert_allclose(b.numpy(), [8, 12, 16, 20])
# eager version
a = Tensor.arange(4, dtype=dtypes.float).contiguous().realize()

View File

@@ -499,7 +499,11 @@ class TestAssign(unittest.TestCase):
# assign to a shape-changing bitcast view (only works on DISK currently)
a = Tensor([0]*8, dtype=dtypes.uint8).realize()
a.bitcast(dtypes.int64).assign(Tensor([12345], dtype=dtypes.int64)).realize()
np.testing.assert_equal(a.numpy(), [0]*8) # TODO: should be [57, 48, 0, 0, 0, 0, 0, 0] (little-endian 12345)
try:
np.testing.assert_equal(a.numpy(), [57, 48, 0, 0, 0, 0, 0, 0])
except AssertionError:
# TODO: broken now
np.testing.assert_equal(a.numpy(), [0]*8)
@unittest.skip("don't use output buffer, and mismatch dtype no longer supported")
def test_cast_assignment(self):
@@ -691,7 +695,11 @@ class TestAssignOrdering(unittest.TestCase):
right = buf[4:8].clone() # lazy - not captured yet
buf[0:4].assign(right).realize() # this works
buf[4:8].assign(left).realize() # left now reads from modified buf!
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8]) # TODO: wrong! should be [5,6,7,8,1,2,3,4]
try:
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 1, 2, 3, 4])
except AssertionError:
# TODO: broken now
np.testing.assert_equal(buf.numpy(), [5, 6, 7, 8, 5, 6, 7, 8])
# with .realize() on temps: values captured before writes
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
@@ -809,40 +817,55 @@ class TestAssignToUnrealizedView(unittest.TestCase):
c = t.to("CPU:1") # unrealized COPY
self.assertIs(c.uop.base.op, Ops.COPY)
c[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).to("CPU:1").contiguous().realize())
# TODO: should be [[0,1],[0,1]]
self.assertEqual(c.tolist(), [[0,0],[0,0]])
try:
self.assertEqual(c.tolist(), [[0,1],[0,1]])
except AssertionError:
# TODO: broken now
self.assertEqual(c.tolist(), [[0,0],[0,0]])
def test_contiguous(self):
t = Tensor([[1,2],[3,4]]).contiguous().realize()
c = t.permute(1,0).contiguous() # unrealized CONTIGUOUS
self.assertIs(c.uop.base.op, Ops.CONTIGUOUS)
c[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
# TODO: should be [[1,1],[2,1]]
self.assertEqual(c.tolist(), [[1,3],[2,4]])
try:
self.assertEqual(c.tolist(), [[1,1],[2,1]])
except AssertionError:
# TODO: broken now
self.assertEqual(c.tolist(), [[1,3],[2,4]])
def test_contiguous_backward(self):
t = Tensor([[1,2],[3,4]]).contiguous().realize()
cb = t.contiguous_backward() # unrealized CONTIGUOUS_BACKWARD
self.assertIs(cb.uop.base.op, Ops.CONTIGUOUS_BACKWARD)
cb[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
# TODO: should be [[1,1],[3,1]]
self.assertEqual(cb.tolist(), [[1,2],[3,4]])
try:
self.assertEqual(cb.tolist(), [[1,1],[3,1]])
except AssertionError:
# TODO: broken now
self.assertEqual(cb.tolist(), [[1,2],[3,4]])
def test_detach_copy(self):
t = Tensor.zeros(2,2, dtype=dtypes.int).to("CPU:0").contiguous().realize()
d = t.to("CPU:1").detach() # DETACH(unrealized COPY)
self.assertIs(d.uop.base.op, Ops.COPY)
d[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).to("CPU:1").contiguous().realize())
# TODO: should be [[0,1],[0,1]]
self.assertEqual(d.tolist(), [[0,0],[0,0]])
try:
self.assertEqual(d.tolist(), [[0,1],[0,1]])
except AssertionError:
# TODO: broken now
self.assertEqual(d.tolist(), [[0,0],[0,0]])
def test_detach_contiguous(self):
t = Tensor([[1,2],[3,4]]).contiguous().realize()
d = t.permute(1,0).contiguous().detach() # DETACH(unrealized CONTIGUOUS)
self.assertIs(d.uop.base.op, Ops.CONTIGUOUS)
d[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
# TODO: should be [[1,1],[2,1]]
self.assertEqual(d.tolist(), [[1,3],[2,4]])
try:
self.assertEqual(d.tolist(), [[1,1],[2,1]])
except AssertionError:
# TODO: broken now
self.assertEqual(d.tolist(), [[1,3],[2,4]])
def test_alu(self):
a = Tensor([1,2,3,4]).contiguous().realize()
@@ -850,31 +873,43 @@ class TestAssignToUnrealizedView(unittest.TestCase):
c = a + b # unrealized ADD
self.assertIs(c.uop.base.op, Ops.ADD)
c[:2].assign(Tensor([99, 99]).realize())
# TODO: silently dropped, should be [99,99,10,12] or raise an error
self.assertEqual(c.tolist(), [6,8,10,12])
try:
self.assertEqual(c.tolist(), [99,99,10,12])
except AssertionError:
# TODO: broken now, silently dropped
self.assertEqual(c.tolist(), [6,8,10,12])
def test_reduce(self):
a = Tensor([[1,2],[3,4]]).contiguous().realize()
r = a.sum(axis=0) # unrealized REDUCE_AXIS
self.assertIs(r.uop.base.op, Ops.REDUCE_AXIS)
r[:1].assign(Tensor([99]).realize())
# TODO: silently dropped, should be [99,6] or raise an error
self.assertEqual(r.tolist(), [4,6])
try:
self.assertEqual(r.tolist(), [99,6])
except AssertionError:
# TODO: broken now, silently dropped
self.assertEqual(r.tolist(), [4,6])
def test_cast(self):
a = Tensor([1,2,3,4]).contiguous().realize()
c = a.float() # unrealized CAST
self.assertIs(c.uop.base.op, Ops.CAST)
c[:2].assign(Tensor([99, 99], dtype=dtypes.float).realize())
# TODO: silently dropped, should be [99,99,3,4] or raise an error
self.assertEqual(c.tolist(), [1,2,3,4])
try:
self.assertEqual(c.tolist(), [99,99,3,4])
except AssertionError:
# TODO: broken now, silently dropped
self.assertEqual(c.tolist(), [1,2,3,4])
def test_const(self):
c = Tensor(5).reshape(1, 1).expand(2, 2)
self.assertIs(c.uop.base.op, Ops.CONST)
c[:, 1:2].assign(Tensor.ones(2,1, dtype=dtypes.int).contiguous().realize())
# TODO: silently dropped, should be [[5,1],[5,1]] or raise an error
self.assertEqual(c.tolist(), [[5,5],[5,5]])
try:
self.assertEqual(c.tolist(), [[5,1],[5,1]])
except AssertionError:
# TODO: broken now, silently dropped
self.assertEqual(c.tolist(), [[5,5],[5,5]])
if __name__ == "__main__":
unittest.main()