mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
assign realize fix (#14649)
fix the need for explicit assign. track pending assigns for each buffer, and run those before the main realize in order
This commit is contained in:
@@ -175,7 +175,6 @@ class TestZeroFolding(unittest.TestCase):
|
||||
class TestAssignIssues(unittest.TestCase):
|
||||
# these are good failures. i'm not sure we need more, but we need to fix these.
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_permuted_view_constant(self):
|
||||
# assigning to a permuted view should modify the underlying tensor
|
||||
arr = np.arange(6).reshape(2, 3).astype(np.float32)
|
||||
@@ -185,7 +184,6 @@ class TestAssignIssues(unittest.TestCase):
|
||||
t.permute(1, 0).assign(Tensor([[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]))
|
||||
np.testing.assert_allclose(t.numpy(), torch_tensor.numpy())
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_assign_shrink_view_constant(self):
|
||||
# assigning to a shrunk view should update the base tensor
|
||||
arr = np.arange(9).reshape(3, 3).astype(np.float32)
|
||||
|
||||
@@ -111,30 +111,18 @@ class TestJitFootguns(unittest.TestCase):
|
||||
self.assertEqual(first.numpy().item(), expected_first)
|
||||
buf = new_buf
|
||||
|
||||
def test_slice_assign_requires_realize(self):
|
||||
"""Slice assign then read from same buffer - assign isn't connected to read without explicit realize()."""
|
||||
def test_slice_assign_works_without_realize(self):
|
||||
"""Slice assign then read from same buffer - pending assigns are side-realized."""
|
||||
from tinygrad import Variable
|
||||
v_pos = Variable("pos", 0, 3)
|
||||
|
||||
# without .realize() after assign, the read doesn't see the assigned values
|
||||
cache = Tensor.zeros(4, 4).contiguous().realize()
|
||||
@TinyJit
|
||||
def f_broken(pos):
|
||||
def f(pos):
|
||||
cache[pos:pos+1, :].assign(Tensor.ones(1, 4))
|
||||
return cache.sum().realize()
|
||||
for i in range(4):
|
||||
cache.assign(Tensor.zeros(4, 4)).realize()
|
||||
self.assertEqual(f_broken(v_pos.bind(i)).item(), 0.0) # should be 4.0!
|
||||
|
||||
# workaround: add .realize() after assign
|
||||
cache2 = Tensor.zeros(4, 4).contiguous().realize()
|
||||
@TinyJit
|
||||
def f_fixed(pos):
|
||||
cache2[pos:pos+1, :].assign(Tensor.ones(1, 4)).realize()
|
||||
return cache2.sum().realize()
|
||||
for i in range(4):
|
||||
cache2.assign(Tensor.zeros(4, 4)).realize()
|
||||
self.assertEqual(f_fixed(v_pos.bind(i)).item(), 4.0)
|
||||
self.assertEqual(f(v_pos.bind(i)).item(), 4.0)
|
||||
|
||||
def test_symbolic_pad_view_frozen(self):
|
||||
"""Symbolic pad view has BIND values baked in at capture time. TODO: pad should be captured in jit."""
|
||||
|
||||
@@ -560,30 +560,16 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
|
||||
def test_overlapping_slice_assigns(self):
|
||||
"""Overlapping slice assigns - later write should win for overlapping elements."""
|
||||
# without .realize(): assigns not executed, buffer stays zeros
|
||||
buf = Tensor.zeros(8).contiguous().realize()
|
||||
buf[0:4].assign(Tensor.ones(4))
|
||||
buf[2:6].assign(Tensor.ones(4) * 2)
|
||||
np.testing.assert_equal(buf.numpy(), [0,0,0,0,0,0,0,0]) # TODO: wrong! should be [1,1,2,2,2,2,0,0]
|
||||
|
||||
# with .realize(): assigns execute in order
|
||||
buf = Tensor.zeros(8).contiguous().realize()
|
||||
buf[0:4].assign(Tensor.ones(4)).realize()
|
||||
buf[2:6].assign(Tensor.ones(4) * 2).realize()
|
||||
np.testing.assert_equal(buf.numpy(), [1,1,2,2,2,2,0,0])
|
||||
|
||||
def test_overlapping_slice_assigns_reverse(self):
|
||||
"""Overlapping slice assigns in reverse order."""
|
||||
# without .realize(): assigns not executed
|
||||
buf = Tensor.zeros(8).contiguous().realize()
|
||||
buf[2:6].assign(Tensor.ones(4) * 2)
|
||||
buf[0:4].assign(Tensor.ones(4))
|
||||
np.testing.assert_equal(buf.numpy(), [0,0,0,0,0,0,0,0]) # TODO: wrong! should be [1,1,1,1,2,2,0,0]
|
||||
|
||||
# with .realize(): assigns execute in order
|
||||
buf = Tensor.zeros(8).contiguous().realize()
|
||||
buf[2:6].assign(Tensor.ones(4) * 2).realize()
|
||||
buf[0:4].assign(Tensor.ones(4)).realize()
|
||||
np.testing.assert_equal(buf.numpy(), [1,1,1,1,2,2,0,0])
|
||||
|
||||
def test_read_between_writes(self):
|
||||
@@ -619,26 +605,14 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
|
||||
def test_slice_write_then_full_read(self):
|
||||
"""Write to slice, then read full buffer."""
|
||||
# without .realize(): orphan slice assign not triggered by .numpy()
|
||||
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
buf[1:3].assign(Tensor([5, 6]))
|
||||
np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0]
|
||||
|
||||
# with .realize(): assign executes
|
||||
buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize()
|
||||
buf[1:3].assign(Tensor([5, 6])).realize()
|
||||
np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0])
|
||||
|
||||
def test_chained_slice_copies(self):
|
||||
"""Copy from one slice to another within same buffer."""
|
||||
# without .realize(): orphan slice assign not triggered
|
||||
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
|
||||
buf[4:8].assign(buf[0:4].contiguous())
|
||||
np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 5, 6, 7, 8]) # TODO: wrong! should be [1,2,3,4,1,2,3,4]
|
||||
|
||||
# with .realize(): assign executes
|
||||
buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize()
|
||||
buf[4:8].assign(buf[0:4].contiguous()).realize()
|
||||
np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 1, 2, 3, 4])
|
||||
|
||||
def test_swap_slices(self):
|
||||
@@ -661,16 +635,9 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
|
||||
def test_reduction_after_partial_assign(self):
|
||||
"""Reduction over buffer after partial assign - must see the assigned values."""
|
||||
# without .realize(): orphan slice assign not triggered by reduction
|
||||
buf = Tensor.zeros(4, 4).contiguous().realize()
|
||||
buf[0:2, :].assign(Tensor.ones(2, 4)) # top half = 1
|
||||
total = buf.sum()
|
||||
self.assertEqual(total.item(), 0) # TODO: wrong! should be 8 (2*4 ones)
|
||||
|
||||
# with .realize(): assign executes before reduction
|
||||
buf = Tensor.zeros(4, 4).contiguous().realize()
|
||||
buf[0:2, :].assign(Tensor.ones(2, 4)).realize()
|
||||
total = buf.sum()
|
||||
self.assertEqual(total.item(), 8)
|
||||
|
||||
def test_multiple_reductions_different_views(self):
|
||||
@@ -734,34 +701,18 @@ class TestAssignOrdering(unittest.TestCase):
|
||||
def test_variable_slice_ordering(self):
|
||||
"""Variable-indexed slices - tests symbolic dependency tracking."""
|
||||
v_i = Variable("i", 0, 3)
|
||||
|
||||
# without .realize(): orphan slice assigns not triggered
|
||||
buf = Tensor.zeros(4, 4).contiguous().realize()
|
||||
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4))
|
||||
row0_sum = buf[0:1, :].sum()
|
||||
self.assertEqual(row0_sum.item(), 0) # TODO: wrong! should be 4
|
||||
|
||||
# with .realize(): assigns execute
|
||||
buf = Tensor.zeros(4, 4).contiguous().realize()
|
||||
buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4)).realize()
|
||||
row0_sum = buf[0:1, :].sum()
|
||||
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2).realize()
|
||||
row1_sum = buf[1:2, :].sum()
|
||||
self.assertEqual(row0_sum.item(), 4)
|
||||
self.assertEqual(row1_sum.item(), 8)
|
||||
buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2)
|
||||
self.assertEqual(buf[0:1, :].sum().item(), 4)
|
||||
self.assertEqual(buf[1:2, :].sum().item(), 8)
|
||||
|
||||
def test_multiple_slice_assigns_then_read(self):
|
||||
"""Multiple non-overlapping slice assigns then read - RAW dependencies must ensure all writes complete before read."""
|
||||
"""Multiple non-overlapping slice assigns then read."""
|
||||
buf = Tensor.zeros(4).contiguous().realize()
|
||||
buf[0:1].assign(Tensor.ones(1))
|
||||
buf[1:2].assign(Tensor.full((1,), 2.0))
|
||||
buf[2:3].assign(Tensor.full((1,), 3.0))
|
||||
self.assertEqual(buf.sum().realize().item(), 0.0) # TODO: wrong! should be 1 + 2 + 3 + 0 = 6
|
||||
|
||||
buf = Tensor.zeros(4).contiguous().realize()
|
||||
buf[0:1].assign(Tensor.ones(1)).realize()
|
||||
buf[1:2].assign(Tensor.full((1,), 2.0)).realize()
|
||||
buf[2:3].assign(Tensor.full((1,), 3.0)).realize()
|
||||
self.assertEqual(buf.sum().realize().item(), 6.0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -24,6 +24,7 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
all_tensors: dict[weakref.ref[Tensor], None] = {}
|
||||
_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order]
|
||||
def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||
with cpu_profile(TracingKey(name), "TINY"):
|
||||
# get tensors in scope
|
||||
@@ -271,6 +272,13 @@ class Tensor(OpMixin):
|
||||
@disable_gc()
|
||||
def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor:
|
||||
"""Triggers the computation needed to create these Tensor(s)."""
|
||||
# side-realize pending assigns for buffers referenced by these tensors
|
||||
if _pending_assigns:
|
||||
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
|
||||
for assign_uop in _pending_assigns.pop(buf, []):
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
|
||||
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
|
||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
|
||||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
||||
return self
|
||||
@@ -299,7 +307,11 @@ class Tensor(OpMixin):
|
||||
if is_disk:
|
||||
self._buffer().copyin(x._data())
|
||||
return self
|
||||
return self.replace(self._apply_uop(UOp.assign, x))
|
||||
result = self._apply_uop(UOp.assign, x)
|
||||
# track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read
|
||||
if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity():
|
||||
_pending_assigns.setdefault(buf_uop, []).append(result.uop)
|
||||
return self.replace(result)
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user