fix chained full-buffer assign (#14828)

this shows issue that pm_remove_bufferize drops tags, will fix in bufferize next. this also fixed rand being different in jit vs no-jit
This commit is contained in:
chenyu
2026-02-17 09:11:04 -05:00
committed by GitHub
parent 58fa82eef5
commit f2f039cc0f
3 changed files with 10 additions and 7 deletions

View File

@@ -289,7 +289,7 @@ class TestJit(unittest.TestCase):
with_jit.add(o1.numpy()[0][0])
with_jit.add(o2.numpy()[0][0])
assert len(with_jit) == 10, "All values should be different."
assert with_jit != without_jit, "TODO: fix. jit and non-jit should produce the same random values with the same seed"
assert with_jit == without_jit, "jit and non-jit should produce the same random values with the same seed"
def test_jit_multiple_random_regen(self):
def f(a, b):

View File

@@ -43,11 +43,7 @@ class TestAssign(unittest.TestCase):
x += 1
x.realize()
assert x.item() == T
if T == 1:
assert x.uop.base.realized is buf
else:
# TODO: this is wrong, it should always return the same buffer
assert x.uop.base.realized is not buf
assert x.uop.base.realized is buf
def test_assign_slice_add(self):
for T in (1, 2, 10, 100):

View File

@@ -315,7 +315,14 @@ class Tensor(OpMixin):
if is_disk:
self._buffer().copyin(x._data())
return self
result = self._apply_uop(UOp.assign, x)
# chained full-buffer assign should keep writing into the original target buffer
# TODO: move this to rangeify, currently pm_remove_bufferize drops some tags
if self.uop.op is Ops.ASSIGN and (target:=self.uop.src[0]).has_buffer_identity():
if self.uop in x.uop.toposort():
# break assign-in-source cycle lazily through a temporary
result = self._apply_uop(lambda _self, val: target.assign(val.contiguous()), x)
else: result = self._apply_uop(lambda _self, val: target.assign(val), x)
else: result = self._apply_uop(UOp.assign, x)
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it