mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
lazy setitem for realized target (#14735)
This commit is contained in:
@@ -1018,7 +1018,8 @@ class TestSchedule(unittest.TestCase):
|
||||
a = Tensor.arange(16).contiguous().realize()
|
||||
GlobalCounters.reset()
|
||||
a[4] = 3
|
||||
# TODO: update when this becomes lazy
|
||||
self.assertEqual(GlobalCounters.kernel_count, 0)
|
||||
a.realize()
|
||||
self.assertEqual(GlobalCounters.kernel_count, 1)
|
||||
self.assertListEqual(a.tolist(), [0, 1, 2, 3, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
|
||||
|
||||
|
||||
@@ -162,6 +162,8 @@ class TestSetitem(unittest.TestCase):
|
||||
@TinyJit
|
||||
def f(t:Tensor, a:Tensor):
|
||||
t[2:4, 3:5] = a
|
||||
# NOTE: without return t or an explicit realize, it's lazy and not captured
|
||||
return t
|
||||
|
||||
for i in range(1, 6):
|
||||
t = Tensor.zeros(6, 6).contiguous().realize()
|
||||
|
||||
@@ -1304,15 +1304,14 @@ class Tensor(OpMixin):
|
||||
if isinstance(v, Tensor) and v.dtype != self.dtype: raise RuntimeError(f"setitem dtype mismatch: {self.dtype=} != {v.dtype=}")
|
||||
if self.requires_grad or (isinstance(v, Tensor) and v.requires_grad): raise NotImplementedError("setitem with requires_grad is not supported")
|
||||
idx = [indices] if (isinstance(indices, list) and all_int(indices)) or not isinstance(indices, (tuple, list)) else list(indices)
|
||||
is_disk = isinstance(self.device, str) and self.device.startswith("DISK")
|
||||
if any(isinstance(i, (Tensor, list, tuple)) for i in idx): # advanced setitem
|
||||
if is_disk: raise RuntimeError("advanced setitem is not supported for DISK tensors")
|
||||
if isinstance(self.device, str) and self.device.startswith("DISK"): raise RuntimeError("advanced setitem is not supported for DISK tensors")
|
||||
if not isinstance(v, Tensor): v = Tensor(v, device=self.device, dtype=self.dtype)
|
||||
self.assign(self._getitem(indices, v))
|
||||
else: # basic setitem
|
||||
if is_disk: self[indices].assign(v)
|
||||
else:
|
||||
self[indices].assign(v).realize()
|
||||
elif self.uop.is_realized: # basic setitem, self is realized
|
||||
self[indices].assign(v)
|
||||
else: # basic setitem, self is not realized
|
||||
self[indices].assign(v).realize()
|
||||
|
||||
def __delitem__(self, indices) -> None:
|
||||
raise TypeError("Tensor does not support deleting items")
|
||||
|
||||
Reference in New Issue
Block a user