mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
llm: add created/model fields, non-streaming support, and tests (#13660)
* llm: add created/model fields, non-streaming support, and tests - Add `created` timestamp and `model` fields to response (required by OpenAI spec) - Add non-streaming mode support for /v1/chat/completions - Add `send_data` helper to HTTPRequestHandler for responses with Content-Length - Refactor viz/serve.py to use send_data - Add integration tests using real OpenAI client 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> * add openai to testing * toml * Remove 'openai' from dependencies Removed 'openai' from the dependencies list. * bump cache --------- Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: Unit Tests
|
||||
env:
|
||||
# increment this when downloads substantially change to avoid the internet
|
||||
CACHE_VERSION: '14'
|
||||
CACHE_VERSION: '15'
|
||||
CAPTURE_PROCESS_REPLAY: 1
|
||||
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
PYTHONPATH: ${{ github.workspace }}
|
||||
|
||||
@@ -69,17 +69,14 @@ testing_minimal = [
|
||||
"hypothesis",
|
||||
"z3-solver",
|
||||
]
|
||||
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate"]
|
||||
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai"]
|
||||
testing = [
|
||||
"tinygrad[testing_minimal]",
|
||||
"tinygrad[testing_unit]",
|
||||
"pillow",
|
||||
"onnx==1.19.0",
|
||||
"onnx2torch",
|
||||
"onnxruntime",
|
||||
"opencv-python",
|
||||
"tabulate",
|
||||
"tqdm",
|
||||
"safetensors",
|
||||
"transformers",
|
||||
"sentencepiece",
|
||||
"tiktoken",
|
||||
|
||||
136
test/unit/test_llm_server.py
Normal file
136
test/unit/test_llm_server.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import unittest, threading, time
|
||||
from unittest.mock import Mock
|
||||
|
||||
class TestLLMServer(unittest.TestCase):
|
||||
"""Integration tests using the real OpenAI client."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.mock_tok = Mock()
|
||||
cls.mock_tok.role = Mock(return_value=[100, 101])
|
||||
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
|
||||
cls.mock_tok.decode = Mock(return_value="Hello")
|
||||
|
||||
cls.mock_model = Mock()
|
||||
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
|
||||
|
||||
cls.bos_id = 1
|
||||
cls.eos_id = 999
|
||||
|
||||
import tinygrad.apps.llm as llm_module
|
||||
llm_module.model = cls.mock_model
|
||||
llm_module.tok = cls.mock_tok
|
||||
llm_module.bos_id = cls.bos_id
|
||||
llm_module.eos_id = cls.eos_id
|
||||
|
||||
from tinygrad.apps.llm import Handler
|
||||
from tinygrad.helpers import TCPServerWithReuse
|
||||
|
||||
cls.port = 11435
|
||||
cls.server = TCPServerWithReuse(('127.0.0.1', cls.port), Handler)
|
||||
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
|
||||
cls.server_thread.start()
|
||||
time.sleep(0.1)
|
||||
|
||||
from openai import OpenAI
|
||||
cls.client = OpenAI(base_url=f"http://127.0.0.1:{cls.port}/v1", api_key="test")
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.server.shutdown()
|
||||
cls.server.server_close()
|
||||
|
||||
def test_chat_completion_stream(self):
|
||||
stream = self.client.chat.completions.create(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
stream=True
|
||||
)
|
||||
|
||||
chunks = list(stream)
|
||||
self.assertGreater(len(chunks), 0)
|
||||
self.assertEqual(chunks[0].choices[0].delta.role, "assistant")
|
||||
self.assertEqual(chunks[-1].choices[0].finish_reason, "stop")
|
||||
|
||||
def test_openai_response_structure(self):
|
||||
stream = self.client.chat.completions.create(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Test"}],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
self.assertTrue(chunk.id.startswith("chatcmpl-"))
|
||||
self.assertEqual(chunk.object, "chat.completion.chunk")
|
||||
self.assertIsNotNone(chunk.choices)
|
||||
self.assertIsNotNone(chunk.created)
|
||||
self.assertIsInstance(chunk.created, int)
|
||||
self.assertEqual(chunk.model, "test-model")
|
||||
|
||||
def test_stream_with_usage(self):
|
||||
stream = self.client.chat.completions.create(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
stream=True,
|
||||
stream_options={"include_usage": True}
|
||||
)
|
||||
|
||||
chunks = list(stream)
|
||||
last_chunk = chunks[-1]
|
||||
|
||||
self.assertIsNotNone(last_chunk.usage)
|
||||
self.assertIsNotNone(last_chunk.usage.prompt_tokens)
|
||||
self.assertIsNotNone(last_chunk.usage.completion_tokens)
|
||||
self.assertIsNotNone(last_chunk.usage.total_tokens)
|
||||
|
||||
def test_multi_turn_conversation(self):
|
||||
stream = self.client.chat.completions.create(
|
||||
model="test",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi!"},
|
||||
{"role": "user", "content": "How are you?"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
chunks = list(stream)
|
||||
self.assertGreater(len(chunks), 0)
|
||||
self.assertEqual(chunks[-1].choices[0].finish_reason, "stop")
|
||||
|
||||
def test_content_is_streamed(self):
|
||||
stream = self.client.chat.completions.create(
|
||||
model="test",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
stream=True
|
||||
)
|
||||
|
||||
contents = []
|
||||
for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
contents.append(chunk.choices[0].delta.content)
|
||||
|
||||
self.assertGreater(len(contents), 0)
|
||||
|
||||
def test_non_streaming(self):
|
||||
resp = self.client.chat.completions.create(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
stream=False
|
||||
)
|
||||
|
||||
self.assertTrue(resp.id.startswith("chatcmpl-"))
|
||||
self.assertEqual(resp.object, "chat.completion")
|
||||
self.assertEqual(resp.model, "test-model")
|
||||
self.assertIsNotNone(resp.created)
|
||||
self.assertEqual(len(resp.choices), 1)
|
||||
self.assertEqual(resp.choices[0].message.role, "assistant")
|
||||
self.assertIsNotNone(resp.choices[0].message.content)
|
||||
self.assertEqual(resp.choices[0].finish_reason, "stop")
|
||||
self.assertIsNotNone(resp.usage)
|
||||
self.assertIsNotNone(resp.usage.prompt_tokens)
|
||||
self.assertIsNotNone(resp.usage.completion_tokens)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG
|
||||
|
||||
@@ -184,8 +184,8 @@ models = {
|
||||
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
|
||||
|
||||
class Handler(HTTPRequestHandler):
|
||||
def run_model(self, ids:list[int], include_usage=False):
|
||||
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk"}
|
||||
def run_model(self, ids:list[int], model_name:str, include_usage=False):
|
||||
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
|
||||
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
|
||||
out = []
|
||||
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
|
||||
@@ -194,7 +194,7 @@ class Handler(HTTPRequestHandler):
|
||||
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
|
||||
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
|
||||
if include_usage:
|
||||
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}}
|
||||
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
|
||||
|
||||
def do_POST(self):
|
||||
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
|
||||
@@ -203,8 +203,6 @@ class Handler(HTTPRequestHandler):
|
||||
print(self.path)
|
||||
print(json.dumps(body, indent=2))
|
||||
if self.path == "/v1/chat/completions":
|
||||
assert body["stream"], "we only support stream mode"
|
||||
|
||||
# extract tokens
|
||||
ids = [bos_id]
|
||||
for msg in body["messages"]:
|
||||
@@ -219,8 +217,14 @@ class Handler(HTTPRequestHandler):
|
||||
else: raise RuntimeError(f"unknown content type: {type(content)}")
|
||||
ids += tok.role("assistant")
|
||||
|
||||
# stream reply
|
||||
self.stream_json(self.run_model(ids, include_usage=body.get("stream_options",{}).get("include_usage", False)))
|
||||
# reply
|
||||
chunks = self.run_model(ids, body["model"], not body.get("stream") or body.get("stream_options",{}).get("include_usage", False))
|
||||
if body.get("stream"): self.stream_json(chunks)
|
||||
else:
|
||||
out = []
|
||||
for c in chunks: out.append(c["choices"][0]["delta"].get("content", "") if c["choices"] else "")
|
||||
self.send_data(json.dumps({**c, "object":"chat.completion",
|
||||
"choices":[{"index":0, "message":{"role":"assistant","content":"".join(out)}, "finish_reason":"stop"}]}).encode())
|
||||
else:
|
||||
raise RuntimeError(f"unhandled path {self.path}")
|
||||
|
||||
|
||||
@@ -413,6 +413,12 @@ class TCPServerWithReuse(socketserver.TCPServer):
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-Type", content_type)
|
||||
self.send_header("Content-Length", str(len(data)))
|
||||
self.end_headers()
|
||||
return self.wfile.write(data)
|
||||
def stream_json(self, source:Generator):
|
||||
try:
|
||||
self.send_response(200)
|
||||
|
||||
@@ -457,12 +457,7 @@ class Handler(HTTPRequestHandler):
|
||||
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
|
||||
else: status_code = 404
|
||||
|
||||
# send response
|
||||
self.send_response(status_code)
|
||||
self.send_header('Content-Type', content_type)
|
||||
self.send_header('Content-Length', str(len(ret)))
|
||||
self.end_headers()
|
||||
return self.wfile.write(ret)
|
||||
return self.send_data(ret, content_type, status_code)
|
||||
|
||||
# ** main loop
|
||||
|
||||
|
||||
Reference in New Issue
Block a user