assign realize fix (#14649)

fix the need for explicit assign. track pending assigns for each buffer, and run those before the main realize in order
This commit is contained in:
chenyu
2026-02-09 17:46:46 -05:00
committed by GitHub
parent 0913c068ea
commit 9e3f24db9f
4 changed files with 21 additions and 72 deletions

View File

@@ -175,7 +175,6 @@ class TestZeroFolding(unittest.TestCase):
class TestAssignIssues(unittest.TestCase):
# these are good failures. i'm not sure we need more, but we need to fix these.
@unittest.expectedFailure
def test_assign_permuted_view_constant(self):
# assigning to a permuted view should modify the underlying tensor
arr = np.arange(6).reshape(2, 3).astype(np.float32)
@@ -185,7 +184,6 @@ class TestAssignIssues(unittest.TestCase):
t.permute(1, 0).assign(Tensor([[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]))
np.testing.assert_allclose(t.numpy(), torch_tensor.numpy())
@unittest.expectedFailure
def test_assign_shrink_view_constant(self):
# assigning to a shrunk view should update the base tensor
arr = np.arange(9).reshape(3, 3).astype(np.float32)

View File

@@ -111,30 +111,18 @@ class TestJitFootguns(unittest.TestCase):
self.assertEqual(first.numpy().item(), expected_first)
buf = new_buf
def test_slice_assign_requires_realize(self):
"""Slice assign then read from same buffer - assign isn't connected to read without explicit realize()."""
def test_slice_assign_works_without_realize(self):
"""Slice assign then read from same buffer - pending assigns are side-realized."""
from tinygrad import Variable
v_pos = Variable("pos", 0, 3)
# without .realize() after assign, the read doesn't see the assigned values
cache = Tensor.zeros(4, 4).contiguous().realize()
@TinyJit
def f_broken(pos):
def f(pos):
cache[pos:pos+1, :].assign(Tensor.ones(1, 4))
return cache.sum().realize()
for i in range(4):
cache.assign(Tensor.zeros(4, 4)).realize()
self.assertEqual(f_broken(v_pos.bind(i)).item(), 0.0) # should be 4.0!
# workaround: add .realize() after assign
cache2 = Tensor.zeros(4, 4).contiguous().realize()
@TinyJit
def f_fixed(pos):
cache2[pos:pos+1, :].assign(Tensor.ones(1, 4)).realize()
return cache2.sum().realize()
for i in range(4):
cache2.assign(Tensor.zeros(4, 4)).realize()
self.assertEqual(f_fixed(v_pos.bind(i)).item(), 4.0)
self.assertEqual(f(v_pos.bind(i)).item(), 4.0)
def test_symbolic_pad_view_frozen(self):
"""Symbolic pad view has BIND values baked in at capture time. TODO: pad should be captured in jit."""

View File

@@ -560,30 +560,16 @@ class TestAssignOrdering(unittest.TestCase):
def test_overlapping_slice_assigns(self):
"""Overlapping slice assigns - later write should win for overlapping elements."""
# without .realize(): assigns not executed, buffer stays zeros
buf = Tensor.zeros(8).contiguous().realize()
buf[0:4].assign(Tensor.ones(4))
buf[2:6].assign(Tensor.ones(4) * 2)
np.testing.assert_equal(buf.numpy(), [0,0,0,0,0,0,0,0]) # TODO: wrong! should be [1,1,2,2,2,2,0,0]
# with .realize(): assigns execute in order
buf = Tensor.zeros(8).contiguous().realize()
buf[0:4].assign(Tensor.ones(4)).realize()
buf[2:6].assign(Tensor.ones(4) * 2).realize()
np.testing.assert_equal(buf.numpy(), [1,1,2,2,2,2,0,0])
def test_overlapping_slice_assigns_reverse(self):
"""Overlapping slice assigns in reverse order."""
# without .realize(): assigns not executed
buf = Tensor.zeros(8).contiguous().realize()
buf[2:6].assign(Tensor.ones(4) * 2)
buf[0:4].assign(Tensor.ones(4))
np.testing.assert_equal(buf.numpy(), [0,0,0,0,0,0,0,0]) # TODO: wrong! should be [1,1,1,1,2,2,0,0]
# with .realize(): assigns execute in order
buf = Tensor.zeros(8).contiguous().realize()
buf[2:6].assign(Tensor.ones(4) * 2).realize()
buf[0:4].assign(Tensor.ones(4)).realize()
np.testing.assert_equal(buf.numpy(), [1,1,1,1,2,2,0,0])
def test_read_between_writes(self):
@@ -619,26 +605,14 @@ class TestAssignOrdering(unittest.TestCase):
def test_slice_write_then_full_read(self):
"""Write to slice, then read full buffer."""
# without .realize(): orphan slice assign not triggered by .numpy()
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
buf[1:3].assign(Tensor([5, 6]))
np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0]
# with .realize(): assign executes
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
buf[1:3].assign(Tensor([5, 6])).realize()
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
def test_chained_slice_copies(self):
"""Copy from one slice to another within same buffer."""
# without .realize(): orphan slice assign not triggered
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
buf[4:8].assign(buf[0:4].contiguous())
np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 5, 6, 7, 8]) # TODO: wrong! should be [1,2,3,4,1,2,3,4]
# with .realize(): assign executes
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
buf[4:8].assign(buf[0:4].contiguous()).realize()
np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 1, 2, 3, 4])
def test_swap_slices(self):
@@ -661,16 +635,9 @@ class TestAssignOrdering(unittest.TestCase):
def test_reduction_after_partial_assign(self):
"""Reduction over buffer after partial assign - must see the assigned values."""
# without .realize(): orphan slice assign not triggered by reduction
buf = Tensor.zeros(4, 4).contiguous().realize()
buf[0:2, :].assign(Tensor.ones(2, 4)) # top half = 1
total = buf.sum()
self.assertEqual(total.item(), 0) # TODO: wrong! should be 8 (2*4 ones)
# with .realize(): assign executes before reduction
buf = Tensor.zeros(4, 4).contiguous().realize()
buf[0:2, :].assign(Tensor.ones(2, 4)).realize()
total = buf.sum()
self.assertEqual(total.item(), 8)
def test_multiple_reductions_different_views(self):
@@ -734,34 +701,18 @@ class TestAssignOrdering(unittest.TestCase):
def test_variable_slice_ordering(self):
"""Variable-indexed slices - tests symbolic dependency tracking."""
v_i = Variable("i", 0, 3)
# without .realize(): orphan slice assigns not triggered
buf = Tensor.zeros(4, 4).contiguous().realize()
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4))
row0_sum = buf[0:1, :].sum()
self.assertEqual(row0_sum.item(), 0) # TODO: wrong! should be 4
# with .realize(): assigns execute
buf = Tensor.zeros(4, 4).contiguous().realize()
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4)).realize()
row0_sum = buf[0:1, :].sum()
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2).realize()
row1_sum = buf[1:2, :].sum()
self.assertEqual(row0_sum.item(), 4)
self.assertEqual(row1_sum.item(), 8)
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2)
self.assertEqual(buf[0:1, :].sum().item(), 4)
self.assertEqual(buf[1:2, :].sum().item(), 8)
def test_multiple_slice_assigns_then_read(self):
"""Multiple non-overlapping slice assigns then read - RAW dependencies must ensure all writes complete before read."""
"""Multiple non-overlapping slice assigns then read."""
buf = Tensor.zeros(4).contiguous().realize()
buf[0:1].assign(Tensor.ones(1))
buf[1:2].assign(Tensor.full((1,), 2.0))
buf[2:3].assign(Tensor.full((1,), 3.0))
self.assertEqual(buf.sum().realize().item(), 0.0) # TODO: wrong! should be 1 + 2 + 3 + 0 = 6
buf = Tensor.zeros(4).contiguous().realize()
buf[0:1].assign(Tensor.ones(1)).realize()
buf[1:2].assign(Tensor.full((1,), 2.0)).realize()
buf[2:3].assign(Tensor.full((1,), 3.0)).realize()
self.assertEqual(buf.sum().realize().item(), 6.0)
if __name__ == "__main__":

View File

@@ -24,6 +24,7 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
# *** all in scope Tensors are here. this gets relevant UOps ***
all_tensors: dict[weakref.ref[Tensor], None] = {}
_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order]
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
with cpu_profile(TracingKey(name), "TINY"):
# get tensors in scope
@@ -271,6 +272,13 @@ class Tensor(OpMixin):
@disable_gc()
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
"""Triggers the computation needed to create these Tensor(s)."""
# side-realize pending assigns for buffers referenced by these tensors
if _pending_assigns:
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
for assign_uop in _pending_assigns.pop(buf, []):
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
return self
@@ -299,7 +307,11 @@ class Tensor(OpMixin):
if is_disk:
self._buffer().copyin(x._data())
return self
return self.replace(self._apply_uop(UOp.assign, x))
result = self._apply_uop(UOp.assign, x)
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
_pending_assigns.setdefault(buf_uop, []).append(result.uop)
return self.replace(result)
def detach(self) -> Tensor:
"""