llm: add created/model fields, non-streaming support, and tests (#13660)

* llm: add created/model fields, non-streaming support, and tests

- Add `created` timestamp and `model` fields to response (required by OpenAI spec)
- Add non-streaming mode support for /v1/chat/completions
- Add `send_data` helper to HTTPRequestHandler for responses with Content-Length
- Refactor viz/serve.py to use send_data
- Add integration tests using real OpenAI client

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

* add openai to testing

* toml

* Remove 'openai' from dependencies

Removed 'openai' from the dependencies list.

* bump cache

---------

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
George Hotz
2025-12-12 14:50:36 -05:00
committed by GitHub
parent 9604773e45
commit 316da9f7ff
6 changed files with 158 additions and 20 deletions

View File

@@ -1,7 +1,7 @@
name: Unit Tests
env:
# increment this when downloads substantially change to avoid the internet
CACHE_VERSION: '14'
CACHE_VERSION: '15'
CAPTURE_PROCESS_REPLAY: 1
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
PYTHONPATH: ${{ github.workspace }}

View File

@@ -69,17 +69,14 @@ testing_minimal = [
"hypothesis",
"z3-solver",
]
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate"]
testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai"]
testing = [
"tinygrad[testing_minimal]",
"tinygrad[testing_unit]",
"pillow",
"onnx==1.19.0",
"onnx2torch",
"onnxruntime",
"opencv-python",
"tabulate",
"tqdm",
"safetensors",
"transformers",
"sentencepiece",
"tiktoken",

View File

@@ -0,0 +1,136 @@
import unittest, threading, time
from unittest.mock import Mock
class TestLLMServer(unittest.TestCase):
"""Integration tests using the real OpenAI client."""
@classmethod
def setUpClass(cls):
cls.mock_tok = Mock()
cls.mock_tok.role = Mock(return_value=[100, 101])
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
cls.mock_tok.decode = Mock(return_value="Hello")
cls.mock_model = Mock()
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
cls.bos_id = 1
cls.eos_id = 999
import tinygrad.apps.llm as llm_module
llm_module.model = cls.mock_model
llm_module.tok = cls.mock_tok
llm_module.bos_id = cls.bos_id
llm_module.eos_id = cls.eos_id
from tinygrad.apps.llm import Handler
from tinygrad.helpers import TCPServerWithReuse
cls.port = 11435
cls.server = TCPServerWithReuse(('127.0.0.1', cls.port), Handler)
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
cls.server_thread.start()
time.sleep(0.1)
from openai import OpenAI
cls.client = OpenAI(base_url=f"http://127.0.0.1:{cls.port}/v1", api_key="test")
@classmethod
def tearDownClass(cls):
cls.server.shutdown()
cls.server.server_close()
def test_chat_completion_stream(self):
stream = self.client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hello"}],
stream=True
)
chunks = list(stream)
self.assertGreater(len(chunks), 0)
self.assertEqual(chunks[0].choices[0].delta.role, "assistant")
self.assertEqual(chunks[-1].choices[0].finish_reason, "stop")
def test_openai_response_structure(self):
stream = self.client.chat.completions.create(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
stream=True
)
for chunk in stream:
self.assertTrue(chunk.id.startswith("chatcmpl-"))
self.assertEqual(chunk.object, "chat.completion.chunk")
self.assertIsNotNone(chunk.choices)
self.assertIsNotNone(chunk.created)
self.assertIsInstance(chunk.created, int)
self.assertEqual(chunk.model, "test-model")
def test_stream_with_usage(self):
stream = self.client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hello"}],
stream=True,
stream_options={"include_usage": True}
)
chunks = list(stream)
last_chunk = chunks[-1]
self.assertIsNotNone(last_chunk.usage)
self.assertIsNotNone(last_chunk.usage.prompt_tokens)
self.assertIsNotNone(last_chunk.usage.completion_tokens)
self.assertIsNotNone(last_chunk.usage.total_tokens)
def test_multi_turn_conversation(self):
stream = self.client.chat.completions.create(
model="test",
messages=[
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "How are you?"}
],
stream=True
)
chunks = list(stream)
self.assertGreater(len(chunks), 0)
self.assertEqual(chunks[-1].choices[0].finish_reason, "stop")
def test_content_is_streamed(self):
stream = self.client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hello"}],
stream=True
)
contents = []
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
contents.append(chunk.choices[0].delta.content)
self.assertGreater(len(contents), 0)
def test_non_streaming(self):
resp = self.client.chat.completions.create(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
stream=False
)
self.assertTrue(resp.id.startswith("chatcmpl-"))
self.assertEqual(resp.object, "chat.completion")
self.assertEqual(resp.model, "test-model")
self.assertIsNotNone(resp.created)
self.assertEqual(len(resp.choices), 1)
self.assertEqual(resp.choices[0].message.role, "assistant")
self.assertIsNotNone(resp.choices[0].message.content)
self.assertEqual(resp.choices[0].finish_reason, "stop")
self.assertIsNotNone(resp.usage)
self.assertIsNotNone(resp.usage.prompt_tokens)
self.assertIsNotNone(resp.usage.completion_tokens)
if __name__ == '__main__':
unittest.main()

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid
import sys, argparse, typing, re, unicodedata, json, uuid, time
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG
@@ -184,8 +184,8 @@ models = {
# OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt
class Handler(HTTPRequestHandler):
def run_model(self, ids:list[int], include_usage=False):
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk"}
def run_model(self, ids:list[int], model_name:str, include_usage=False):
tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
out = []
for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1):
@@ -194,7 +194,7 @@ class Handler(HTTPRequestHandler):
yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl}
yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl}
if include_usage:
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}}
yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
def do_POST(self):
raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0")))
@@ -203,8 +203,6 @@ class Handler(HTTPRequestHandler):
print(self.path)
print(json.dumps(body, indent=2))
if self.path == "/v1/chat/completions":
assert body["stream"], "we only support stream mode"
# extract tokens
ids = [bos_id]
for msg in body["messages"]:
@@ -219,8 +217,14 @@ class Handler(HTTPRequestHandler):
else: raise RuntimeError(f"unknown content type: {type(content)}")
ids += tok.role("assistant")
# stream reply
self.stream_json(self.run_model(ids, include_usage=body.get("stream_options",{}).get("include_usage", False)))
# reply
chunks = self.run_model(ids, body["model"], not body.get("stream") or body.get("stream_options",{}).get("include_usage", False))
if body.get("stream"): self.stream_json(chunks)
else:
out = []
for c in chunks: out.append(c["choices"][0]["delta"].get("content", "") if c["choices"] else "")
self.send_data(json.dumps({**c, "object":"chat.completion",
"choices":[{"index":0, "message":{"role":"assistant","content":"".join(out)}, "finish_reason":"stop"}]}).encode())
else:
raise RuntimeError(f"unhandled path {self.path}")

View File

@@ -413,6 +413,12 @@ class TCPServerWithReuse(socketserver.TCPServer):
super().__init__(server_address, RequestHandlerClass)
class HTTPRequestHandler(BaseHTTPRequestHandler):
def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200):
self.send_response(status_code)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(data)))
self.end_headers()
return self.wfile.write(data)
def stream_json(self, source:Generator):
try:
self.send_response(200)

View File

@@ -457,12 +457,7 @@ class Handler(HTTPRequestHandler):
elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream"
else: status_code = 404
# send response
self.send_response(status_code)
self.send_header('Content-Type', content_type)
self.send_header('Content-Length', str(len(ret)))
self.end_headers()
return self.wfile.write(ret)
return self.send_data(ret, content_type, status_code)
# ** main loop