Files
tinygrad/test/null/test_llm_server.py
George Hotz ec00cefa5b llm is the only app (#15779)
* tinygrad/llm is the only app

* upd pyproject

* claude refs

* scoping

* min diff
2026-04-17 10:44:48 +08:00

212 lines
8.0 KiB
Python

import unittest, threading, time
from unittest.mock import Mock
class TestLLMServer(unittest.TestCase):
"""Integration tests using the real OpenAI client."""
@classmethod
def setUpClass(cls):
cls.mock_tok = Mock()
cls.mock_tok.role = Mock(return_value=[100, 101])
cls.mock_tok.encode = Mock(return_value=[200, 201, 202])
cls.mock_tok.decode = Mock(return_value="Hello")
cls.mock_tok.stream_decoder = Mock(return_value=lambda tid=None: "Hello" if tid is not None else "")
cls.mock_tok.end_turn = Mock(return_value=[998])
cls.mock_tok.prefix = Mock(return_value=[1])
cls.mock_tok.preset = "llama3"
cls.mock_model = Mock()
cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999]))
cls.mock_model.get_start_pos = Mock(return_value=0)
cls.bos_id = 1
cls.eos_id = 999
from tinygrad.llm.cli import Handler, LLMServer
cls.server = LLMServer(('127.0.0.1', 0), Handler)
cls.server.model = cls.mock_model
cls.server.model_name = "test-model"
cls.server.tok = cls.mock_tok
cls.server.bos_id = cls.bos_id
cls.server.eos_id = cls.eos_id
cls.server.eot_id = None
cls.port = cls.server.server_address[1]
cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True)
cls.server_thread.start()
time.sleep(0.1)
from openai import OpenAI
cls.client = OpenAI(base_url=f"http://127.0.0.1:{cls.port}/v1", api_key="test")
@classmethod
def tearDownClass(cls):
cls.server.shutdown()
cls.server.server_close()
def test_chat_completion_stream(self):
stream = self.client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hello"}],
stream=True
)
chunks = list(stream)
self.assertGreater(len(chunks), 0)
self.assertEqual(chunks[0].choices[0].delta.role, "assistant")
self.assertEqual(chunks[-1].choices[0].finish_reason, "stop")
def test_openai_response_structure(self):
stream = self.client.chat.completions.create(
model="test-model",
messages=[{"role": "user", "content": "Test"}],
stream=True
)
for chunk in stream:
self.assertTrue(chunk.id.startswith("chatcmpl-"))
self.assertEqual(chunk.object, "chat.completion.chunk")
self.assertIsNotNone(chunk.choices)
self.assertIsNotNone(chunk.created)
self.assertIsInstance(chunk.created, int)
self.assertEqual(chunk.model, "test-model")
def test_stream_with_usage(self):
stream = self.client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hello"}],
stream=True,
stream_options={"include_usage": True}
)
chunks = list(stream)
last_chunk = chunks[-1]
self.assertIsNotNone(last_chunk.usage)
self.assertIsNotNone(last_chunk.usage.prompt_tokens)
self.assertIsNotNone(last_chunk.usage.completion_tokens)
self.assertIsNotNone(last_chunk.usage.total_tokens)
def test_multi_turn_conversation(self):
stream = self.client.chat.completions.create(
model="test",
messages=[
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "How are you?"}
],
stream=True
)
chunks = list(stream)
self.assertGreater(len(chunks), 0)
self.assertEqual(chunks[-1].choices[0].finish_reason, "stop")
def test_content_is_streamed(self):
stream = self.client.chat.completions.create(
model="test",
messages=[{"role": "user", "content": "Hello"}],
stream=True
)
contents = []
for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
contents.append(chunk.choices[0].delta.content)
self.assertGreater(len(contents), 0)
def test_non_streaming(self):
resp = self.client.chat.completions.create(
model="test-model",
messages=[{"role": "user", "content": "Hello"}],
stream=False
)
self.assertTrue(resp.id.startswith("chatcmpl-"))
self.assertEqual(resp.object, "chat.completion")
self.assertEqual(resp.model, "test-model")
self.assertIsNotNone(resp.created)
self.assertEqual(len(resp.choices), 1)
self.assertEqual(resp.choices[0].message.role, "assistant")
self.assertIsNotNone(resp.choices[0].message.content)
self.assertEqual(resp.choices[0].finish_reason, "stop")
self.assertIsNotNone(resp.usage)
self.assertIsNotNone(resp.usage.prompt_tokens)
self.assertIsNotNone(resp.usage.completion_tokens)
def test_max_tokens_streaming(self):
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 302, 303, 999]))
stream = self.client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "Hello"}], stream=True, max_tokens=2
)
chunks = list(stream)
content_chunks = [c for c in chunks if c.choices and c.choices[0].delta.content]
self.assertEqual(len(content_chunks), 2)
self.assertEqual(chunks[-1].choices[0].finish_reason, "length")
def test_max_tokens_non_streaming(self):
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 302, 303, 999]))
resp = self.client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "Hello"}], stream=False, max_tokens=2
)
self.assertEqual(resp.choices[0].finish_reason, "length")
self.assertEqual(resp.usage.completion_tokens, 2)
def test_assistant_prefill(self):
"""Last assistant message should be treated as prefill (not a completed turn)."""
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
captured_ids = []
orig_generate = self.mock_model.generate.side_effect
def capture_generate(ids, **kwargs):
captured_ids.extend(ids)
return orig_generate(ids, **kwargs)
self.mock_model.generate = Mock(side_effect=capture_generate)
resp = self.client.chat.completions.create(
model="test", messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Sure"}
], stream=False
)
# prefill tokens should be in ids: role("assistant") + encode("Sure") but NO end_turn after it
# and NO extra role("assistant") appended
role_tokens = self.mock_tok.role.call_args_list
# last role() call should be for "assistant" (the prefill message), not an extra one
self.assertEqual(role_tokens[-1], unittest.mock.call("assistant"))
# end_turn should be called once less than role() — the prefill assistant msg doesn't get end_turn
self.assertEqual(self.mock_tok.end_turn.call_count, self.mock_tok.role.call_count - 1)
self.assertIsNotNone(resp.choices[0].message.content)
def test_assistant_prefill_not_last(self):
"""Assistant message that's NOT last should be a normal completed turn."""
self.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 999]))
self.mock_tok.role.reset_mock()
self.mock_tok.end_turn.reset_mock()
self.client.chat.completions.create(
model="test", messages=[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Sure"},
{"role": "user", "content": "Continue"}
], stream=False
)
# all messages get end_turn, plus an extra role("assistant") at the end
# roles: user, assistant, user, assistant(generation prompt) = 4 role calls
# end_turns: user, assistant, user = 3 end_turn calls (one per message)
self.assertEqual(self.mock_tok.end_turn.call_count, 3)
self.assertEqual(self.mock_tok.role.call_count, 4)
def test_models_endpoint(self):
import requests as req
resp = req.get(f"http://127.0.0.1:{self.port}/v1/models")
self.assertEqual(resp.status_code, 200)
data = resp.json()
self.assertEqual(data["object"], "list")
self.assertEqual(len(data["data"]), 1)
self.assertEqual(data["data"][0]["id"], "test-model")
self.assertEqual(data["data"][0]["object"], "model")
if __name__ == '__main__':
unittest.main()