mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix chained full-buffer assign (#14828)
this shows issue that pm_remove_bufferize drops tags, will fix in bufferize next. this also fixed rand being different in jit vs no-jit
This commit is contained in:
@@ -289,7 +289,7 @@ class TestJit(unittest.TestCase):
|
||||
with_jit.add(o1.numpy()[0][0])
|
||||
with_jit.add(o2.numpy()[0][0])
|
||||
assert len(with_jit) == 10, "All values should be different."
|
||||
assert with_jit != without_jit, "TODO: fix. jit and non-jit should produce the same random values with the same seed"
|
||||
assert with_jit == without_jit, "jit and non-jit should produce the same random values with the same seed"
|
||||
|
||||
def test_jit_multiple_random_regen(self):
|
||||
def f(a, b):
|
||||
|
||||
@@ -43,11 +43,7 @@ class TestAssign(unittest.TestCase):
|
||||
x += 1
|
||||
x.realize()
|
||||
assert x.item() == T
|
||||
if T == 1:
|
||||
assert x.uop.base.realized is buf
|
||||
else:
|
||||
# TODO: this is wrong, it should always return the same buffer
|
||||
assert x.uop.base.realized is not buf
|
||||
assert x.uop.base.realized is buf
|
||||
|
||||
def test_assign_slice_add(self):
|
||||
for T in (1, 2, 10, 100):
|
||||
|
||||
@@ -315,7 +315,14 @@ class Tensor(OpMixin):
|
||||
if is_disk:
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
result = self._apply_uop(UOp.assign, x)
|
||||
# chained full-buffer assign should keep writing into the original target buffer
|
||||
# TODO: move this to rangeify, currently pm_remove_bufferize drops some tags
|
||||
if self.uop.op is Ops.ASSIGN and (target:=self.uop.src[0]).has_buffer_identity():
|
||||
if self.uop in x.uop.toposort():
|
||||
# break assign-in-source cycle lazily through a temporary
|
||||
result = self._apply_uop(lambda _self, val: target.assign(val.contiguous()), x)
|
||||
else: result = self._apply_uop(lambda _self, val: target.assign(val), x)
|
||||
else: result = self._apply_uop(UOp.assign, x)
|
||||
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
|
||||
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
|
||||
# deduplicate: if the value is already a pending assign for this buffer (e.g. __iadd__ in __setitem__), remove it
|
||||
|
||||
Reference in New Issue
Block a user