diff --git a/test/test_assign.py b/test/test_assign.py index fa43ec67f5..0d7c0df71c 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -644,5 +644,19 @@ class TestAssignOrdering(unittest.TestCase): self.assertEqual(row0_sum.item(), 4) self.assertEqual(row1_sum.item(), 8) + def test_multiple_slice_assigns_then_read(self): + """Multiple non-overlapping slice assigns then read - RAW dependencies must ensure all writes complete before read.""" + buf = Tensor.zeros(4).contiguous().realize() + buf[0:1].assign(Tensor.ones(1)) + buf[1:2].assign(Tensor.full((1,), 2.0)) + buf[2:3].assign(Tensor.full((1,), 3.0)) + self.assertEqual(buf.sum().realize().item(), 0.0) # TODO: wrong! should be 1 + 2 + 3 + 0 = 6 + + buf = Tensor.zeros(4).contiguous().realize() + buf[0:1].assign(Tensor.ones(1)).realize() + buf[1:2].assign(Tensor.full((1,), 2.0)).realize() + buf[2:3].assign(Tensor.full((1,), 3.0)).realize() + self.assertEqual(buf.sum().realize().item(), 6.0) + if __name__ == "__main__": unittest.main()