diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index cf91290a9c..e4819bfe13 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -1018,7 +1018,8 @@ class TestSchedule(unittest.TestCase): a = Tensor.arange(16).contiguous().realize() GlobalCounters.reset() a[4] = 3 - # TODO: update when this becomes lazy + self.assertEqual(GlobalCounters.kernel_count, 0) + a.realize() self.assertEqual(GlobalCounters.kernel_count, 1) self.assertListEqual(a.tolist(), [0, 1, 2, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]) diff --git a/test/backend/test_setitem.py b/test/backend/test_setitem.py index 649757c90a..3f971202b6 100644 --- a/test/backend/test_setitem.py +++ b/test/backend/test_setitem.py @@ -162,6 +162,8 @@ class TestSetitem(unittest.TestCase): @TinyJit def f(t:Tensor, a:Tensor): t[2:4, 3:5] = a + # NOTE: without return t or an explicit realize, it's lazy and not captured + return t for i in range(1, 6): t = Tensor.zeros(6, 6).contiguous().realize() diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e69db10c36..e529240135 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1304,15 +1304,14 @@ class Tensor(OpMixin): if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}") if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): raise NotImplementedError("setitem with requires_grad is not supported") idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices) - is_disk = isinstance(self.device, str) and self.device.startswith("DISK") if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem - if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors") + if isinstance(self.device, str) and self.device.startswith("DISK"): raise RuntimeError("advanced setitem is not supported for DISK tensors") if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype) self.assign(self._getitem(indices, v)) - else: # basic setitem - if is_disk: self[indices].assign(v) - else: - self[indices].assign(v).realize() + elif self.uop.is_realized: # basic setitem, self is realized + self[indices].assign(v) + else: # basic setitem, self is not realized + self[indices].assign(v).realize() def __delitem__(self, indices) -> None: raise TypeError("Tensor does not support deleting items")