From 0c63f63ee4a8397f7681f1f5d2ff77dadc331062 Mon Sep 17 00:00:00 2001 From: chenyu Date: Wed, 11 Feb 2026 17:41:05 -0500 Subject: [PATCH] recursive resolve assign dependency (#14688) remove the .realize in llm.py --- test/unit/test_assign.py | 55 ++++++++++++++++++++++++++++++++++++++++ tinygrad/apps/llm.py | 3 +-- tinygrad/tensor.py | 11 +++++++- 3 files changed, 66 insertions(+), 3 deletions(-) diff --git a/test/unit/test_assign.py b/test/unit/test_assign.py index 5e41c36908..2c61a0c504 100644 --- a/test/unit/test_assign.py +++ b/test/unit/test_assign.py @@ -539,6 +539,61 @@ class TestAssign(unittest.TestCase): cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4)) self.assertEqual(cache.sum().item(), 4.0) + def test_chained_assign_slice_then_read(self): + """Three caches with chained assign-then-read: each block writes to its cache and reads back, + feeding the result to the next block's assign. Without proper dependency tracking, block N's read + may see stale data from block N-1's cache (pre-assign zeros instead of the assigned values). + This is the multi-layer KV cache pattern from llm.py._attention. + """ + D, max_ctx = 4, 8 + cache1 = Tensor.zeros(max_ctx, D).contiguous().realize() + cache2 = Tensor.zeros(max_ctx, D).contiguous().realize() + cache3 = Tensor.zeros(max_ctx, D).contiguous().realize() + cache1[:3].assign(Tensor.ones(3, D)).realize() + cache2[:3].assign(Tensor.ones(3, D) * 2).realize() + cache3[:3].assign(Tensor.ones(3, D) * 3).realize() + # block 1: assign [10]*D at position 3, read sum -> c1=[13]*D + cache1[3:4].assign(Tensor.ones(1, D) * 10) + c1 = cache1[:4].sum(0, keepdim=True) + # block 2: assign c1 at position 3, read sum -> c2=[19]*D + cache2[3:4].assign(c1) + c2 = cache2[:4].sum(0, keepdim=True) + # block 3: assign c2 at position 3, read sum -> 112 + cache3[3:4].assign(c2) + self.assertEqual(cache3[:4].sum().item(), 112.0) + + def test_chained_assign_kernel_count(self): + """Chained pending assigns must not produce excessive kernels (tests recursive transitive processing).""" + D, N = 4, 5 + caches = [Tensor.zeros(8, D).contiguous().realize() for _ in range(N)] + caches[0][0:1].assign(Tensor.ones(1, D) * 10) + x = caches[0][:1].sum(0, keepdim=True) + for i in range(1, N): + caches[i][0:1].assign(x) + x = caches[i][:1].sum(0, keepdim=True) + GlobalCounters.reset() + x.realize() + # N assigns (1 kernel each) producing N kernels total + self.assertEqual(GlobalCounters.kernel_count, N) + + def test_shared_computation_assign_kernel_count(self): + """When a .contiguous() is shared between an assign value and the next layer's input (like QKV projection in LLM), + substitute optimization replaces already-realized sub-graphs in remaining pending assigns, preventing kernel escalation. + Without substitute, pending assign graphs grow linearly and produce 153 kernels instead of 48.""" + D, N = 16, 16 + caches = [Tensor.zeros(4, D).contiguous().realize() for _ in range(N)] + W = [Tensor.full((D, D*2), 0.01).contiguous().realize() for _ in range(N)] + x = Tensor.ones(1, D).contiguous().realize() + for i in range(N): + shared = (x @ W[i]).contiguous() # .contiguous() UOp is shared between assign (k) and next layer (q) + k, q = shared[:, :D], shared[:, D:] + caches[i][0:1].assign(k) # assign references the CONTIGUOUS + x = q + caches[i][:1] # next layer also references the same CONTIGUOUS through q + GlobalCounters.reset() + caches[-1][:1].contiguous().realize() + # 2 kernels for first assign + 3 per remaining assign (matmul, contiguous, assign) + 1 final read = 3*N + self.assertEqual(GlobalCounters.kernel_count, 3*N) + class TestAssignOrdering(unittest.TestCase): """Tests for complex assign orderings that could differ between lazy and eager execution. diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 604336a893..274703eb61 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -131,10 +131,9 @@ class TransformerBlock: q = apply_rope(q, freqs_cis) k = apply_rope(k, freqs_cis) - # TODO: remove these kv cache realizes if not hasattr(self, "cache_kv"): self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize() - self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize() + self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)) k = self.cache_kv[0, :, :, 0:start_pos+T, :] v = self.cache_kv[1, :, :, 0:start_pos+T, :] diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 330e64012c..b6e6abd09e 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -274,11 +274,20 @@ class Tensor(OpMixin): """Triggers the computation needed to create these Tensor(s).""" # side-realize pending assigns for buffers referenced by these tensors if _pending_assigns: - for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}: + def _realize_pending(buf): for assign_uop in _pending_assigns.pop(buf, []): + # recursively realize pending assigns that this assign's value depends on + for u in assign_uop.toposort(): + if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u) becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop)) _apply_map_to_tensors(becomes_map, name="Apply Pending Assign") run_schedule(schedule, var_vals, do_update_stats=do_update_stats) + # update remaining pending assigns so they reference realized buffers instead of stale lazy graphs + if becomes_map: + for assigns in _pending_assigns.values(): + for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map) + for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}: + if buf in _pending_assigns: _realize_pending(buf) if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]): run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats) return self