clean up test_setitem_slice (#14750)

moved to test_setitem_schedule, and use contiguous zeros as scheduler handles empty differently now
This commit is contained in:
chenyu
2026-02-14 14:29:16 -05:00
committed by GitHub
parent 8f6772fd8c
commit 0ce4a55dad
3 changed files with 29 additions and 23 deletions

View File

@@ -68,28 +68,6 @@ class TestMemoryCount(unittest.TestCase):
_, mem = get_stats(a.assign(a+a))
self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write
def test_setitem_slice_const(self):
t = Tensor.empty(100, dtype=dtypes.int).realize()
GlobalCounters.reset()
t[20:50] = 3
t.realize()
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
def test_setitem_slice_tensor(self):
t = Tensor.empty(100, dtype=dtypes.int).realize()
v = Tensor.empty(30, dtype=dtypes.int).realize()
GlobalCounters.reset()
t[20:50] = v
t.realize()
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
def test_setitem_full(self):
t = Tensor.empty(100, dtype=dtypes.int).realize()
GlobalCounters.reset()
t[:] = 3
t.realize()
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
@unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device")
def test_copyout(self):
a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU")

View File

@@ -105,5 +105,33 @@ class TestSetitemInto(unittest.TestCase):
self.assertEqual(GlobalCounters.kernel_count, 2)
self.assertListEqual(t.tolist(), [0, 5, 2, 3])
def test_setitem_slice_const(self):
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
GlobalCounters.reset()
t[20:50] = 3
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written
def test_setitem_slice_tensor(self):
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
v = Tensor.zeros(30, dtype=dtypes.int32).contiguous().realize()
GlobalCounters.reset()
t[20:50] = v
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written
def test_setitem_full(self):
t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize()
GlobalCounters.reset()
t[:] = 3
self.assertEqual(GlobalCounters.kernel_count, 0)
t.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written
if __name__ == '__main__':
unittest.main()

View File

@@ -1306,7 +1306,7 @@ class Tensor(OpMixin):
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem
if isinstance(self.device, str) and self.device.startswith("DISK"): raise RuntimeError("advanced setitem is not supported for DISK tensors")
if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
self.assign(self._getitem(indices, v))
elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized