diff --git a/test/unit/test_llm_server.py b/test/unit/test_llm_server.py index 38f6877890..942baea061 100644 --- a/test/unit/test_llm_server.py +++ b/test/unit/test_llm_server.py @@ -25,7 +25,7 @@ class TestLLMServer(unittest.TestCase): llm_module.eos_id = cls.eos_id from tinygrad.apps.llm import Handler - from tinygrad.helpers import TCPServerWithReuse + from tinygrad.viz.serve import TCPServerWithReuse cls.server = TCPServerWithReuse(('127.0.0.1', 0), Handler) cls.port = cls.server.server_address[1] diff --git a/tinygrad/apps/llm.py b/tinygrad/apps/llm.py index 52e33c7c1c..0a2880dd83 100644 --- a/tinygrad/apps/llm.py +++ b/tinygrad/apps/llm.py @@ -1,7 +1,8 @@ from __future__ import annotations import sys, argparse, typing, re, unicodedata, json, uuid, time, functools from tinygrad import Tensor, nn, UOp, TinyJit, getenv -from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored +from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored +from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler class SimpleTokenizer: def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 068538accc..e75e5ae5b1 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -1,9 +1,8 @@ from __future__ import annotations import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc -import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools, socketserver, json +import subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools from dataclasses import dataclass, field from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator, cast, overload -from http.server import BaseHTTPRequestHandler T = TypeVar("T") U = TypeVar("U") @@ -380,6 +379,7 @@ def _ensure_downloads_dir() -> pathlib.Path: def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False, allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers:dict[str, str]={}) -> pathlib.Path: + import urllib.request if url.startswith(("/", ".")): return pathlib.Path(url) if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name) else: @@ -400,33 +400,6 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}") return fp -# NOTE: using HTTPServer forces a potentially slow socket.getfqdn -class TCPServerWithReuse(socketserver.TCPServer): - allow_reuse_address = True - def __init__(self, server_address, RequestHandlerClass): - print(f"*** started server on http://127.0.0.1:{server_address[1]}") - super().__init__(server_address, RequestHandlerClass) - -class HTTPRequestHandler(BaseHTTPRequestHandler): - def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200): - self.send_response(status_code) - self.send_header("Content-Type", content_type) - self.send_header("Content-Length", str(len(data))) - self.end_headers() - return self.wfile.write(data) - def stream_json(self, source:Generator): - try: - self.send_response(200) - self.send_header("Content-Type", "text/event-stream") - self.send_header("Cache-Control", "no-cache") - self.end_headers() - for r in source: - self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8")) - self.wfile.flush() - self.wfile.write("data: [DONE]\n\n".encode("utf-8")) - # pass if client closed connection - except (BrokenPipeError, ConnectionResetError): return - # *** Exec helpers def system(cmd:str, **kwargs) -> str: diff --git a/tinygrad/uop/ops.py b/tinygrad/uop/ops.py index 1c7737453a..e350a9a939 100644 --- a/tinygrad/uop/ops.py +++ b/tinygrad/uop/ops.py @@ -888,8 +888,9 @@ def print_uops(uops:list[UOp]): def get_location() -> tuple[str, int]: frm = sys._getframe(1) # skip over ops.py and anything in mixin - while ((codepath:=pathlib.Path(frm.f_code.co_filename)).name == "ops.py" or codepath.parent.name == "mixin") and frm.f_back is not None and \ - not frm.f_back.f_code.co_filename.startswith("