mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
LLM speedup with two jits, prefill/rollout (#15153)
* START_TIME * print cleanup * fix tests
This commit is contained in:
@@ -60,25 +60,28 @@ class TestTransformerGenerate(unittest.TestCase):
|
||||
self.assertEqual(captured_inputs[0][1], 0)
|
||||
|
||||
def test_two_prompts_schedule_cache(self):
|
||||
"""Second prompt prefill should hit the schedule cache, not miss."""
|
||||
"""Third prompt should hit the schedule cache, not miss (first two warm up both jits: prefill + decode)."""
|
||||
from tinygrad.apps.llm import Transformer
|
||||
model = Transformer(num_blocks=1, dim=64, hidden_dim=128, n_heads=2, n_kv_heads=2,
|
||||
norm_eps=1e-5, vocab_size=100, head_dim=32, rope_theta=10000.0, max_context=64)
|
||||
|
||||
# first prompt: prefill + a few decode steps
|
||||
# first two prompts warm up both jits (prefill + decode)
|
||||
ids = list(range(1, 6))
|
||||
gen = model.generate(ids)
|
||||
for _ in range(3): next(gen)
|
||||
cache_size_after_first = len(schedule_cache)
|
||||
|
||||
# second prompt: simulates multi-turn chat (KV cache prefix is automatically reused)
|
||||
ids += list(range(10, 15))
|
||||
gen = model.generate(ids)
|
||||
for _ in range(3): next(gen)
|
||||
cache_size_after_warmup = len(schedule_cache)
|
||||
|
||||
# the second prompt should reuse the same schedule cache entries, not create new ones
|
||||
self.assertEqual(cache_size_after_first, len(schedule_cache),
|
||||
f"second prompt added {len(schedule_cache) - cache_size_after_first} new schedule cache entries (expected 0)")
|
||||
# third prompt should reuse the same schedule cache entries, not create new ones
|
||||
ids += list(range(20, 25))
|
||||
gen = model.generate(ids)
|
||||
for _ in range(3): next(gen)
|
||||
|
||||
self.assertEqual(cache_size_after_warmup, len(schedule_cache),
|
||||
f"third prompt added {len(schedule_cache) - cache_size_after_warmup} new schedule cache entries (expected 0)")
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools, itertools
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, function
|
||||
from tinygrad.uop.ops import resolve
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored, Context
|
||||
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||
|
||||
@@ -144,7 +145,7 @@ class TransformerBlock:
|
||||
|
||||
# NOTE: this mask is causal_lower_right, not the causal_upper_left generated by is_casual = True
|
||||
# TODO: this if statement should be removed and it shouldn't generate extra kernels
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1)
|
||||
mask = Tensor.full((1, 1, T, start_pos+T), float("-inf"), dtype=x.dtype, device=x.device).triu(start_pos+1) if resolve(T != 1) else None
|
||||
attn = q.scaled_dot_product_attention(k, v, attn_mask=mask, enable_gqa=True) # (B,H,T,Hd)
|
||||
attn = attn.transpose(1, 2).reshape(B, T, -1) # back to (B,T,D)
|
||||
attn = self.attn_output(attn)
|
||||
@@ -179,7 +180,9 @@ class Transformer:
|
||||
self.output = nn.Linear(dim, vocab_size, bias=False)
|
||||
self.max_context = max_context
|
||||
self._cached_tokens: list[int] = []
|
||||
self.forward_jit = TinyJit(self.forward)
|
||||
# we specialize the JIT for prefill and rollout
|
||||
self.prefill_jit = TinyJit(self.forward)
|
||||
self.rollout_jit = TinyJit(self.forward)
|
||||
|
||||
def forward(self, tokens:Tensor, start_pos:int|UOp) -> Tensor:
|
||||
x = self.token_embd(tokens) # (B, T, D)
|
||||
@@ -187,7 +190,8 @@ class Transformer:
|
||||
# TODO: add temperature
|
||||
return self.output(self.output_norm(x))[:, -1, :].softmax(-1, dtype="float").argmax(-1, keepdim=True)
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor: return self.forward_jit(tokens, start_pos)
|
||||
def __call__(self, tokens:Tensor, start_pos:int|UOp=0) -> Tensor:
|
||||
return (self.prefill_jit if resolve(tokens.shape[1] != 1) else self.rollout_jit)(tokens, start_pos)
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=bool(getenv("REALIZE", 0))) -> tuple[Transformer, dict]:
|
||||
@@ -226,7 +230,7 @@ class Transformer:
|
||||
return model, kv
|
||||
|
||||
def get_start_pos(self, tokens:list[int]):
|
||||
return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens, self._cached_tokens)))
|
||||
return sum(1 for _ in itertools.takewhile(lambda ab: ab[0] == ab[1], zip(tokens[:-1], self._cached_tokens)))
|
||||
|
||||
def generate(self, tokens:list[int], chunk_size:int=32):
|
||||
v_start_pos = UOp.variable("start_pos", 0, self.max_context-1)
|
||||
@@ -235,15 +239,13 @@ class Transformer:
|
||||
t = Tensor(tokens + [0] * (self.max_context - len(tokens)), dtype="int32").reshape(1, self.max_context)
|
||||
# recompute start_pos from what's currently valid in the kv cache
|
||||
start_pos = self.get_start_pos(tokens)
|
||||
out = None
|
||||
while len(tokens) < self.max_context:
|
||||
sp, nt = v_start_pos.bind(start_pos), v_toks.bind(min(chunk_size, len(tokens) - start_pos))
|
||||
out = self(t[:, sp:sp+nt], sp)
|
||||
out = self(t[:, sp:sp+nt] if out is None else out, sp).realize()
|
||||
start_pos += nt.val
|
||||
# chunked prefill: keep processing until all prompt tokens are consumed
|
||||
if start_pos < len(tokens):
|
||||
out.realize()
|
||||
continue
|
||||
t[:, sp+nt:sp+nt+1] = out
|
||||
if start_pos < len(tokens): continue
|
||||
tokens.append(int(out.item()))
|
||||
self._cached_tokens = tokens[:]
|
||||
yield tokens[-1]
|
||||
@@ -275,7 +277,7 @@ CHAT_HTML = b'''<!DOCTYPE html><html><head><title>tinygrad chat</title><style>
|
||||
background: #2f2f2f; color: inherit; font: inherit;
|
||||
border: none; outline: none; resize: none; border-radius: 24px; field-sizing: content }
|
||||
</style></head><body><div id="chat"></div>
|
||||
<textarea id="input" rows="1" placeholder="Ask anything"></textarea>
|
||||
<textarea id="input" rows="1" placeholder="Ask anything" autofocus></textarea>
|
||||
<script>
|
||||
input.onkeydown = (e) => { if (e.key === 'Enter' && !e.shiftKey && !e.isComposing) { e.preventDefault(); send() } }
|
||||
const msgs = [];
|
||||
@@ -318,7 +320,7 @@ class Handler(HTTPRequestHandler):
|
||||
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
|
||||
if include_usage:
|
||||
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
|
||||
stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n")
|
||||
stderr_log(f"gen:{len(out)/(time.perf_counter()-pt):4.0f} tok/s {colored('--', 'BLACK')} out:{len(out):5d}\n")
|
||||
|
||||
def do_POST(self):
|
||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||
@@ -387,8 +389,9 @@ if __name__ == "__main__":
|
||||
|
||||
# start server
|
||||
if args.serve:
|
||||
# warmup: run 2 tokens through the model to capture the JIT before serving
|
||||
with Context(DEBUG=max(DEBUG.value, 1)): list(zip(range(2), model.generate([0])))
|
||||
# warmup: run 2 tokens through the model twice to capture the JIT before serving
|
||||
with Context(DEBUG=max(DEBUG.value, 1)):
|
||||
for _ in range(2): list(zip(range(2), model.generate([0])))
|
||||
TCPServerWithReuse(('', args.serve), Handler).serve_forever()
|
||||
|
||||
# interactive chat
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
|
||||
import time
|
||||
START_TIME = time.perf_counter()
|
||||
import os, functools, platform, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
|
||||
from collections import defaultdict
|
||||
import subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -7,7 +7,7 @@ from urllib.parse import parse_qs, urlparse
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.helpers import printable, Context
|
||||
from tinygrad.helpers import printable, Context, START_TIME
|
||||
from tinygrad.renderer.amd.dsl import Inst
|
||||
from tinygrad.renderer.amd import detect_format
|
||||
|
||||
@@ -15,7 +15,7 @@ from tinygrad.renderer.amd import detect_format
|
||||
class TCPServerWithReuse(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
def __init__(self, server_address, RequestHandlerClass):
|
||||
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
|
||||
print(f"*** started server on http://127.0.0.1:{server_address[1]} at {time.perf_counter()-START_TIME:.2f} s")
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
|
||||
Reference in New Issue
Block a user