openai api for llm (#13648)

* openai api for llm

* responds to simple request

* schedule cache needs to unbind

* stream works

* share stream code

* 20k

* one print

* cid
This commit is contained in:
George Hotz
2025-12-12 08:25:33 -05:00
committed by GitHub
parent 93ad1f7732
commit f0fa9bcd98
4 changed files with 83 additions and 29 deletions

View File

@@ -289,8 +289,8 @@ jobs:
python extra/optimization/extract_dataset.py
gzip -c /tmp/sops > extra/datasets/sops.gz
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
- name: Repo line count < 19150 lines
run: MAX_LINE_COUNT=19150 python sz.py
- name: Repo line count < 20000 lines
run: MAX_LINE_COUNT=20000 python sz.py
spec:
strategy:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers
import sys, argparse, typing, re, unicodedata, json, uuid
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
@@ -24,7 +25,7 @@ class SimpleTokenizer:
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
def _encode_word(self, word:bytes) -> list[int]:
@@ -139,7 +140,7 @@ class Transformer:
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
@staticmethod
def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]:
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=True) -> tuple[Transformer, dict]:
# TODO: remove the need for copy to default device
kv, state_dict = nn.state.gguf_load(gguf.to(None))
@@ -156,7 +157,8 @@ class Transformer:
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
for s in nn.state.get_parameters(model): s.replace(s.contiguous())
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
if realize: Tensor.realize(*params)
return model, kv
def generate(self, tokens:list[int], start_pos=0):
@@ -178,10 +180,55 @@ models = {
"8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
}
# *** simple OpenAI compatible server on 11434 to match ollama ***
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
class Handler(HTTPRequestHandler):
def run_model(self, ids:list[int], include_usage=False):
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk"}
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
out = []
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
if next_id == eos_id: break
out.append(next_id)
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
if include_usage:
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}}
def do_POST(self):
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
if DEBUG >= 1:
print(self.path)
print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions":
assert body["stream"], "we only support stream mode"
# extract tokens
ids = [bos_id]
for msg in body["messages"]:
ids += tok.role(msg["role"])
# content can be a str or a list
content = msg["content"]
if isinstance(content, str): ids += tok.encode(content)
elif isinstance(content, list):
for c in content:
if c["type"] == "text": ids += tok.encode(c["text"])
else: raise RuntimeError(f"unhandled type: {c['type']}")
else: raise RuntimeError(f"unknown content type: {type(content)}")
ids += tok.role("assistant")
# stream reply
self.stream_json(self.run_model(ids, include_usage=body.get("stream_options",{}).get("include_usage", False)))
else:
raise RuntimeError(f"unhandled path {self.path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--size", choices=list(models.keys()), default=list(models.keys())[0], help="Model size")
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
parser.add_argument("--serve", action="store_true", help="Run OpenAI compatible API")
args = parser.parse_args()
# load the model
@@ -192,6 +239,9 @@ if __name__ == "__main__":
bos_id: int = kv['tokenizer.ggml.bos_token_id']
eos_id: int = kv['tokenizer.ggml.eos_token_id']
# start server
if args.serve: TCPServerWithReuse(('', 11434), Handler).serve_forever()
ids: list[int] = [bos_id]
while 1:
start_pos = len(ids) - 1

View File

@@ -1,8 +1,9 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools, socketserver, json
from dataclasses import dataclass, field
from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator, cast, overload
from http.server import BaseHTTPRequestHandler
T = TypeVar("T")
U = TypeVar("U")
@@ -404,6 +405,27 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
return fp
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse(socketserver.TCPServer):
allow_reuse_address = True
def __init__(self, server_address, RequestHandlerClass):
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
super().__init__(server_address, RequestHandlerClass)
class HTTPRequestHandler(BaseHTTPRequestHandler):
def stream_json(self, source:Generator):
try:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for r in source:
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
# *** Exec helpers
def system(cmd:str, **kwargs) -> str:

View File

@@ -1,13 +1,12 @@
#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct
import ctypes, pathlib, traceback, itertools
from contextlib import redirect_stdout, redirect_stderr, contextmanager
from decimal import Decimal
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable, system
from tinygrad.helpers import printable, system, TCPServerWithReuse, HTTPRequestHandler
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
@@ -431,7 +430,7 @@ def get_render(i:int, j:int, fmt:str) -> dict:
def get_int(query:dict[str, list[str]], k:str) -> int: return int(query.get(k,["0"])[0])
class Handler(BaseHTTPRequestHandler):
class Handler(HTTPRequestHandler):
def do_GET(self):
ret, status_code, content_type = b"", 200, "text/html"
@@ -465,19 +464,6 @@ class Handler(BaseHTTPRequestHandler):
self.end_headers()
return self.wfile.write(ret)
def stream_json(self, source:Generator):
try:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for r in source:
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
self.wfile.flush()
self.wfile.write("data: END\n\n".encode("utf-8"))
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
# ** main loop
def reloader():
@@ -493,9 +479,6 @@ def load_pickle(path:pathlib.Path, default:T) -> T:
if not path.exists(): return default
with path.open("rb") as f: return pickle.load(f)
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
@@ -516,7 +499,6 @@ if __name__ == "__main__":
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)
reloader_thread.start()
print(f"*** started viz on {HOST}:{PORT}")
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True)
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}")
try: server.serve_forever()