handle setitem target in rangeify (#14685)

This commit is contained in:
chenyu
2026-02-11 11:38:59 -05:00
committed by GitHub
parent 0d215b962e
commit 7465b22ba0
4 changed files with 31 additions and 8 deletions

View File

@@ -41,6 +41,23 @@ class TestSetitem(unittest.TestCase):
t[1] = 5
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
def test_setitem_into_unrealized_sliced_compute(self):
# base computation contains SHRINK from prior slicing (like QR decomposition pattern)
a = Tensor.arange(6, dtype=dtypes.float).reshape(2, 3)
w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [3, 5, 7]
w[1] = 99
np.testing.assert_allclose(w.numpy(), [3, 99, 7])
def test_setitem_fancy_on_unrealized_view(self):
# fancy indexing setitem on unrealized SHRINK view (triggered infinite loop in graph_rewrite)
base = Tensor.arange(20, dtype=dtypes.float).reshape(4, 5)
sub = base[1:3]
flat = sub.reshape(sub.numel()).contiguous()
idx = Tensor([0, 3, 7, 9])
flat[idx] = Tensor([99, 98, 97, 96], dtype=dtypes.float)
sub.assign(flat.reshape(2, 5))
np.testing.assert_allclose(sub.numpy(), [[99, 6, 7, 98, 9], [10, 11, 97, 13, 96]])
def test_setitem_dtype(self):
for dt in (dtypes.int, dtypes.float, dtypes.bool):
for v in (5., 5, True):

View File

@@ -39,6 +39,18 @@ def collapse_nested_assign(assign:UOp, target:UOp, src:UOp):
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
if (t := target.base).op is Ops.PARAM or (t.op is Ops.MSTACK and all(s.op is Ops.PARAM for s in t.src)): return None
# partial view of unrealized graph: insert CONTIGUOUS at base to realize it
if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK):
# base already realized: copy src only if it reads from the same buffer (overlapping read/write hazard)
if t.op is Ops.CONTIGUOUS: return assign.replace(src=(target, src.contiguous())) if t in src.toposort() else None
if t.op is Ops.CONST: raise RuntimeError("setitem target must be a writable view backed by a buffer")
mops: list[UOp] = []
while target.op in GroupOp.Movement:
mops.append(target)
target = target.src[0]
new_target = t.f(Ops.CONTIGUOUS, tag=t.tag)
for m in reversed(mops): new_target = m.replace(src=(new_target,)+m.src[1:])
return assign.replace(src=(new_target, src))
return src.f(Ops.CONTIGUOUS, tag=assign.tag)
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):

View File

@@ -310,6 +310,8 @@ class Tensor(OpMixin):
result = self._apply_uop(UOp.assign, x)
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it
if x.uop in _pending_assigns.get(buf_uop, []): _pending_assigns[buf_uop].remove(x.uop)
_pending_assigns.setdefault(buf_uop, []).append(result.uop)
return self.replace(result)
@@ -1302,8 +1304,6 @@ class Tensor(OpMixin):
else: # basic setitem
if is_disk: self[indices].assign(v)
else:
self.realize()
if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer")
self[indices].assign(v).realize()
def __delitem__(self, indices) -> None:

View File

@@ -459,12 +459,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
def is_writable_view(self) -> bool:
"""Check if this UOp is a writable view backed by a buffer (injective mapping)."""
if self.op in {Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.DETACH}: return self.src[0].is_writable_view()
if self.op is Ops.MULTI: return all(x.is_writable_view() for x in self.src)
return self.op is Ops.BUFFER
def contiguous(self, *args, **kwargs):
if self.op is Ops.CONTIGUOUS: return self
if self.has_buffer_identity(): return self