mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
openai api for llm (#13648)
* openai api for llm * responds to simple request * schedule cache needs to unbind * stream works * share stream code * 20k * one print * cid
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -289,8 +289,8 @@ jobs:
|
||||
python extra/optimization/extract_dataset.py
|
||||
gzip -c /tmp/sops > extra/datasets/sops.gz
|
||||
#DEBUG=1 MIN_ASTS=1 python extra/optimization/get_action_space.py
|
||||
- name: Repo line count < 19150 lines
|
||||
run: MAX_LINE_COUNT=19150 python sz.py
|
||||
- name: Repo line count < 20000 lines
|
||||
run: MAX_LINE_COUNT=20000 python sz.py
|
||||
|
||||
spec:
|
||||
strategy:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv, helpers
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int]):
|
||||
@@ -24,7 +25,7 @@ class SimpleTokenizer:
|
||||
# https://github.com/ggml-org/llama.cpp/blob/94933c8c2eeaa9a7983e3f6c08af76bd86724094/src/llama-vocab.cpp#L1818-L1820
|
||||
if kv["tokenizer.ggml.pre"] not in ("llama3","llama-v3","llama-bpe"): raise ValueError(f"Invalid tokenizer preset '{kv['tokenizer.ggml.pre']}'")
|
||||
vocab: typing.Iterable[tuple[str, int]] = ((tok, idx) for idx, tok in enumerate(kv["tokenizer.ggml.tokens"]))
|
||||
normal_tokens, special_tokens = helpers.partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
normal_tokens, special_tokens = partition(vocab, lambda e: kv["tokenizer.ggml.token_type"][e[1]] == 1)
|
||||
return SimpleTokenizer(dict(normal_tokens), dict(special_tokens))
|
||||
|
||||
def _encode_word(self, word:bytes) -> list[int]:
|
||||
@@ -139,7 +140,7 @@ class Transformer:
|
||||
return (self.forward_jit if getenv("JIT", 1) and tokens.shape[1] == 1 and isinstance(start_pos, UOp) else self.forward)(tokens, start_pos)
|
||||
|
||||
@staticmethod
|
||||
def from_gguf(gguf:Tensor, max_context:int|None=None) -> tuple[Transformer, dict]:
|
||||
def from_gguf(gguf:Tensor, max_context:int|None=None, realize=True) -> tuple[Transformer, dict]:
|
||||
# TODO: remove the need for copy to default device
|
||||
kv, state_dict = nn.state.gguf_load(gguf.to(None))
|
||||
|
||||
@@ -156,7 +157,8 @@ class Transformer:
|
||||
norm_eps=kv[f'{arch}.attention.layer_norm_rms_epsilon'], vocab_size=len(kv['tokenizer.ggml.tokens']), max_context=max_context)
|
||||
nn.state.load_state_dict(model, state_dict, verbose=False, consume=True, realize=False) # NOTE: rope_freqs.weight (32,) is unused
|
||||
# NOTE: without this contiguous, it unpacks the weights from the model every time. we shouldn't need this, but for now it's faster
|
||||
for s in nn.state.get_parameters(model): s.replace(s.contiguous())
|
||||
for s in (params:=nn.state.get_parameters(model)): s.replace(s.contiguous())
|
||||
if realize: Tensor.realize(*params)
|
||||
return model, kv
|
||||
|
||||
def generate(self, tokens:list[int], start_pos=0):
|
||||
@@ -178,10 +180,55 @@ models = {
|
||||
"8B": "https://huggingface.co/bartowski/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/main/Meta-Llama-3.1-8B-Instruct-Q8_0.gguf",
|
||||
}
|
||||
|
||||
# *** simple OpenAI compatible server on 11434 to match ollama ***
|
||||
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
|
||||
|
||||
class Handler(HTTPRequestHandler):
|
||||
def run_model(self, ids:list[int], include_usage=False):
|
||||
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk"}
|
||||
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
|
||||
out = []
|
||||
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
|
||||
if next_id == eos_id: break
|
||||
out.append(next_id)
|
||||
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
|
||||
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
|
||||
if include_usage:
|
||||
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}}
|
||||
|
||||
def do_POST(self):
|
||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||
body: dict[str, typing.Any] = json.loads(raw_body.decode("utf-8"))
|
||||
if DEBUG >= 1:
|
||||
print(self.path)
|
||||
print(json.dumps(body, indent=2))
|
||||
if self.path == "/v1/chat/completions":
|
||||
assert body["stream"], "we only support stream mode"
|
||||
|
||||
# extract tokens
|
||||
ids = [bos_id]
|
||||
for msg in body["messages"]:
|
||||
ids += tok.role(msg["role"])
|
||||
# content can be a str or a list
|
||||
content = msg["content"]
|
||||
if isinstance(content, str): ids += tok.encode(content)
|
||||
elif isinstance(content, list):
|
||||
for c in content:
|
||||
if c["type"] == "text": ids += tok.encode(c["text"])
|
||||
else: raise RuntimeError(f"unhandled type: {c['type']}")
|
||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
||||
ids += tok.role("assistant")
|
||||
|
||||
# stream reply
|
||||
self.stream_json(self.run_model(ids, include_usage=body.get("stream_options",{}).get("include_usage", False)))
|
||||
else:
|
||||
raise RuntimeError(f"unhandled path {self.path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--size", choices=list(models.keys()), default=list(models.keys())[0], help="Model size")
|
||||
parser.add_argument("--max_context", type=int, default=4096, help="Max Context Length")
|
||||
parser.add_argument("--serve", action="store_true", help="Run OpenAI compatible API")
|
||||
args = parser.parse_args()
|
||||
|
||||
# load the model
|
||||
@@ -192,6 +239,9 @@ if __name__ == "__main__":
|
||||
bos_id: int = kv['tokenizer.ggml.bos_token_id']
|
||||
eos_id: int = kv['tokenizer.ggml.eos_token_id']
|
||||
|
||||
# start server
|
||||
if args.serve: TCPServerWithReuse(('', 11434), Handler).serve_forever()
|
||||
|
||||
ids: list[int] = [bos_id]
|
||||
while 1:
|
||||
start_pos = len(ids) - 1
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
|
||||
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools
|
||||
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools, socketserver, json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator, cast, overload
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
@@ -404,6 +405,27 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
|
||||
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
||||
return fp
|
||||
|
||||
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
||||
class TCPServerWithReuse(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
def __init__(self, server_address, RequestHandlerClass):
|
||||
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
def stream_json(self, source:Generator):
|
||||
try:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.end_headers()
|
||||
for r in source:
|
||||
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
||||
self.wfile.flush()
|
||||
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
|
||||
# pass if client closed connection
|
||||
except (BrokenPipeError, ConnectionResetError): return
|
||||
|
||||
# *** Exec helpers
|
||||
|
||||
def system(cmd:str, **kwargs) -> str:
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, socketserver, functools, codecs, io, struct
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct
|
||||
import ctypes, pathlib, traceback, itertools
|
||||
from contextlib import redirect_stdout, redirect_stderr, contextmanager
|
||||
from decimal import Decimal
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.helpers import printable, system
|
||||
from tinygrad.helpers import printable, system, TCPServerWithReuse, HTTPRequestHandler
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
|
||||
from tinygrad.uop.ops import print_uops, range_start, multirange_str
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||
@@ -431,7 +430,7 @@ def get_render(i:int, j:int, fmt:str) -> dict:
|
||||
|
||||
def get_int(query:dict[str, list[str]], k:str) -> int: return int(query.get(k,["0"])[0])
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
class Handler(HTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
ret, status_code, content_type = b"", 200, "text/html"
|
||||
|
||||
@@ -465,19 +464,6 @@ class Handler(BaseHTTPRequestHandler):
|
||||
self.end_headers()
|
||||
return self.wfile.write(ret)
|
||||
|
||||
def stream_json(self, source:Generator):
|
||||
try:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.end_headers()
|
||||
for r in source:
|
||||
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
||||
self.wfile.flush()
|
||||
self.wfile.write("data: END\n\n".encode("utf-8"))
|
||||
# pass if client closed connection
|
||||
except (BrokenPipeError, ConnectionResetError): return
|
||||
|
||||
# ** main loop
|
||||
|
||||
def reloader():
|
||||
@@ -493,9 +479,6 @@ def load_pickle(path:pathlib.Path, default:T) -> T:
|
||||
if not path.exists(): return default
|
||||
with path.open("rb") as f: return pickle.load(f)
|
||||
|
||||
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
||||
class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
|
||||
@@ -516,7 +499,6 @@ if __name__ == "__main__":
|
||||
server = TCPServerWithReuse(('', PORT), Handler)
|
||||
reloader_thread = threading.Thread(target=reloader)
|
||||
reloader_thread.start()
|
||||
print(f"*** started viz on {HOST}:{PORT}")
|
||||
print(colored(f"*** ready in {(time.perf_counter()-st)*1e3:4.2f}ms", "green"), flush=True)
|
||||
if len(getenv("BROWSER", "")) > 0: webbrowser.open(f"{HOST}:{PORT}")
|
||||
try: server.serve_forever()
|
||||
|
||||
Reference in New Issue
Block a user