diff --git a/test/unit/test_call.py b/test/unit/test_call.py index 4fee7fd6eb..a38e69c90e 100644 --- a/test/unit/test_call.py +++ b/test/unit/test_call.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad import Tensor, function from tinygrad.dtype import dtypes -from tinygrad.uop.ops import UOp +from tinygrad.uop.ops import UOp, Ops class TestCall(unittest.TestCase): def test_call_plus(self): @@ -100,6 +100,42 @@ class TestCall(unittest.TestCase): c = Tensor.call(a, b, fxn=a.as_param(0) + b.as_param(1)) np.testing.assert_equal(c.numpy(), 2 * np.ones((10, 10))) +class TestCallShape(unittest.TestCase): + def test_call_shape_int(self): + # fixed-shape function: shape passes through unchanged + @function + def f(x:Tensor) -> Tensor: return x * 2 + self.assertEqual(f(Tensor.empty(4, 8)).shape, (4, 8)) + + def test_call_shape_param_substitution(self): + # symbolic shape dimension is substituted: inner PARAM replaced with the BIND arg + @function + def f(x:Tensor) -> Tensor: return x * 2 + sz = UOp.variable("sz", 1, 8) + shape = f(Tensor.empty(8)[:sz.bind(5)]).shape + # the PARAM should be gone, replaced with the BIND from the call arg + self.assertIsInstance(shape[0], UOp) + self.assertNotEqual(shape[0].op, Ops.PARAM) + self.assertEqual(shape[0], sz.bind(5)) + + def test_call_shape_expr_substitution(self): + # expression containing PARAMs in shape gets fully substituted + @function + def f(x:Tensor) -> Tensor: return x + 1 + sz = UOp.variable("sz", 1, 10) + shape = f(Tensor.empty(10, 4)[:sz.bind(3)]).shape + self.assertIsInstance(shape[0], UOp) + self.assertNotEqual(shape[0].op, Ops.PARAM) + self.assertEqual(shape[1], 4) + + def test_call_shape_no_param_passthrough(self): + # a non-PARAM UOp shape element passes through unchanged + @function + def f(x:Tensor) -> Tensor: return x * 3 + sz = UOp.variable("sz", 1, 8) + shape = f(Tensor.empty(8)[:sz.bind(5)]).shape + self.assertEqual(shape[0], sz.bind(5)) + class TestCallSchedule(unittest.TestCase): def test_reshape_precompile(self): a = Tensor.empty(4, 8).realize() diff --git a/test/unit/test_llm_server.py b/test/unit/test_llm_server.py index bf10ab6e7f..c43fa86f77 100644 --- a/test/unit/test_llm_server.py +++ b/test/unit/test_llm_server.py @@ -1,29 +1,84 @@ import unittest from unittest.mock import patch -from tinygrad import Tensor +from tinygrad import Tensor, UOp +from tinygrad.engine.schedule import schedule_cache class TestTransformerGenerate(unittest.TestCase): - def test_start_pos_parameter_is_used(self): - """Test that start_pos parameter is not ignored (regression test for always resetting to 0).""" + def test_kv_cache_reuse(self): + """Test that generate reuses the KV cache when tokens extend the cached prefix.""" from tinygrad.apps.llm import Transformer - # Create a minimal transformer model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2, norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32) captured_inputs = [] def mock_call(self, tokens, start_pos): - captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.bind_val)) - return Tensor([[42]]) # return a fake next token + captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.val)) + return Tensor([[42]]) with patch.object(Transformer, '__call__', mock_call): + # first conversation: prefill 5 tokens + 1 decode tokens = [1, 2, 3, 4, 5] - gen = model.generate(tokens, start_pos=3) - next(gen) # get first token + gen = model.generate(tokens) + next(gen) # prefill + next(gen) # decode - # With start_pos=3, the initial tensor should only have tokens[3:] = [4, 5] (length 2) - # If the bug existed (start_pos always reset to 0), it would have all 5 tokens - self.assertEqual(captured_inputs[0][0][-1], 2) # shape should be (1, 2) - self.assertEqual(captured_inputs[0][1], 3) # start_pos should be 3, not 0 + # second call extends the conversation — cached prefix should be reused + captured_inputs.clear() + tokens = [1, 2, 3, 4, 5, 42, 42, 10, 11, 12] + gen = model.generate(tokens) + next(gen) + + # should only process tokens[7:] = [10, 11, 12] since first 7 are cached + toks_shape = captured_inputs[0][0][-1] + self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3) + self.assertEqual(captured_inputs[0][1], 7) + + def test_kv_cache_invalidation(self): + """Test that generate invalidates the KV cache when tokens diverge from the cached prefix.""" + from tinygrad.apps.llm import Transformer + model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2, + norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=32) + + captured_inputs = [] + def mock_call(self, tokens, start_pos): + captured_inputs.append((tokens.shape, start_pos if isinstance(start_pos, int) else start_pos.val)) + return Tensor([[42]]) + + with patch.object(Transformer, '__call__', mock_call): + # first conversation + gen = model.generate([1, 2, 3, 4, 5]) + next(gen) + + # completely different prompt — KV cache should be invalidated + captured_inputs.clear() + gen = model.generate([10, 20, 30]) + next(gen) + + # should process all 3 tokens from start + toks_shape = captured_inputs[0][0][-1] + self.assertEqual(toks_shape.val if isinstance(toks_shape, UOp) else toks_shape, 3) + self.assertEqual(captured_inputs[0][1], 0) + + def test_two_prompts_schedule_cache(self): + """Second prompt prefill should hit the schedule cache, not miss.""" + from tinygrad.apps.llm import Transformer + model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2, + norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64) + + # first prompt: prefill + a few decode steps + ids = list(range(1, 6)) + gen = model.generate(ids) + for _ in range(3): next(gen) + cache_size_after_first = len(schedule_cache) + + # second prompt: simulates multi-turn chat (KV cache prefix is automatically reused) + ids += list(range(10, 15)) + gen = model.generate(ids) + for _ in range(3): next(gen) + + # the second prompt should reuse the same schedule cache entries, not create new ones + self.assertEqual(cache_size_after_first, len(schedule_cache), + f"second prompt added {len(schedule_cache) - cache_size_after_first} new schedule cache entries (expected 0)") if __name__ == '__main__': unittest.main() diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index cf773546d6..897e846de1 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,5 +1,5 @@ from __future__ import annotations -import sys, argparse, typing, re, unicodedata, json, uuid, time, functools +import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler @@ -144,7 +144,7 @@ class TransformerBlock: # NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True # TODO: this if statement should be removed and it shouldn't generate extra kernels - mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if T > 1 else None + mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd) attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D) attn = self.attn_output(attn) @@ -178,6 +178,7 @@ class Transformer: self.output_norm = nn.RMSNorm(dim, norm_eps) self.output = nn.Linear(dim, vocab_size, bias=False) self.max_context = max_context + self._cached_tokens: list[int] = [] # JIT is used if T=1 and start_pos is a UOp. TODO: make this not needed by including T in the JIT and making start_pos always a UOp self.forward_jit = TinyJit(self.forward) @@ -187,8 +188,7 @@ class Transformer: # TODO: add temperature return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True) - def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: - return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos) + def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: return self.forward_jit(tokens, start_pos) @staticmethod def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 1))) -> tuple[Transformer, dict]: @@ -226,15 +226,20 @@ class Transformer: Tensor.realize(*params) return model, kv - def generate(self, tokens:list[int], start_pos=0): - v_start_pos = UOp.variable("start_pos", 1, self.max_context-1) - t = Tensor([tokens[start_pos:]], dtype="int32") + def generate(self, tokens:list[int]): + v_start_pos = UOp.variable("start_pos", 0, self.max_context-1) + v_toks = UOp.variable("toks", 1, self.max_context) + # assign all input tokens once, then slice from start_pos for the model call + t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context) + # recompute start_pos from what's currently valid in the kv cache + start_pos = sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens))) while len(tokens) < self.max_context: - t = self(t, v_start_pos.bind(start_pos) if getenv("SYM", 1) and start_pos != 0 and t.shape[-1] == 1 else start_pos) - next_id = int(t.item()) - tokens.append(next_id) - start_pos = len(tokens) - 1 - yield next_id + sp, nt = v_start_pos.bind(start_pos), v_toks.bind(len(tokens) - start_pos) + t[:, sp+nt:sp+nt+1] = out = self(t[:, sp:sp+nt], sp) + start_pos = len(tokens) + tokens.append(int(out.item())) + self._cached_tokens = tokens[:] + yield tokens[-1] models = { "llama3.2:1b": "https://huggingface.co/bartowski/Llama-3.2-1B-Instruct-GGUF/resolve/main/Llama-3.2-1B-Instruct-Q6_K.gguf", @@ -363,7 +368,7 @@ if __name__ == "__main__": # do benchmark if args.benchmark: - gen = model.generate(toks:=[bos_id or 0], 0) + gen = model.generate(toks:=[bos_id or 0]) for _ in range(args.benchmark): GlobalCounters.reset() with Timing(on_exit=lambda x: f", {1e9/x:6.2f} tok/s, {GlobalCounters.global_mem/x:7.2f} GB/s," @@ -377,12 +382,11 @@ if __name__ == "__main__": # interactive chat ids: list[int] = [bos_id] if bos_id is not None else [] while 1: - start_pos = max(len(ids) - 1, 0) try: ids += tok.role("user") + tok.encode(input('>>> ')) + tok.end_turn(eos_id) + tok.role("assistant") except EOFError: break - for next_id in model.generate(ids, start_pos): + for next_id in model.generate(ids): sys.stdout.write(tok.decode([next_id]) if next_id != eos_id else "\n\n") sys.stdout.flush() if next_id == eos_id: break diff --git a/tinygrad/codegen/gpudims.py b/tinygrad/codegen/gpudims.py index 11fa8b7c72..91677ef001 100644 --- a/tinygrad/codegen/gpudims.py +++ b/tinygrad/codegen/gpudims.py @@ -1,15 +1,15 @@ import math from tinygrad.uop.ops import UOp, Ops, sint, PatternMatcher, UPat, KernelInfo, ssimplify, AxisType, sint_to_uop -from tinygrad.helpers import all_int, dedup, get_contraction +from tinygrad.helpers import dedup, get_contraction from tinygrad.dtype import dtypes, AddrSpace, Invalid from tinygrad.renderer import Renderer +def _dim_max(d:sint) -> int: return d if isinstance(d, int) else int(d.vmax) + def _group_dims(dims:tuple[sint, ...], max_sizes:tuple[int, ...]): - # TODO: symbolic shape - if not all_int(dims): return dims while len(dims) > len(max_sizes) or any(d > m for d,m in zip(dims, max_sizes)): for i,m in enumerate(max_sizes): - if i < (len(dims)-1) and dims[i] * dims[i+1] <= m: + if i < (len(dims)-1) and _dim_max(dims[i]) * _dim_max(dims[i+1]) <= m: dims = dims[:i] + (dims[i]*dims[i+1],) + dims[i+2:] break else: return None diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 31b19f4262..c471ac66b1 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -239,9 +239,15 @@ class UOp(OpMixin, metaclass=UOpMetaClass): return None # passthrough ops - case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END | Ops.CALL: + case Ops.REDUCE | Ops.MSTACK | Ops.MSELECT | Ops.DETACH | Ops.CONTIGUOUS | Ops.CONTIGUOUS_BACKWARD | Ops.AFTER | Ops.END: return self.src[0]._shape + case Ops.CALL: + inner_shape = self.src[0]._shape + if inner_shape is None: return None + # substitute internal PARAMs in the shape with corresponding args + return tuple(graph_rewrite(s, _pm_resolve_params, self.src[1:], walk=True) if isinstance(s, UOp) else s for s in inner_shape) + # TODO: disallow shape changing bitcast case Ops.BITCAST: ps = self.src[0]._shape @@ -1415,6 +1421,7 @@ pm_lower_index_dtype = PatternMatcher([ def _index_to_concrete_int(u:UOp) -> UOp: return graph_rewrite(u.sink(), pm_lower_index_dtype).src[0] _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) +_pm_resolve_params = PatternMatcher([(UPat(Ops.PARAM, name="p"), lambda ctx,p: ctx[p.arg])]) _remove_all_tags = PatternMatcher([(UPat(GroupOp.All, name="x"), lambda x: x.replace(tag=None) if x.tag is not None else None)]) def gate_kernel_sink(x:UOp) -> bool: diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index c7f18f0794..971b81763e 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -74,8 +74,8 @@ movement_ops = PatternMatcher([ (UPat((Ops.VECTORIZE, Ops.VCONST), dtype=dtypes.index), lambda: True), (UPat({Ops.ADD, Ops.MUL, Ops.IDIV}, dtype=dtypes.index), lambda: True), - # AFTER on Movement Op - (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS})),), allow_any_len=True), lambda: True), + # AFTER on Movement Op or ASSIGN + (UPat(Ops.AFTER, src=(UPat(GroupOp.Movement.union({Ops.MULTI, Ops.CONTIGUOUS, Ops.ASSIGN})),), allow_any_len=True), lambda: True), ]) _tensor_spec = PatternMatcher([