lazy setitem for realized target (#14735)

This commit is contained in:
chenyu
2026-02-13 12:20:14 -05:00
committed by GitHub
parent 3bee6638e3
commit 8b205a007e
3 changed files with 9 additions and 7 deletions

View File

@@ -1018,7 +1018,8 @@ class TestSchedule(unittest.TestCase):
a = Tensor.arange(16).contiguous().realize()
GlobalCounters.reset()
a[4] = 3
# TODO: update when this becomes lazy
self.assertEqual(GlobalCounters.kernel_count, 0)
a.realize()
self.assertEqual(GlobalCounters.kernel_count, 1)
self.assertListEqual(a.tolist(), [0, 1, 2, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])

View File

@@ -162,6 +162,8 @@ class TestSetitem(unittest.TestCase):
@TinyJit
def f(t:Tensor, a:Tensor):
t[2:4, 3:5] = a
# NOTE: without return t or an explicit realize, it's lazy and not captured
return t
for i in range(1, 6):
t = Tensor.zeros(6, 6).contiguous().realize()

View File

@@ -1304,15 +1304,14 @@ class Tensor(OpMixin):
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): raise NotImplementedError("setitem with requires_grad is not supported")
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem
if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors")
if isinstance(self.device, str) and self.device.startswith("DISK"): raise RuntimeError("advanced setitem is not supported for DISK tensors")
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
self.assign(self._getitem(indices, v))
else: # basic setitem
if is_disk: self[indices].assign(v)
else:
self[indices].assign(v).realize()
elif self.uop.is_realized: # basic setitem, self is realized
self[indices].assign(v)
else: # basic setitem, self is not realized
self[indices].assign(v).realize()
def __delitem__(self, indices) -> None:
raise TypeError("Tensor does not support deleting items")