diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3148bc8a6a..9250366dee 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,7 @@ name: Unit Tests env: # increment this when downloads substantially change to avoid the internet - CACHE_VERSION: '14' + CACHE_VERSION: '15' CAPTURE_PROCESS_REPLAY: 1 GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} PYTHONPATH: ${{ github.workspace }} diff --git a/pyproject.toml b/pyproject.toml index e0c2f24d02..555300c803 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,17 +69,14 @@ testing_minimal = [ "hypothesis", "z3-solver", ] -testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate"] +testing_unit = ["tinygrad[testing_minimal]", "tqdm", "safetensors", "tabulate", "openai"] testing = [ - "tinygrad[testing_minimal]", + "tinygrad[testing_unit]", "pillow", "onnx==1.19.0", "onnx2torch", "onnxruntime", "opencv-python", - "tabulate", - "tqdm", - "safetensors", "transformers", "sentencepiece", "tiktoken", diff --git a/test/unit/test_llm_server.py b/test/unit/test_llm_server.py new file mode 100644 index 0000000000..90e0089c24 --- /dev/null +++ b/test/unit/test_llm_server.py @@ -0,0 +1,136 @@ +import unittest, threading, time +from unittest.mock import Mock + +class TestLLMServer(unittest.TestCase): + """Integration tests using the real OpenAI client.""" + + @classmethod + def setUpClass(cls): + cls.mock_tok = Mock() + cls.mock_tok.role = Mock(return_value=[100, 101]) + cls.mock_tok.encode = Mock(return_value=[200, 201, 202]) + cls.mock_tok.decode = Mock(return_value="Hello") + + cls.mock_model = Mock() + cls.mock_model.generate = Mock(side_effect=lambda ids, **kwargs: iter([300, 301, 999])) + + cls.bos_id = 1 + cls.eos_id = 999 + + import tinygrad.apps.llm as llm_module + llm_module.model = cls.mock_model + llm_module.tok = cls.mock_tok + llm_module.bos_id = cls.bos_id + llm_module.eos_id = cls.eos_id + + from tinygrad.apps.llm import Handler + from tinygrad.helpers import TCPServerWithReuse + + cls.port = 11435 + cls.server = TCPServerWithReuse(('127.0.0.1', cls.port), Handler) + cls.server_thread = threading.Thread(target=cls.server.serve_forever, daemon=True) + cls.server_thread.start() + time.sleep(0.1) + + from openai import OpenAI + cls.client = OpenAI(base_url=f"http://127.0.0.1:{cls.port}/v1", api_key="test") + + @classmethod + def tearDownClass(cls): + cls.server.shutdown() + cls.server.server_close() + + def test_chat_completion_stream(self): + stream = self.client.chat.completions.create( + model="test", + messages=[{"role": "user", "content": "Hello"}], + stream=True + ) + + chunks = list(stream) + self.assertGreater(len(chunks), 0) + self.assertEqual(chunks[0].choices[0].delta.role, "assistant") + self.assertEqual(chunks[-1].choices[0].finish_reason, "stop") + + def test_openai_response_structure(self): + stream = self.client.chat.completions.create( + model="test-model", + messages=[{"role": "user", "content": "Test"}], + stream=True + ) + + for chunk in stream: + self.assertTrue(chunk.id.startswith("chatcmpl-")) + self.assertEqual(chunk.object, "chat.completion.chunk") + self.assertIsNotNone(chunk.choices) + self.assertIsNotNone(chunk.created) + self.assertIsInstance(chunk.created, int) + self.assertEqual(chunk.model, "test-model") + + def test_stream_with_usage(self): + stream = self.client.chat.completions.create( + model="test", + messages=[{"role": "user", "content": "Hello"}], + stream=True, + stream_options={"include_usage": True} + ) + + chunks = list(stream) + last_chunk = chunks[-1] + + self.assertIsNotNone(last_chunk.usage) + self.assertIsNotNone(last_chunk.usage.prompt_tokens) + self.assertIsNotNone(last_chunk.usage.completion_tokens) + self.assertIsNotNone(last_chunk.usage.total_tokens) + + def test_multi_turn_conversation(self): + stream = self.client.chat.completions.create( + model="test", + messages=[ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "How are you?"} + ], + stream=True + ) + + chunks = list(stream) + self.assertGreater(len(chunks), 0) + self.assertEqual(chunks[-1].choices[0].finish_reason, "stop") + + def test_content_is_streamed(self): + stream = self.client.chat.completions.create( + model="test", + messages=[{"role": "user", "content": "Hello"}], + stream=True + ) + + contents = [] + for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + contents.append(chunk.choices[0].delta.content) + + self.assertGreater(len(contents), 0) + + def test_non_streaming(self): + resp = self.client.chat.completions.create( + model="test-model", + messages=[{"role": "user", "content": "Hello"}], + stream=False + ) + + self.assertTrue(resp.id.startswith("chatcmpl-")) + self.assertEqual(resp.object, "chat.completion") + self.assertEqual(resp.model, "test-model") + self.assertIsNotNone(resp.created) + self.assertEqual(len(resp.choices), 1) + self.assertEqual(resp.choices[0].message.role, "assistant") + self.assertIsNotNone(resp.choices[0].message.content) + self.assertEqual(resp.choices[0].finish_reason, "stop") + self.assertIsNotNone(resp.usage) + self.assertIsNotNone(resp.usage.prompt_tokens) + self.assertIsNotNone(resp.usage.completion_tokens) + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index e11cec5222..f3d84b2028 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,5 +1,5 @@ from __future__ import annotations -import sys, argparse, typing, re, unicodedata, json, uuid +import sys, argparse, typing, re, unicodedata, json, uuid, time from tinygrad import Tensor, nn, UOp, TinyJit, getenv from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, tqdm, DEBUG @@ -184,8 +184,8 @@ models = { # OPENAI_BASE_URL=http://localhost:11434/v1 OPENAI_API_KEY=ollama uvx --from gpt-command-line gpt class Handler(HTTPRequestHandler): - def run_model(self, ids:list[int], include_usage=False): - tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk"} + def run_model(self, ids:list[int], model_name:str, include_usage=False): + tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name} yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl} out = [] for next_id in tqdm(model.generate(ids), disable=not DEBUG>=1): @@ -194,7 +194,7 @@ class Handler(HTTPRequestHandler): yield {"choices": [{"index":0, "delta":{"content":tok.decode([next_id])}, "finish_reason":None}], **tmpl} yield {"choices": [{"index":0, "delta":{},"finish_reason":"stop"}], **tmpl} if include_usage: - yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}} + yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl} def do_POST(self): raw_body = self.rfile.read(int(self.headers.get("Content-Length", "0"))) @@ -203,8 +203,6 @@ class Handler(HTTPRequestHandler): print(self.path) print(json.dumps(body, indent=2)) if self.path == "/v1/chat/completions": - assert body["stream"], "we only support stream mode" - # extract tokens ids = [bos_id] for msg in body["messages"]: @@ -219,8 +217,14 @@ class Handler(HTTPRequestHandler): else: raise RuntimeError(f"unknown content type: {type(content)}") ids += tok.role("assistant") - # stream reply - self.stream_json(self.run_model(ids, include_usage=body.get("stream_options",{}).get("include_usage", False))) + # reply + chunks = self.run_model(ids, body["model"], not body.get("stream") or body.get("stream_options",{}).get("include_usage", False)) + if body.get("stream"): self.stream_json(chunks) + else: + out = [] + for c in chunks: out.append(c["choices"][0]["delta"].get("content", "") if c["choices"] else "") + self.send_data(json.dumps({**c, "object":"chat.completion", + "choices":[{"index":0, "message":{"role":"assistant","content":"".join(out)}, "finish_reason":"stop"}]}).encode()) else: raise RuntimeError(f"unhandled path {self.path}") diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 61974d5b4e..905fd08144 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -413,6 +413,12 @@ class TCPServerWithReuse(socketserver.TCPServer): super().__init__(server_address, RequestHandlerClass) class HTTPRequestHandler(BaseHTTPRequestHandler): + def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200): + self.send_response(status_code) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", str(len(data))) + self.end_headers() + return self.wfile.write(data) def stream_json(self, source:Generator): try: self.send_response(200) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 5576e4879d..b9a32252a6 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -457,12 +457,7 @@ class Handler(HTTPRequestHandler): elif url.path == "/get_profile" and profile_ret: ret, content_type = profile_ret, "application/octet-stream" else: status_code = 404 - # send response - self.send_response(status_code) - self.send_header('Content-Type', content_type) - self.send_header('Content-Length', str(len(ret))) - self.end_headers() - return self.wfile.write(ret) + return self.send_data(ret, content_type, status_code) # ** main loop