recursive resolve assign dependency (#14688)

remove the .realize in llm.py
This commit is contained in:
chenyu
2026-02-11 17:41:05 -05:00
committed by GitHub
parent 869083e373
commit 0c63f63ee4
3 changed files with 66 additions and 3 deletions

View File

@@ -539,6 +539,61 @@ class TestAssign(unittest.TestCase):
cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4))
self.assertEqual(cache.sum().item(), 4.0)
def test_chained_assign_slice_then_read(self):
"""Three caches with chained assign-then-read: each block writes to its cache and reads back,
feeding the result to the next block's assign. Without proper dependency tracking, block N's read
may see stale data from block N-1's cache (pre-assign zeros instead of the assigned values).
This is the multi-layer KV cache pattern from llm.py._attention.
"""
D, max_ctx = 4, 8
cache1 = Tensor.zeros(max_ctx, D).contiguous().realize()
cache2 = Tensor.zeros(max_ctx, D).contiguous().realize()
cache3 = Tensor.zeros(max_ctx, D).contiguous().realize()
cache1[:3].assign(Tensor.ones(3, D)).realize()
cache2[:3].assign(Tensor.ones(3, D) * 2).realize()
cache3[:3].assign(Tensor.ones(3, D) * 3).realize()
# block 1: assign [10]*D at position 3, read sum -> c1=[13]*D
cache1[3:4].assign(Tensor.ones(1, D) * 10)
c1 = cache1[:4].sum(0, keepdim=True)
# block 2: assign c1 at position 3, read sum -> c2=[19]*D
cache2[3:4].assign(c1)
c2 = cache2[:4].sum(0, keepdim=True)
# block 3: assign c2 at position 3, read sum -> 112
cache3[3:4].assign(c2)
self.assertEqual(cache3[:4].sum().item(), 112.0)
def test_chained_assign_kernel_count(self):
"""Chained pending assigns must not produce excessive kernels (tests recursive transitive processing)."""
D, N = 4, 5
caches = [Tensor.zeros(8, D).contiguous().realize() for _ in range(N)]
caches[0][0:1].assign(Tensor.ones(1, D) * 10)
x = caches[0][:1].sum(0, keepdim=True)
for i in range(1, N):
caches[i][0:1].assign(x)
x = caches[i][:1].sum(0, keepdim=True)
GlobalCounters.reset()
x.realize()
# N assigns (1 kernel each) producing N kernels total
self.assertEqual(GlobalCounters.kernel_count, N)
def test_shared_computation_assign_kernel_count(self):
"""When a .contiguous() is shared between an assign value and the next layer's input (like QKV projection in LLM),
substitute optimization replaces already-realized sub-graphs in remaining pending assigns, preventing kernel escalation.
Without substitute, pending assign graphs grow linearly and produce 153 kernels instead of 48."""
D, N = 16, 16
caches = [Tensor.zeros(4, D).contiguous().realize() for _ in range(N)]
W = [Tensor.full((D, D*2), 0.01).contiguous().realize() for _ in range(N)]
x = Tensor.ones(1, D).contiguous().realize()
for i in range(N):
shared = (x @ W[i]).contiguous() # .contiguous() UOp is shared between assign (k) and next layer (q)
k, q = shared[:, :D], shared[:, D:]
caches[i][0:1].assign(k) # assign references the CONTIGUOUS
x = q + caches[i][:1] # next layer also references the same CONTIGUOUS through q
GlobalCounters.reset()
caches[-1][:1].contiguous().realize()
# 2 kernels for first assign + 3 per remaining assign (matmul, contiguous, assign) + 1 final read = 3*N
self.assertEqual(GlobalCounters.kernel_count, 3*N)
class TestAssignOrdering(unittest.TestCase):
"""Tests for complex assign orderings that could differ between lazy and eager execution.

View File

@@ -131,10 +131,9 @@ class TransformerBlock:
q = apply_rope(q, freqs_cis)
k = apply_rope(k, freqs_cis)
# TODO: remove these kv cache realizes
if not hasattr(self, "cache_kv"):
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize()
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
v = self.cache_kv[1, :, :, 0:start_pos+T, :]

View File

@@ -274,11 +274,20 @@ class Tensor(OpMixin):
"""Triggers the computation needed to create these Tensor(s)."""
# side-realize pending assigns for buffers referenced by these tensors
if _pending_assigns:
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
def _realize_pending(buf):
for assign_uop in _pending_assigns.pop(buf, []):
# recursively realize pending assigns that this assign's value depends on
for u in assign_uop.toposort():
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
if becomes_map:
for assigns in _pending_assigns.values():
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
if buf in _pending_assigns: _realize_pending(buf)
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
return self