mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
fast tinygrad.apps.llm (#13685)
* llm: add --benchmark support * fix speed * debug logging * fix test attention
This commit is contained in:
@@ -1,8 +1,14 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor, dtypes, TinyJit, UOp
|
||||
from tinygrad.apps.llm import apply_rope
|
||||
from tinygrad.apps.llm import apply_rope as apply_rope_new, precompute_freqs_cis
|
||||
#from tinygrad.engine.realize import run_schedule
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int):
|
||||
B, H, T, Hd = x.shape
|
||||
precompute_freqs_cis.cache_clear()
|
||||
freqs_cis = precompute_freqs_cis(Hd, start_pos+T)[start_pos:start_pos+T]
|
||||
return apply_rope_new(x, freqs_cis)
|
||||
|
||||
# TODO: test_scheduler, but just in uint
|
||||
class TestAttention(unittest.TestCase):
|
||||
def test_half_qkv_buffers(self):
|
||||
@@ -39,7 +45,7 @@ class TestAttention(unittest.TestCase):
|
||||
prune_size = len(rope_prune.captured.jit_cache)
|
||||
|
||||
self.assertGreater(noprune_size, prune_size)
|
||||
self.assertGreaterEqual(noprune_size, 3)
|
||||
self.assertGreaterEqual(noprune_size, 2)
|
||||
self.assertEqual(prune_size, 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG, Timing, GlobalCounters
|
||||
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
|
||||
@@ -52,15 +52,18 @@ class SimpleTokenizer:
|
||||
def decode(self, ids:list[int]) -> str: return b''.join(self._tok2bytes[tid] for tid in ids).decode()
|
||||
def role(self, role:str): return self.encode("<|start_header_id|>" + role + "<|end_header_id|>\n\n")
|
||||
|
||||
def apply_rope(x:Tensor, start_pos:int|UOp, base:float = 10000.0) -> Tensor:
|
||||
@functools.cache
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> Tensor:
|
||||
freqs = 1.0 / (theta ** (Tensor.arange(0, dim, 2)[:(dim // 2)] / dim))
|
||||
freqs = Tensor.arange(end).unsqueeze(dim=1) * freqs.unsqueeze(dim=0)
|
||||
return Tensor.stack(freqs.cos(), freqs.sin(), dim=-1).contiguous()
|
||||
|
||||
def apply_rope(x:Tensor, freqs_cis:Tensor) -> Tensor:
|
||||
B, H, T, Hd = x.shape
|
||||
assert isinstance(Hd, int) and (Hd & 1) == 0, "RoPE requires an even head dimension"
|
||||
half = Hd // 2
|
||||
t_start_pos = start_pos if isinstance(start_pos, int) else Tensor(start_pos)
|
||||
angles = (Tensor.arange(T, dtype="float32") + t_start_pos)[:, None] * (base ** (-(Tensor.arange(half, dtype="float32") / half)))[None, :]
|
||||
# contiguous here allows RoPE to be pruned in the JIT
|
||||
cos, sin = angles.cos().reshape(1, 1, T, half).cast(x.dtype).contiguous(), angles.sin().reshape(1, 1, T, half).cast(x.dtype).contiguous()
|
||||
x_pairs = x.reshape(B, H, T, half, 2)
|
||||
x_pairs = x.reshape(B, H, T, Hd//2, 2)
|
||||
cos = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 0]
|
||||
sin = freqs_cis.reshape(1, 1, T, Hd//2, 2)[..., 1]
|
||||
return Tensor.stack(x_pairs[..., 0] * cos - x_pairs[..., 1] * sin,
|
||||
x_pairs[..., 0] * sin + x_pairs[..., 1] * cos, dim=-1).reshape(B, H, T, Hd)
|
||||
|
||||
@@ -96,8 +99,10 @@ class TransformerBlock:
|
||||
k = k.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
v = v.reshape(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) # (B,KvH,T,Hd)
|
||||
|
||||
q = apply_rope(q, start_pos)
|
||||
k = apply_rope(k, start_pos)
|
||||
# TODO: make UOp have SupportsIndex
|
||||
freqs_cis = precompute_freqs_cis(self.head_dim, self.max_context)[start_pos:start_pos+T] # type: ignore
|
||||
q = apply_rope(q, freqs_cis)
|
||||
k = apply_rope(k, freqs_cis)
|
||||
|
||||
# TODO: remove these kv cache realizes
|
||||
if not hasattr(self, "cache_kv"):
|
||||
@@ -115,7 +120,8 @@ class TransformerBlock:
|
||||
|
||||
def _feed_forward(self, h: Tensor) -> Tensor:
|
||||
h_norm = self.ffn_norm(h)
|
||||
gated = self.ffn_gate(h_norm).silu() * self.ffn_up(h_norm)
|
||||
# TODO: remove the need for this contiguous
|
||||
gated = self.ffn_gate(h_norm).silu().contiguous() * self.ffn_up(h_norm)
|
||||
return h + self.ffn_down(gated)
|
||||
|
||||
def __call__(self, x: Tensor, start_pos: int|UOp):
|
||||
@@ -185,24 +191,27 @@ models = {
|
||||
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
|
||||
|
||||
class Handler(HTTPRequestHandler):
|
||||
def log_request(self, code='-', size='-'): pass
|
||||
def run_model(self, ids:list[int], model_name:str, include_usage=False):
|
||||
stderr_log(f"{self.path} {colored('--', 'BLACK')} in:{len(ids):5d} {colored('--', 'BLACK')} ")
|
||||
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
|
||||
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
|
||||
out = []
|
||||
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
|
||||
out: list[int] = []
|
||||
st = time.perf_counter()
|
||||
for next_id in model.generate(ids):
|
||||
if len(out) == 0: stderr_log(f"prefill:{len(ids)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
|
||||
if next_id == eos_id: break
|
||||
out.append(next_id)
|
||||
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
|
||||
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
|
||||
if include_usage:
|
||||
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
|
||||
stderr_log(f"out:{len(out):5d} {colored('--', 'BLACK')} gen: {len(out)/(time.perf_counter()-pt):4.0f} tok/s\n")
|
||||
|
||||
def do_POST(self):
|
||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
||||
if DEBUG >= 1:
|
||||
print(self.path)
|
||||
print(json.dumps(body, indent=2))
|
||||
if DEBUG >= 1: print(json.dumps(body, indent=2))
|
||||
if self.path == "/v1/chat/completions":
|
||||
# extract tokens
|
||||
ids = [bos_id]
|
||||
|
||||
@@ -149,6 +149,10 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
|
||||
def temp(x:str, append_user:bool=False) -> str:
|
||||
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
||||
|
||||
def stderr_log(msg):
|
||||
sys.stderr.write(msg)
|
||||
sys.stderr.flush()
|
||||
|
||||
class Context(contextlib.ContextDecorator):
|
||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||
def __enter__(self):
|
||||
|
||||
Reference in New Issue
Block a user