mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
recursive resolve assign dependency (#14688)
remove the .realize in llm.py
This commit is contained in:
@@ -539,6 +539,61 @@ class TestAssign(unittest.TestCase):
|
||||
cache[v_pos:v_pos+1, :].assign(Tensor.ones(1, 4))
|
||||
self.assertEqual(cache.sum().item(), 4.0)
|
||||
|
||||
def test_chained_assign_slice_then_read(self):
|
||||
"""Three caches with chained assign-then-read: each block writes to its cache and reads back,
|
||||
feeding the result to the next block's assign. Without proper dependency tracking, block N's read
|
||||
may see stale data from block N-1's cache (pre-assign zeros instead of the assigned values).
|
||||
This is the multi-layer KV cache pattern from llm.py._attention.
|
||||
"""
|
||||
D, max_ctx = 4, 8
|
||||
cache1 = Tensor.zeros(max_ctx, D).contiguous().realize()
|
||||
cache2 = Tensor.zeros(max_ctx, D).contiguous().realize()
|
||||
cache3 = Tensor.zeros(max_ctx, D).contiguous().realize()
|
||||
cache1[:3].assign(Tensor.ones(3, D)).realize()
|
||||
cache2[:3].assign(Tensor.ones(3, D) * 2).realize()
|
||||
cache3[:3].assign(Tensor.ones(3, D) * 3).realize()
|
||||
# block 1: assign [10]*D at position 3, read sum -> c1=[13]*D
|
||||
cache1[3:4].assign(Tensor.ones(1, D) * 10)
|
||||
c1 = cache1[:4].sum(0, keepdim=True)
|
||||
# block 2: assign c1 at position 3, read sum -> c2=[19]*D
|
||||
cache2[3:4].assign(c1)
|
||||
c2 = cache2[:4].sum(0, keepdim=True)
|
||||
# block 3: assign c2 at position 3, read sum -> 112
|
||||
cache3[3:4].assign(c2)
|
||||
self.assertEqual(cache3[:4].sum().item(), 112.0)
|
||||
|
||||
def test_chained_assign_kernel_count(self):
|
||||
"""Chained pending assigns must not produce excessive kernels (tests recursive transitive processing)."""
|
||||
D, N = 4, 5
|
||||
caches = [Tensor.zeros(8, D).contiguous().realize() for _ in range(N)]
|
||||
caches[0][0:1].assign(Tensor.ones(1, D) * 10)
|
||||
x = caches[0][:1].sum(0, keepdim=True)
|
||||
for i in range(1, N):
|
||||
caches[i][0:1].assign(x)
|
||||
x = caches[i][:1].sum(0, keepdim=True)
|
||||
GlobalCounters.reset()
|
||||
x.realize()
|
||||
# N assigns (1 kernel each) producing N kernels total
|
||||
self.assertEqual(GlobalCounters.kernel_count, N)
|
||||
|
||||
def test_shared_computation_assign_kernel_count(self):
|
||||
"""When a .contiguous() is shared between an assign value and the next layer's input (like QKV projection in LLM),
|
||||
substitute optimization replaces already-realized sub-graphs in remaining pending assigns, preventing kernel escalation.
|
||||
Without substitute, pending assign graphs grow linearly and produce 153 kernels instead of 48."""
|
||||
D, N = 16, 16
|
||||
caches = [Tensor.zeros(4, D).contiguous().realize() for _ in range(N)]
|
||||
W = [Tensor.full((D, D*2), 0.01).contiguous().realize() for _ in range(N)]
|
||||
x = Tensor.ones(1, D).contiguous().realize()
|
||||
for i in range(N):
|
||||
shared = (x @ W[i]).contiguous() # .contiguous() UOp is shared between assign (k) and next layer (q)
|
||||
k, q = shared[:, :D], shared[:, D:]
|
||||
caches[i][0:1].assign(k) # assign references the CONTIGUOUS
|
||||
x = q + caches[i][:1] # next layer also references the same CONTIGUOUS through q
|
||||
GlobalCounters.reset()
|
||||
caches[-1][:1].contiguous().realize()
|
||||
# 2 kernels for first assign + 3 per remaining assign (matmul, contiguous, assign) + 1 final read = 3*N
|
||||
self.assertEqual(GlobalCounters.kernel_count, 3*N)
|
||||
|
||||
|
||||
class TestAssignOrdering(unittest.TestCase):
|
||||
"""Tests for complex assign orderings that could differ between lazy and eager execution.
|
||||
|
||||
@@ -131,10 +131,9 @@ class TransformerBlock:
|
||||
q = apply_rope(q, freqs_cis)
|
||||
k = apply_rope(k, freqs_cis)
|
||||
|
||||
# TODO: remove these kv cache realizes
|
||||
if not hasattr(self, "cache_kv"):
|
||||
self.cache_kv = Tensor.zeros(2, B, self.n_kv_heads, self.max_context, self.head_dim, dtype=k.dtype, device=k.device).contiguous().realize()
|
||||
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v)).realize()
|
||||
self.cache_kv[:, :, :, start_pos:start_pos+T, :].assign(Tensor.stack(k, v))
|
||||
k = self.cache_kv[0, :, :, 0:start_pos+T, :]
|
||||
v = self.cache_kv[1, :, :, 0:start_pos+T, :]
|
||||
|
||||
|
||||
@@ -274,11 +274,20 @@ class Tensor(OpMixin):
|
||||
"""Triggers the computation needed to create these Tensor(s)."""
|
||||
# side-realize pending assigns for buffers referenced by these tensors
|
||||
if _pending_assigns:
|
||||
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
|
||||
def _realize_pending(buf):
|
||||
for assign_uop in _pending_assigns.pop(buf, []):
|
||||
# recursively realize pending assigns that this assign's value depends on
|
||||
for u in assign_uop.toposort():
|
||||
if u.op is Ops.BUFFER and u in _pending_assigns: _realize_pending(u)
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(UOp.sink(assign_uop))
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Pending Assign")
|
||||
run_schedule(schedule, var_vals, do_update_stats=do_update_stats)
|
||||
# update remaining pending assigns so they reference realized buffers instead of stale lazy graphs
|
||||
if becomes_map:
|
||||
for assigns in _pending_assigns.values():
|
||||
for i in range(len(assigns)): assigns[i] = assigns[i].substitute(becomes_map)
|
||||
for buf in {u for t in (self,)+lst for u in t.uop.toposort() if u.op is Ops.BUFFER}:
|
||||
if buf in _pending_assigns: _realize_pending(buf)
|
||||
if len(to_realize:=[x for x in (self,)+lst if not x.uop.has_buffer_identity()]):
|
||||
run_schedule(*Tensor.schedule_with_vars(*to_realize), do_update_stats=do_update_stats)
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user