fully symbolic llm (#15097)

* work

* llm symbolic (almost)

* work

* revert that

* llm sym

* works

* cleanups

* cache tokens with the kv cache

* cleanups

* cleanups
This commit is contained in:
George Hotz
2026-03-05 10:22:11 +08:00
committed by GitHub
parent 33a1970045
commit ac1847cbf7
6 changed files with 137 additions and 35 deletions

View File

@@ -2,7 +2,7 @@ import unittest
import numpy as np
from tinygrad import Tensor, function
from tinygrad.dtype import dtypes
from tinygrad.uop.ops import UOp
from tinygrad.uop.ops import UOp, Ops
class TestCall(unittest.TestCase):
def test_call_plus(self):
@@ -100,6 +100,42 @@ class TestCall(unittest.TestCase):
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
class TestCallShape(unittest.TestCase):
def test_call_shape_int(self):
# fixed-shape function: shape passes through unchanged
@function
def f(x:Tensor) -> Tensor: return x * 2
self.assertEqual(f(Tensor.empty(4, 8)).shape, (4, 8))
def test_call_shape_param_substitution(self):
# symbolic shape dimension is substituted: inner PARAM replaced with the BIND arg
@function
def f(x:Tensor) -> Tensor: return x * 2
sz = UOp.variable("sz", 1, 8)
shape = f(Tensor.empty(8)[:sz.bind(5)]).shape
# the PARAM should be gone, replaced with the BIND from the call arg
self.assertIsInstance(shape[0], UOp)
self.assertNotEqual(shape[0].op, Ops.PARAM)
self.assertEqual(shape[0], sz.bind(5))
def test_call_shape_expr_substitution(self):
# expression containing PARAMs in shape gets fully substituted
@function
def f(x:Tensor) -> Tensor: return x + 1
sz = UOp.variable("sz", 1, 10)
shape = f(Tensor.empty(10, 4)[:sz.bind(3)]).shape
self.assertIsInstance(shape[0], UOp)
self.assertNotEqual(shape[0].op, Ops.PARAM)
self.assertEqual(shape[1], 4)
def test_call_shape_no_param_passthrough(self):
# a non-PARAM UOp shape element passes through unchanged
@function
def f(x:Tensor) -> Tensor: return x * 3
sz = UOp.variable("sz", 1, 8)
shape = f(Tensor.empty(8)[:sz.bind(5)]).shape
self.assertEqual(shape[0], sz.bind(5))
class TestCallSchedule(unittest.TestCase):
def test_reshape_precompile(self):
a = Tensor.empty(4, 8).realize()

View File

@@ -1,29 +1,84 @@
import unittest
from unittest.mock import patch
from tinygrad import Tensor
from tinygrad import Tensor, UOp
from tinygrad.engine.schedule import schedule_cache
class TestTransformerGenerate(unittest.TestCase):
def test_start_pos_parameter_is_used(self):
"""Test that start_pos parameter is not ignored (regression test for always resetting to 0)."""
def test_kv_cache_reuse(self):
"""Test that generate reuses the KV cache when tokens extend the cached prefix."""
from tinygrad.apps.llm import Transformer
# Create a minimal transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
captured_inputs = []
def mock_call(self, tokens, start_pos):
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.bind_val))
return Tensor([[42]]) # return a fake next token
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.val))
return Tensor([[42]])
with patch.object(Transformer, '__call__', mock_call):
# first conversation: prefill 5 tokens + 1 decode
tokens = [1, 2, 3, 4, 5]
gen = model.generate(tokens, start_pos=3)
next(gen) # get first token
gen = model.generate(tokens)
next(gen) # prefill
next(gen) # decode
# With start_pos=3, the initial tensor should only have tokens[3:] = [4, 5] (length 2)
# If the bug existed (start_pos always reset to 0), it would have all 5 tokens
self.assertEqual(captured_inputs[0][0][-1], 2) # shape should be (1, 2)
self.assertEqual(captured_inputs[0][1], 3) # start_pos should be 3, not 0
# second call extends the conversation — cached prefix should be reused
captured_inputs.clear()
tokens = [1, 2, 3, 4, 5, 42, 42, 10, 11, 12]
gen = model.generate(tokens)
next(gen)
# should only process tokens[7:] = [10, 11, 12] since first 7 are cached
toks_shape = captured_inputs[0][0][-1]
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3)
self.assertEqual(captured_inputs[0][1], 7)
def test_kv_cache_invalidation(self):
"""Test that generate invalidates the KV cache when tokens diverge from the cached prefix."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
captured_inputs = []
def mock_call(self, tokens, start_pos):
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.val))
return Tensor([[42]])
with patch.object(Transformer, '__call__', mock_call):
# first conversation
gen = model.generate([1, 2, 3, 4, 5])
next(gen)
# completely different prompt — KV cache should be invalidated
captured_inputs.clear()
gen = model.generate([10, 20, 30])
next(gen)
# should process all 3 tokens from start
toks_shape = captured_inputs[0][0][-1]
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3)
self.assertEqual(captured_inputs[0][1], 0)
def test_two_prompts_schedule_cache(self):
"""Second prompt prefill should hit the schedule cache, not miss."""
from tinygrad.apps.llm import Transformer
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
# first prompt: prefill + a few decode steps
ids = list(range(1, 6))
gen = model.generate(ids)
for _ in range(3): next(gen)
cache_size_after_first = len(schedule_cache)
# second prompt: simulates multi-turn chat (KV cache prefix is automatically reused)
ids += list(range(10, 15))
gen = model.generate(ids)
for _ in range(3): next(gen)
# the second prompt should reuse the same schedule cache entries, not create new ones
self.assertEqual(cache_size_after_first, len(schedule_cache),
f"second prompt added {len(schedule_cache) - cache_size_after_first} new schedule cache entries (expected 0)")
if __name__ == '__main__':
unittest.main()

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
@@ -144,7 +144,7 @@ class TransformerBlock:
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
# TODO: this if statement should be removed and it shouldn't generate extra kernels
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1)
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
attn = self.attn_output(attn)
@@ -178,6 +178,7 @@ class Transformer:
self.output_norm = nn.RMSNorm(dim, norm_eps)
self.output = nn.Linear(dim, vocab_size, bias=False)
self.max_context = max_context
self._cached_tokens: list[int] = []
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
self.forward_jit = TinyJit(self.forward)
@@ -187,8 +188,7 @@ class Transformer:
# TODO: add temperature
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: return self.forward_jit(tokens, start_pos)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 1))) -> tuple[Transformer, dict]:
@@ -226,15 +226,20 @@ class Transformer:
Tensor.realize(*params)
return model, kv
def generate(self, tokens:list[int], start_pos=0):
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
t = Tensor([tokens[start_pos:]], dtype="int32")
def generate(self, tokens:list[int]):
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
v_toks = UOp.variable("toks", 1, self.max_context)
# assign all input tokens once, then slice from start_pos for the model call
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
# recompute start_pos from what's currently valid in the kv cache
start_pos = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
while len(tokens) < self.max_context:
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
next_id = int(t.item())
tokens.append(next_id)
start_pos = len(tokens) - 1
yield next_id
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(len(tokens) - start_pos)
t[:, sp+nt:sp+nt+1] = out = self(t[:, sp:sp+nt], sp)
start_pos = len(tokens)
tokens.append(int(out.item()))
self._cached_tokens = tokens[:]
yield tokens[-1]
models = {
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
@@ -363,7 +368,7 @@ if __name__ == "__main__":
# do benchmark
if args.benchmark:
gen = model.generate(toks:=[bos_id or 0], 0)
gen = model.generate(toks:=[bos_id or 0])
for _ in range(args.benchmark):
GlobalCounters.reset()
with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
@@ -377,12 +382,11 @@ if __name__ == "__main__":
# interactive chat
ids: list[int] = [bos_id] if bos_id is not None else []
while 1:
start_pos = max(len(ids) - 1, 0)
try:
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant")
except EOFError:
break
for next_id in model.generate(ids, start_pos):
for next_id in model.generate(ids):
sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
sys.stdout.flush()
if next_id == eos_id: break

View File

@@ -1,15 +1,15 @@
import math
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
from tinygrad.helpers import all_int, dedup, get_contraction
from tinygrad.helpers import dedup, get_contraction
from tinygrad.dtype import dtypes, AddrSpace, Invalid
from tinygrad.renderer import Renderer
def _dim_max(d:sint) -> int: return d if isinstance(d, int) else int(d.vmax)
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
# TODO: symbolic shape
if not all_int(dims): return dims
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
for i,m in enumerate(max_sizes):
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
if i < (len(dims)-1) and _dim_max(dims[i]) * _dim_max(dims[i+1]) <= m:
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
break
else: return None

View File

@@ -239,9 +239,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
return None
# passthrough ops
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
return self.src[0]._shape
case Ops.CALL:
inner_shape = self.src[0]._shape
if inner_shape is None: return None
# substitute internal PARAMs in the shape with corresponding args
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
# TODO: disallow shape changing bitcast
case Ops.BITCAST:
ps = self.src[0]._shape
@@ -1415,6 +1421,7 @@ pm_lower_index_dtype = PatternMatcher([
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])])
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
def gate_kernel_sink(x:UOp) -> bool:

View File

@@ -74,8 +74,8 @@ movement_ops = PatternMatcher([
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
# AFTER on Movement Op
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS})),), allow_any_len=True), lambda: True),
# AFTER on Movement Op or ASSIGN
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN})),), allow_any_len=True), lambda: True),
])
_tensor_spec = PatternMatcher([