diff --git a/test/test_setitem.py b/test/test_setitem.py index 0614f068fc..649757c90a 100644 --- a/test/test_setitem.py +++ b/test/test_setitem.py @@ -41,6 +41,23 @@ class TestSetitem(unittest.TestCase): t[1] = 5 np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]]) + def test_setitem_into_unrealized_sliced_compute(self): + # base computation contains SHRINK from prior slicing (like QR decomposition pattern) + a = Tensor.arange(6, dtype=dtypes.float).reshape(2, 3) + w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [3, 5, 7] + w[1] = 99 + np.testing.assert_allclose(w.numpy(), [3, 99, 7]) + + def test_setitem_fancy_on_unrealized_view(self): + # fancy indexing setitem on unrealized SHRINK view (triggered infinite loop in graph_rewrite) + base = Tensor.arange(20, dtype=dtypes.float).reshape(4, 5) + sub = base[1:3] + flat = sub.reshape(sub.numel()).contiguous() + idx = Tensor([0, 3, 7, 9]) + flat[idx] = Tensor([99, 98, 97, 96], dtype=dtypes.float) + sub.assign(flat.reshape(2, 5)) + np.testing.assert_allclose(sub.numpy(), [[99, 6, 7, 98, 9], [10, 11, 97, 13, 96]]) + def test_setitem_dtype(self): for dt in (dtypes.int, dtypes.float, dtypes.bool): for v in (5., 5, True): diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 809a6cf5fa..b466e8bc4c 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -39,6 +39,18 @@ def collapse_nested_assign(assign:UOp, target:UOp, src:UOp): def assign_to_contiguous(assign:UOp, target:UOp, src:UOp): if (t := target.base).op is Ops.PARAM or (t.op is Ops.MSTACK and all(s.op is Ops.PARAM for s in t.src)): return None + # partial view of unrealized graph: insert CONTIGUOUS at base to realize it + if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK): + # base already realized: copy src only if it reads from the same buffer (overlapping read/write hazard) + if t.op is Ops.CONTIGUOUS: return assign.replace(src=(target, src.contiguous())) if t in src.toposort() else None + if t.op is Ops.CONST: raise RuntimeError("setitem target must be a writable view backed by a buffer") + mops: list[UOp] = [] + while target.op in GroupOp.Movement: + mops.append(target) + target = target.src[0] + new_target = t.f(Ops.CONTIGUOUS, tag=t.tag) + for m in reversed(mops): new_target = m.replace(src=(new_target,)+m.src[1:]) + return assign.replace(src=(new_target, src)) return src.f(Ops.CONTIGUOUS, tag=assign.tag) def fix_assign_hazard(assign:UOp, target:UOp, src:UOp): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e59fed79f4..330e64012c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -310,6 +310,8 @@ class Tensor(OpMixin): result = self._apply_uop(UOp.assign, x) # track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity(): + # deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it + if x.uop in _pending_assigns.get(buf_uop, []): _pending_assigns[buf_uop].remove(x.uop) _pending_assigns.setdefault(buf_uop, []).append(result.uop) return self.replace(result) @@ -1302,8 +1304,6 @@ class Tensor(OpMixin): else: # basic setitem if is_disk: self[indices].assign(v) else: - self.realize() - if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer") self[indices].assign(v).realize() def __delitem__(self, indices) -> None: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 892211032c..2ca38b4daa 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -459,12 +459,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid) def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs) - def is_writable_view(self) -> bool: - """Check if this UOp is a writable view backed by a buffer (injective mapping).""" - if self.op in {Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.DETACH}: return self.src[0].is_writable_view() - if self.op is Ops.MULTI: return all(x.is_writable_view() for x in self.src) - return self.op is Ops.BUFFER - def contiguous(self, *args, **kwargs): if self.op is Ops.CONTIGUOUS: return self if self.has_buffer_identity(): return self