mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
handle setitem target in rangeify (#14685)
This commit is contained in:
@@ -41,6 +41,23 @@ class TestSetitem(unittest.TestCase):
|
||||
t[1] = 5
|
||||
np.testing.assert_allclose(t.numpy(), [[0, 1], [5, 5]])
|
||||
|
||||
def test_setitem_into_unrealized_sliced_compute(self):
|
||||
# base computation contains SHRINK from prior slicing (like QR decomposition pattern)
|
||||
a = Tensor.arange(6, dtype=dtypes.float).reshape(2, 3)
|
||||
w = a[0] + a[1] # unrealized ADD with SHRINK in graph: [3, 5, 7]
|
||||
w[1] = 99
|
||||
np.testing.assert_allclose(w.numpy(), [3, 99, 7])
|
||||
|
||||
def test_setitem_fancy_on_unrealized_view(self):
|
||||
# fancy indexing setitem on unrealized SHRINK view (triggered infinite loop in graph_rewrite)
|
||||
base = Tensor.arange(20, dtype=dtypes.float).reshape(4, 5)
|
||||
sub = base[1:3]
|
||||
flat = sub.reshape(sub.numel()).contiguous()
|
||||
idx = Tensor([0, 3, 7, 9])
|
||||
flat[idx] = Tensor([99, 98, 97, 96], dtype=dtypes.float)
|
||||
sub.assign(flat.reshape(2, 5))
|
||||
np.testing.assert_allclose(sub.numpy(), [[99, 6, 7, 98, 9], [10, 11, 97, 13, 96]])
|
||||
|
||||
def test_setitem_dtype(self):
|
||||
for dt in (dtypes.int, dtypes.float, dtypes.bool):
|
||||
for v in (5., 5, True):
|
||||
|
||||
@@ -39,6 +39,18 @@ def collapse_nested_assign(assign:UOp, target:UOp, src:UOp):
|
||||
|
||||
def assign_to_contiguous(assign:UOp, target:UOp, src:UOp):
|
||||
if (t := target.base).op is Ops.PARAM or (t.op is Ops.MSTACK and all(s.op is Ops.PARAM for s in t.src)): return None
|
||||
# partial view of unrealized graph: insert CONTIGUOUS at base to realize it
|
||||
if target is not t and target.op_in_backward_slice_with_self(Ops.SHRINK):
|
||||
# base already realized: copy src only if it reads from the same buffer (overlapping read/write hazard)
|
||||
if t.op is Ops.CONTIGUOUS: return assign.replace(src=(target, src.contiguous())) if t in src.toposort() else None
|
||||
if t.op is Ops.CONST: raise RuntimeError("setitem target must be a writable view backed by a buffer")
|
||||
mops: list[UOp] = []
|
||||
while target.op in GroupOp.Movement:
|
||||
mops.append(target)
|
||||
target = target.src[0]
|
||||
new_target = t.f(Ops.CONTIGUOUS, tag=t.tag)
|
||||
for m in reversed(mops): new_target = m.replace(src=(new_target,)+m.src[1:])
|
||||
return assign.replace(src=(new_target, src))
|
||||
return src.f(Ops.CONTIGUOUS, tag=assign.tag)
|
||||
|
||||
def fix_assign_hazard(assign:UOp, target:UOp, src:UOp):
|
||||
|
||||
@@ -310,6 +310,8 @@ class Tensor(OpMixin):
|
||||
result = self._apply_uop(UOp.assign, x)
|
||||
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
|
||||
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
|
||||
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it
|
||||
if x.uop in _pending_assigns.get(buf_uop, []): _pending_assigns[buf_uop].remove(x.uop)
|
||||
_pending_assigns.setdefault(buf_uop, []).append(result.uop)
|
||||
return self.replace(result)
|
||||
|
||||
@@ -1302,8 +1304,6 @@ class Tensor(OpMixin):
|
||||
else: # basic setitem
|
||||
if is_disk: self[indices].assign(v)
|
||||
else:
|
||||
self.realize()
|
||||
if not self.uop.is_writable_view(): raise RuntimeError("setitem target must be a writable view backed by a buffer")
|
||||
self[indices].assign(v).realize()
|
||||
|
||||
def __delitem__(self, indices) -> None:
|
||||
|
||||
@@ -459,12 +459,6 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.src[0] if self.op is Ops.WHERE and self.src[2].arg is Invalid else UOp.const(dtypes.bool, self.arg is not Invalid)
|
||||
def reduce(self, *src:UOp, **kwargs): return UOp(Ops.REDUCE, kwargs.pop('dtype', self.dtype), src=(self,)+src, **kwargs)
|
||||
|
||||
def is_writable_view(self) -> bool:
|
||||
"""Check if this UOp is a writable view backed by a buffer (injective mapping)."""
|
||||
if self.op in {Ops.RESHAPE, Ops.SHRINK, Ops.PERMUTE, Ops.FLIP, Ops.DETACH}: return self.src[0].is_writable_view()
|
||||
if self.op is Ops.MULTI: return all(x.is_writable_view() for x in self.src)
|
||||
return self.op is Ops.BUFFER
|
||||
|
||||
def contiguous(self, *args, **kwargs):
|
||||
if self.op is Ops.CONTIGUOUS: return self
|
||||
if self.has_buffer_identity(): return self
|
||||
|
||||
Reference in New Issue
Block a user