mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fully symbolic llm (#15097)
* work * llm symbolic (almost) * work * revert that * llm sym * works * cleanups * cache tokens with the kv cache * cleanups * cleanups
This commit is contained in:
@@ -2,7 +2,7 @@ import unittest
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, function
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
|
||||
class TestCall(unittest.TestCase):
|
||||
def test_call_plus(self):
|
||||
@@ -100,6 +100,42 @@ class TestCall(unittest.TestCase):
|
||||
c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1))
|
||||
np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10)))
|
||||
|
||||
class TestCallShape(unittest.TestCase):
|
||||
def test_call_shape_int(self):
|
||||
# fixed-shape function: shape passes through unchanged
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x * 2
|
||||
self.assertEqual(f(Tensor.empty(4, 8)).shape, (4, 8))
|
||||
|
||||
def test_call_shape_param_substitution(self):
|
||||
# symbolic shape dimension is substituted: inner PARAM replaced with the BIND arg
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x * 2
|
||||
sz = UOp.variable("sz", 1, 8)
|
||||
shape = f(Tensor.empty(8)[:sz.bind(5)]).shape
|
||||
# the PARAM should be gone, replaced with the BIND from the call arg
|
||||
self.assertIsInstance(shape[0], UOp)
|
||||
self.assertNotEqual(shape[0].op, Ops.PARAM)
|
||||
self.assertEqual(shape[0], sz.bind(5))
|
||||
|
||||
def test_call_shape_expr_substitution(self):
|
||||
# expression containing PARAMs in shape gets fully substituted
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x + 1
|
||||
sz = UOp.variable("sz", 1, 10)
|
||||
shape = f(Tensor.empty(10, 4)[:sz.bind(3)]).shape
|
||||
self.assertIsInstance(shape[0], UOp)
|
||||
self.assertNotEqual(shape[0].op, Ops.PARAM)
|
||||
self.assertEqual(shape[1], 4)
|
||||
|
||||
def test_call_shape_no_param_passthrough(self):
|
||||
# a non-PARAM UOp shape element passes through unchanged
|
||||
@function
|
||||
def f(x:Tensor) -> Tensor: return x * 3
|
||||
sz = UOp.variable("sz", 1, 8)
|
||||
shape = f(Tensor.empty(8)[:sz.bind(5)]).shape
|
||||
self.assertEqual(shape[0], sz.bind(5))
|
||||
|
||||
class TestCallSchedule(unittest.TestCase):
|
||||
def test_reshape_precompile(self):
|
||||
a = Tensor.empty(4, 8).realize()
|
||||
|
||||
@@ -1,29 +1,84 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, UOp
|
||||
from tinygrad.engine.schedule import schedule_cache
|
||||
|
||||
class TestTransformerGenerate(unittest.TestCase):
|
||||
def test_start_pos_parameter_is_used(self):
|
||||
"""Test that start_pos parameter is not ignored (regression test for always resetting to 0)."""
|
||||
def test_kv_cache_reuse(self):
|
||||
"""Test that generate reuses the KV cache when tokens extend the cached prefix."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
# Create a minimal transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
|
||||
captured_inputs = []
|
||||
def mock_call(self, tokens, start_pos):
|
||||
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.bind_val))
|
||||
return Tensor([[42]]) # return a fake next token
|
||||
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.val))
|
||||
return Tensor([[42]])
|
||||
|
||||
with patch.object(Transformer, '__call__', mock_call):
|
||||
# first conversation: prefill 5 tokens + 1 decode
|
||||
tokens = [1, 2, 3, 4, 5]
|
||||
gen = model.generate(tokens, start_pos=3)
|
||||
next(gen) # get first token
|
||||
gen = model.generate(tokens)
|
||||
next(gen) # prefill
|
||||
next(gen) # decode
|
||||
|
||||
# With start_pos=3, the initial tensor should only have tokens[3:] = [4, 5] (length 2)
|
||||
# If the bug existed (start_pos always reset to 0), it would have all 5 tokens
|
||||
self.assertEqual(captured_inputs[0][0][-1], 2) # shape should be (1, 2)
|
||||
self.assertEqual(captured_inputs[0][1], 3) # start_pos should be 3, not 0
|
||||
# second call extends the conversation — cached prefix should be reused
|
||||
captured_inputs.clear()
|
||||
tokens = [1, 2, 3, 4, 5, 42, 42, 10, 11, 12]
|
||||
gen = model.generate(tokens)
|
||||
next(gen)
|
||||
|
||||
# should only process tokens[7:] = [10, 11, 12] since first 7 are cached
|
||||
toks_shape = captured_inputs[0][0][-1]
|
||||
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3)
|
||||
self.assertEqual(captured_inputs[0][1], 7)
|
||||
|
||||
def test_kv_cache_invalidation(self):
|
||||
"""Test that generate invalidates the KV cache when tokens diverge from the cached prefix."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32)
|
||||
|
||||
captured_inputs = []
|
||||
def mock_call(self, tokens, start_pos):
|
||||
captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.val))
|
||||
return Tensor([[42]])
|
||||
|
||||
with patch.object(Transformer, '__call__', mock_call):
|
||||
# first conversation
|
||||
gen = model.generate([1, 2, 3, 4, 5])
|
||||
next(gen)
|
||||
|
||||
# completely different prompt — KV cache should be invalidated
|
||||
captured_inputs.clear()
|
||||
gen = model.generate([10, 20, 30])
|
||||
next(gen)
|
||||
|
||||
# should process all 3 tokens from start
|
||||
toks_shape = captured_inputs[0][0][-1]
|
||||
self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3)
|
||||
self.assertEqual(captured_inputs[0][1], 0)
|
||||
|
||||
def test_two_prompts_schedule_cache(self):
|
||||
"""Second prompt prefill should hit the schedule cache, not miss."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
|
||||
|
||||
# first prompt: prefill + a few decode steps
|
||||
ids = list(range(1, 6))
|
||||
gen = model.generate(ids)
|
||||
for _ in range(3): next(gen)
|
||||
cache_size_after_first = len(schedule_cache)
|
||||
|
||||
# second prompt: simulates multi-turn chat (KV cache prefix is automatically reused)
|
||||
ids += list(range(10, 15))
|
||||
gen = model.generate(ids)
|
||||
for _ in range(3): next(gen)
|
||||
|
||||
# the second prompt should reuse the same schedule cache entries, not create new ones
|
||||
self.assertEqual(cache_size_after_first, len(schedule_cache),
|
||||
f"second prompt added {len(schedule_cache) - cache_size_after_first} new schedule cache entries (expected 0)")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
||||
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||
@@ -144,7 +144,7 @@ class TransformerBlock:
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
# TODO: this if statement should be removed and it shouldn't generate extra kernels
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1)
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
attn = self.attn_output(attn)
|
||||
@@ -178,6 +178,7 @@ class Transformer:
|
||||
self.output_norm = nn.RMSNorm(dim, norm_eps)
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self._cached_tokens: list[int] = []
|
||||
# JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
|
||||
@@ -187,8 +188,7 @@ class Transformer:
|
||||
# TODO: add temperature
|
||||
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
|
||||
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
|
||||
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: return self.forward_jit(tokens, start_pos)
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 1))) -> tuple[Transformer, dict]:
|
||||
@@ -226,15 +226,20 @@ class Transformer:
|
||||
Tensor.realize(*params)
|
||||
return model, kv
|
||||
|
||||
def generate(self, tokens:list[int], start_pos=0):
|
||||
v_start_pos = UOp.variable("start_pos", 1, self.max_context-1)
|
||||
t = Tensor([tokens[start_pos:]], dtype="int32")
|
||||
def generate(self, tokens:list[int]):
|
||||
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
|
||||
v_toks = UOp.variable("toks", 1, self.max_context)
|
||||
# assign all input tokens once, then slice from start_pos for the model call
|
||||
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
|
||||
# recompute start_pos from what's currently valid in the kv cache
|
||||
start_pos = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
|
||||
while len(tokens) < self.max_context:
|
||||
t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos)
|
||||
next_id = int(t.item())
|
||||
tokens.append(next_id)
|
||||
start_pos = len(tokens) - 1
|
||||
yield next_id
|
||||
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(len(tokens) - start_pos)
|
||||
t[:, sp+nt:sp+nt+1] = out = self(t[:, sp:sp+nt], sp)
|
||||
start_pos = len(tokens)
|
||||
tokens.append(int(out.item()))
|
||||
self._cached_tokens = tokens[:]
|
||||
yield tokens[-1]
|
||||
|
||||
models = {
|
||||
"llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf",
|
||||
@@ -363,7 +368,7 @@ if __name__ == "__main__":
|
||||
|
||||
# do benchmark
|
||||
if args.benchmark:
|
||||
gen = model.generate(toks:=[bos_id or 0], 0)
|
||||
gen = model.generate(toks:=[bos_id or 0])
|
||||
for _ in range(args.benchmark):
|
||||
GlobalCounters.reset()
|
||||
with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s,"
|
||||
@@ -377,12 +382,11 @@ if __name__ == "__main__":
|
||||
# interactive chat
|
||||
ids: list[int] = [bos_id] if bos_id is not None else []
|
||||
while 1:
|
||||
start_pos = max(len(ids) - 1, 0)
|
||||
try:
|
||||
ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant")
|
||||
except EOFError:
|
||||
break
|
||||
for next_id in model.generate(ids, start_pos):
|
||||
for next_id in model.generate(ids):
|
||||
sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n")
|
||||
sys.stdout.flush()
|
||||
if next_id == eos_id: break
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import math
|
||||
from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop
|
||||
from tinygrad.helpers import all_int, dedup, get_contraction
|
||||
from tinygrad.helpers import dedup, get_contraction
|
||||
from tinygrad.dtype import dtypes, AddrSpace, Invalid
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
def _dim_max(d:sint) -> int: return d if isinstance(d, int) else int(d.vmax)
|
||||
|
||||
def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]):
|
||||
# TODO: symbolic shape
|
||||
if not all_int(dims): return dims
|
||||
while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)):
|
||||
for i,m in enumerate(max_sizes):
|
||||
if i < (len(dims)-1) and dims[i] * dims[i+1] <= m:
|
||||
if i < (len(dims)-1) and _dim_max(dims[i]) * _dim_max(dims[i+1]) <= m:
|
||||
dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:]
|
||||
break
|
||||
else: return None
|
||||
|
||||
@@ -239,9 +239,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return None
|
||||
|
||||
# passthrough ops
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL:
|
||||
case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END:
|
||||
return self.src[0]._shape
|
||||
|
||||
case Ops.CALL:
|
||||
inner_shape = self.src[0]._shape
|
||||
if inner_shape is None: return None
|
||||
# substitute internal PARAMs in the shape with corresponding args
|
||||
return tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape)
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
case Ops.BITCAST:
|
||||
ps = self.src[0]._shape
|
||||
@@ -1415,6 +1421,7 @@ pm_lower_index_dtype = PatternMatcher([
|
||||
def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0]
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])])
|
||||
_remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)])
|
||||
|
||||
def gate_kernel_sink(x:UOp) -> bool:
|
||||
|
||||
@@ -74,8 +74,8 @@ movement_ops = PatternMatcher([
|
||||
(UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True),
|
||||
(UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True),
|
||||
|
||||
# AFTER on Movement Op
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS})),), allow_any_len=True), lambda: True),
|
||||
# AFTER on Movement Op or ASSIGN
|
||||
(UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN})),), allow_any_len=True), lambda: True),
|
||||
])
|
||||
|
||||
_tensor_spec = PatternMatcher([
|
||||
|
||||
Reference in New Issue
Block a user