diff --git a/test/test_edgecases.py b/test/test_edgecases.py index 59358a5a0e..4a497f7e40 100644 --- a/test/test_edgecases.py +++ b/test/test_edgecases.py @@ -175,7 +175,6 @@ class TestZeroFolding(unittest.TestCase): class TestAssignIssues(unittest.TestCase): # these are good failures. i'm not sure we need more, but we need to fix these. - @unittest.expectedFailure def test_assign_permuted_view_constant(self): # assigning to a permuted view should modify the underlying tensor arr = np.arange(6).reshape(2, 3).astype(np.float32) @@ -185,7 +184,6 @@ class TestAssignIssues(unittest.TestCase): t.permute(1, 0).assign(Tensor([[5.0, 6.0], [7.0, 8.0], [9.0, 10.0]])) np.testing.assert_allclose(t.numpy(), torch_tensor.numpy()) - @unittest.expectedFailure def test_assign_shrink_view_constant(self): # assigning to a shrunk view should update the base tensor arr = np.arange(9).reshape(3, 3).astype(np.float32) diff --git a/test/test_jit_footguns.py b/test/test_jit_footguns.py index b7df7f8eaa..1fd58c28ea 100644 --- a/test/test_jit_footguns.py +++ b/test/test_jit_footguns.py @@ -111,30 +111,18 @@ class TestJitFootguns(unittest.TestCase): self.assertEqual(first.numpy().item(), expected_first) buf = new_buf - def test_slice_assign_requires_realize(self): - """Slice assign then read from same buffer - assign isn't connected to read without explicit realize().""" + def test_slice_assign_works_without_realize(self): + """Slice assign then read from same buffer - pending assigns are side-realized.""" from tinygrad import Variable v_pos = Variable("pos", 0, 3) - - # without .realize() after assign, the read doesn't see the assigned values cache = Tensor.zeros(4, 4).contiguous().realize() @TinyJit - def f_broken(pos): + def f(pos): cache[pos:pos+1, :].assign(Tensor.ones(1, 4)) return cache.sum().realize() for i in range(4): cache.assign(Tensor.zeros(4, 4)).realize() - self.assertEqual(f_broken(v_pos.bind(i)).item(), 0.0) # should be 4.0! - - # workaround: add .realize() after assign - cache2 = Tensor.zeros(4, 4).contiguous().realize() - @TinyJit - def f_fixed(pos): - cache2[pos:pos+1, :].assign(Tensor.ones(1, 4)).realize() - return cache2.sum().realize() - for i in range(4): - cache2.assign(Tensor.zeros(4, 4)).realize() - self.assertEqual(f_fixed(v_pos.bind(i)).item(), 4.0) + self.assertEqual(f(v_pos.bind(i)).item(), 4.0) def test_symbolic_pad_view_frozen(self): """Symbolic pad view has BIND values baked in at capture time. TODO: pad should be captured in jit.""" diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index be8c05d669..1c19525799 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -560,30 +560,16 @@ class TestAssignOrdering(unittest.TestCase): def test_overlapping_slice_assigns(self): """Overlapping slice assigns - later write should win for overlapping elements.""" - # without .realize(): assigns not executed, buffer stays zeros buf = Tensor.zeros(8).contiguous().realize() buf[0:4].assign(Tensor.ones(4)) buf[2:6].assign(Tensor.ones(4) * 2) - np.testing.assert_equal(buf.numpy(), [0,0,0,0,0,0,0,0]) # TODO: wrong! should be [1,1,2,2,2,2,0,0] - - # with .realize(): assigns execute in order - buf = Tensor.zeros(8).contiguous().realize() - buf[0:4].assign(Tensor.ones(4)).realize() - buf[2:6].assign(Tensor.ones(4) * 2).realize() np.testing.assert_equal(buf.numpy(), [1,1,2,2,2,2,0,0]) def test_overlapping_slice_assigns_reverse(self): """Overlapping slice assigns in reverse order.""" - # without .realize(): assigns not executed buf = Tensor.zeros(8).contiguous().realize() buf[2:6].assign(Tensor.ones(4) * 2) buf[0:4].assign(Tensor.ones(4)) - np.testing.assert_equal(buf.numpy(), [0,0,0,0,0,0,0,0]) # TODO: wrong! should be [1,1,1,1,2,2,0,0] - - # with .realize(): assigns execute in order - buf = Tensor.zeros(8).contiguous().realize() - buf[2:6].assign(Tensor.ones(4) * 2).realize() - buf[0:4].assign(Tensor.ones(4)).realize() np.testing.assert_equal(buf.numpy(), [1,1,1,1,2,2,0,0]) def test_read_between_writes(self): @@ -619,26 +605,14 @@ class TestAssignOrdering(unittest.TestCase): def test_slice_write_then_full_read(self): """Write to slice, then read full buffer.""" - # without .realize(): orphan slice assign not triggered by .numpy() buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize() buf[1:3].assign(Tensor([5, 6])) - np.testing.assert_equal(buf.numpy(), [0, 0, 0, 0]) # TODO: wrong! should be [0, 5, 6, 0] - - # with .realize(): assign executes - buf = Tensor.zeros(4, dtype=dtypes.int32).contiguous().realize() - buf[1:3].assign(Tensor([5, 6])).realize() np.testing.assert_equal(buf.numpy(), [0, 5, 6, 0]) def test_chained_slice_copies(self): """Copy from one slice to another within same buffer.""" - # without .realize(): orphan slice assign not triggered buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() buf[4:8].assign(buf[0:4].contiguous()) - np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 5, 6, 7, 8]) # TODO: wrong! should be [1,2,3,4,1,2,3,4] - - # with .realize(): assign executes - buf = Tensor([1, 2, 3, 4, 5, 6, 7, 8]).contiguous().realize() - buf[4:8].assign(buf[0:4].contiguous()).realize() np.testing.assert_equal(buf.numpy(), [1, 2, 3, 4, 1, 2, 3, 4]) def test_swap_slices(self): @@ -661,16 +635,9 @@ class TestAssignOrdering(unittest.TestCase): def test_reduction_after_partial_assign(self): """Reduction over buffer after partial assign - must see the assigned values.""" - # without .realize(): orphan slice assign not triggered by reduction buf = Tensor.zeros(4, 4).contiguous().realize() buf[0:2, :].assign(Tensor.ones(2, 4)) # top half = 1 total = buf.sum() - self.assertEqual(total.item(), 0) # TODO: wrong! should be 8 (2*4 ones) - - # with .realize(): assign executes before reduction - buf = Tensor.zeros(4, 4).contiguous().realize() - buf[0:2, :].assign(Tensor.ones(2, 4)).realize() - total = buf.sum() self.assertEqual(total.item(), 8) def test_multiple_reductions_different_views(self): @@ -734,34 +701,18 @@ class TestAssignOrdering(unittest.TestCase): def test_variable_slice_ordering(self): """Variable-indexed slices - tests symbolic dependency tracking.""" v_i = Variable("i", 0, 3) - - # without .realize(): orphan slice assigns not triggered buf = Tensor.zeros(4, 4).contiguous().realize() buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4)) - row0_sum = buf[0:1, :].sum() - self.assertEqual(row0_sum.item(), 0) # TODO: wrong! should be 4 - - # with .realize(): assigns execute - buf = Tensor.zeros(4, 4).contiguous().realize() - buf[v_i.bind(0):v_i.bind(0)+1, :].assign(Tensor.ones(1, 4)).realize() - row0_sum = buf[0:1, :].sum() - buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2).realize() - row1_sum = buf[1:2, :].sum() - self.assertEqual(row0_sum.item(), 4) - self.assertEqual(row1_sum.item(), 8) + buf[v_i.bind(1):v_i.bind(1)+1, :].assign(Tensor.ones(1, 4) * 2) + self.assertEqual(buf[0:1, :].sum().item(), 4) + self.assertEqual(buf[1:2, :].sum().item(), 8) def test_multiple_slice_assigns_then_read(self): - """Multiple non-overlapping slice assigns then read - RAW dependencies must ensure all writes complete before read.""" + """Multiple non-overlapping slice assigns then read.""" buf = Tensor.zeros(4).contiguous().realize() buf[0:1].assign(Tensor.ones(1)) buf[1:2].assign(Tensor.full((1,), 2.0)) buf[2:3].assign(Tensor.full((1,), 3.0)) - self.assertEqual(buf.sum().realize().item(), 0.0) # TODO: wrong! should be 1 + 2 + 3 + 0 = 6 - - buf = Tensor.zeros(4).contiguous().realize() - buf[0:1].assign(Tensor.ones(1)).realize() - buf[1:2].assign(Tensor.full((1,), 2.0)).realize() - buf[2:3].assign(Tensor.full((1,), 3.0)).realize() self.assertEqual(buf.sum().realize().item(), 6.0) if __name__ == "__main__": diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index e4b208499c..ff01c33737 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -24,6 +24,7 @@ def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]: # *** all in scope Tensors are here. this gets relevant UOps *** all_tensors: dict[weakref.ref[Tensor], None] = {} +_pending_assigns: dict[UOp, list[UOp]] = {} # buffer_uop -> [assign_uops in insertion order] def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None: with cpu_profile(TracingKey(name), "TINY"): # get tensors in scope @@ -271,6 +272,13 @@ class Tensor(OpMixin): @disable_gc() def realize(self, *lst:Tensor, do_update_stats=True) -> Tensor: """Triggers the computation needed to create these Tensor(s).""" + # side-realize pending assigns for buffers referenced by these tensors + if _pending_assigns: + for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}: + for assign_uop in _pending_assigns.pop(buf, []): + becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop)) + _apply_map_to_tensors(becomes_map, name="Apply Pending Assign") + run_schedule(schedule, var_vals, do_update_stats=do_update_stats) if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]): run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats) return self @@ -299,7 +307,11 @@ class Tensor(OpMixin): if is_disk: self._buffer().copyin(x._data()) return self - return self.replace(self._apply_uop(UOp.assign, x)) + result = self._apply_uop(UOp.assign, x) + # track view assigns (not full-buffer or assign-chain) so they can be side-realized when the buffer is read + if (buf_uop:=self.uop.base).op is Ops.BUFFER and self.uop.op is not Ops.ASSIGN and not self.uop.has_buffer_identity(): + _pending_assigns.setdefault(buf_uop, []).append(result.uop) + return self.replace(result) def detach(self) -> Tensor: """