From 0ce4a55dad5c68a1e680d69edd06f443220bea00 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sat, 14 Feb 2026 14:29:16 -0500 Subject: [PATCH] clean up test_setitem_slice (#14750) moved to test_setitem_schedule, and use contiguous zeros as scheduler handles empty differently now --- test/null/test_uops_stats.py | 22 ---------------------- test/unit/test_setitem_schedule.py | 28 ++++++++++++++++++++++++++++ tinygrad/tensor.py | 2 +- 3 files changed, 29 insertions(+), 23 deletions(-) diff --git a/test/null/test_uops_stats.py b/test/null/test_uops_stats.py index f6580d8510..72e8b73d3b 100644 --- a/test/null/test_uops_stats.py +++ b/test/null/test_uops_stats.py @@ -68,28 +68,6 @@ class TestMemoryCount(unittest.TestCase): _, mem = get_stats(a.assign(a+a)) self.assertEqual(mem, 1024*1024*2) # 1 read + 1 write - def test_setitem_slice_const(self): - t = Tensor.empty(100, dtype=dtypes.int).realize() - GlobalCounters.reset() - t[20:50] = 3 - t.realize() - self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written - - def test_setitem_slice_tensor(self): - t = Tensor.empty(100, dtype=dtypes.int).realize() - v = Tensor.empty(30, dtype=dtypes.int).realize() - GlobalCounters.reset() - t[20:50] = v - t.realize() - self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written - - def test_setitem_full(self): - t = Tensor.empty(100, dtype=dtypes.int).realize() - GlobalCounters.reset() - t[:] = 3 - t.realize() - self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written - @unittest.skipIf(Device.DEFAULT == "CPU", "test copy to CPU from other device") def test_copyout(self): a = Tensor.empty(32, dtype=dtypes.uint8).to("CPU") diff --git a/test/unit/test_setitem_schedule.py b/test/unit/test_setitem_schedule.py index c583c316be..7bd8ad8765 100644 --- a/test/unit/test_setitem_schedule.py +++ b/test/unit/test_setitem_schedule.py @@ -105,5 +105,33 @@ class TestSetitemInto(unittest.TestCase): self.assertEqual(GlobalCounters.kernel_count, 2) self.assertListEqual(t.tolist(), [0, 5, 2, 3]) + def test_setitem_slice_const(self): + t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize() + GlobalCounters.reset() + t[20:50] = 3 + self.assertEqual(GlobalCounters.kernel_count, 0) + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 30*4) # 30 elements written + + def test_setitem_slice_tensor(self): + t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize() + v = Tensor.zeros(30, dtype=dtypes.int32).contiguous().realize() + GlobalCounters.reset() + t[20:50] = v + self.assertEqual(GlobalCounters.kernel_count, 0) + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 30*4*2) # 30 read + 30 written + + def test_setitem_full(self): + t = Tensor.zeros(100, dtype=dtypes.int32).contiguous().realize() + GlobalCounters.reset() + t[:] = 3 + self.assertEqual(GlobalCounters.kernel_count, 0) + t.realize() + self.assertEqual(GlobalCounters.kernel_count, 1) + self.assertEqual(GlobalCounters.global_mem, 100*4) # full buffer written + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index b433f335ea..c45300e0c9 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1306,7 +1306,7 @@ class Tensor(OpMixin): idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices) is_disk = isinstance(self.device, str) and self.device.startswith("DISK") if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem - if isinstance(self.device, str) and self.device.startswith("DISK"): raise RuntimeError("advanced setitem is not supported for DISK tensors") + if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) self.assign(self._getitem(indices, v)) elif is_disk or self.uop.is_realized: # basic setitem, self is realized. TODO: disk uop.base is a COPY and not realized