diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index a55117fa90..3148bc8a6a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -289,8 +289,8 @@ jobs: python extra/optimization/extract_dataset.py gzip -c /tmp/sops > extra/datasets/sops.gz #DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py - - name: Repo line count < 19150 lines - run: MAX_LINE_COUNT=19150 python sz.py + - name: Repo line count < 20000 lines + run: MAX_LINE_COUNT=20000 python sz.py spec: strategy: diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index df0d6d6db7..89b64629e0 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,6 +1,7 @@ from __future__ import annotations -import sys, argparse, typing, re, unicodedata -from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers +import sys, argparse, typing, re, unicodedata, json, uuid +from tinygrad import Tensor, nn, UOp, TinyJit, getenv +from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG class SimpleTokenizer: def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]): @@ -24,7 +25,7 @@ class SimpleTokenizer: # https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820 if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'") vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"])) - normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1) + normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1) return SimpleTokenizer(dict(normal_tokens), dict(special_tokens)) def _encode_word(self, word:bytes) -> list[int]: @@ -139,7 +140,7 @@ class Transformer: return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos) @staticmethod - def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]: + def from_gguf(gguf:Tensor, max_context:int|None=None, realize=True) -> tuple[Transformer, dict]: # TODO: remove the need for copy to default device kv, state_dict = nn.state.gguf_load(gguf.to(None)) @@ -156,7 +157,8 @@ class Transformer: norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context) nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused # NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster - for s in nn.state.get_parameters(model): s.replace(s.contiguous()) + for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous()) + if realize: Tensor.realize(*params) return model, kv def generate(self, tokens:list[int], start_pos=0): @@ -178,10 +180,55 @@ models = { "8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf", } +# *** simple OpenAI compatible server on 11434 to match ollama *** +# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt + +class Handler(HTTPRequestHandler): + def run_model(self, ids:list[int], include_usage=False): + tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk"} + yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl} + out = [] + for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1): + if next_id == eos_id: break + out.append(next_id) + yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl} + yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl} + if include_usage: + yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}} + + def do_POST(self): + raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0"))) + body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8")) + if DEBUG >= 1: + print(self.path) + print(json.dumps(body, indent=2)) + if self.path == "/v1/chat/completions": + assert body["stream"], "we only support stream mode" + + # extract tokens + ids = [bos_id] + for msg in body["messages"]: + ids += tok.role(msg["role"]) + # content can be a str or a list + content = msg["content"] + if isinstance(content, str): ids += tok.encode(content) + elif isinstance(content, list): + for c in content: + if c["type"] == "text": ids += tok.encode(c["text"]) + else: raise RuntimeError(f"unhandled type: {c['type']}") + else: raise RuntimeError(f"unknown content type: {type(content)}") + ids += tok.role("assistant") + + # stream reply + self.stream_json(self.run_model(ids, include_usage=body.get("stream_options",{}).get("include_usage", False))) + else: + raise RuntimeError(f"unhandled path {self.path}") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--size", choices=list(models.keys()), default=list(models.keys())[0], help="Model size") parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length") + parser.add_argument("--serve", action="store_true", help="Run OpenAI compatible API") args = parser.parse_args() # load the model @@ -192,6 +239,9 @@ if __name__ == "__main__": bos_id: int = kv['tokenizer.ggml.bos_token_id'] eos_id: int = kv['tokenizer.ggml.eos_token_id'] + # start server + if args.serve: TCPServerWithReuse(('', 11434), Handler).serve_forever() + ids: list[int] = [bos_id] while 1: start_pos = len(ids) - 1 diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index ff81940ad4..61974d5b4e 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,8 +1,9 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc -import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools +import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools, socketserver, json from dataclasses import dataclass, field from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator, cast, overload +from http.server import BaseHTTPRequestHandler T = TypeVar("T") U = TypeVar("U") @@ -404,6 +405,27 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}") return fp +# NOTE: using HTTPServer forces a potentially slow socket.getfqdn +class TCPServerWithReuse(socketserver.TCPServer): + allow_reuse_address = True + def __init__(self, server_address, RequestHandlerClass): + print(f"*** started server on http://127.0.0.1:{server_address[1]}") + super().__init__(server_address, RequestHandlerClass) + +class HTTPRequestHandler(BaseHTTPRequestHandler): + def stream_json(self, source:Generator): + try: + self.send_response(200) + self.send_header("Content-Type", "text/event-stream") + self.send_header("Cache-Control", "no-cache") + self.end_headers() + for r in source: + self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8")) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode("utf-8")) + # pass if client closed connection + except (BrokenPipeError, ConnectionResetError): return + # *** Exec helpers def system(cmd:str, **kwargs) -> str: diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 0751d67863..5576e4879d 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -1,13 +1,12 @@ #!/usr/bin/env python3 -import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct +import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct import ctypes, pathlib, traceback, itertools from contextlib import redirect_stdout, redirect_stderr, contextmanager from decimal import Decimal -from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse from typing import Any, TypedDict, TypeVar, Generator, Callable from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp -from tinygrad.helpers import printable, system +from tinygrad.helpers import printable, system, TCPServerWithReuse, HTTPRequestHandler from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender from tinygrad.uop.ops import print_uops, range_start, multirange_str from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device @@ -431,7 +430,7 @@ def get_render(i:int, j:int, fmt:str) -> dict: def get_int(query:dict[str, list[str]], k:str) -> int: return int(query.get(k,["0"])[0]) -class Handler(BaseHTTPRequestHandler): +class Handler(HTTPRequestHandler): def do_GET(self): ret, status_code, content_type = b"", 200, "text/html" @@ -465,19 +464,6 @@ class Handler(BaseHTTPRequestHandler): self.end_headers() return self.wfile.write(ret) - def stream_json(self, source:Generator): - try: - self.send_response(200) - self.send_header("Content-Type", "text/event-stream") - self.send_header("Cache-Control", "no-cache") - self.end_headers() - for r in source: - self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8")) - self.wfile.flush() - self.wfile.write("data: END\n\n".encode("utf-8")) - # pass if client closed connection - except (BrokenPipeError, ConnectionResetError): return - # ** main loop def reloader(): @@ -493,9 +479,6 @@ def load_pickle(path:pathlib.Path, default:T) -> T: if not path.exists(): return default with path.open("rb") as f: return pickle.load(f) -# NOTE: using HTTPServer forces a potentially slow socket.getfqdn -class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True))) @@ -516,7 +499,6 @@ if __name__ == "__main__": server = TCPServerWithReuse(('', PORT), Handler) reloader_thread = threading.Thread(target=reloader) reloader_thread.start() - print(f"*** started viz on {HOST}:{PORT}") print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True) if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}") try: server.serve_forever()