minor import speedups (#14244)

* minor import speedups

* server stuff in server places

* pre-commit

* fix
This commit is contained in:
George Hotz
2026-01-20 15:05:36 +09:00
committed by GitHub
parent d60a155e48
commit 5e24643889
5 changed files with 39 additions and 35 deletions

View File

@@ -25,7 +25,7 @@ class TestLLMServer(unittest.TestCase):
llm_module.eos_id = cls.eos_id
from tinygrad.apps.llm import Handler
from tinygrad.helpers import TCPServerWithReuse
from tinygrad.viz.serve import TCPServerWithReuse
cls.server = TCPServerWithReuse(('127.0.0.1', 0), Handler)
cls.port = cls.server.server_address[1]

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
class SimpleTokenizer:
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):

View File

@@ -1,9 +1,8 @@
from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools, socketserver, json
import subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools
from dataclasses import dataclass, field
from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator, cast, overload
from http.server import BaseHTTPRequestHandler
T = TypeVar("T")
U = TypeVar("U")
@@ -380,6 +379,7 @@ def _ensure_downloads_dir() -> pathlib.Path:
def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False,
allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers:dict[str, str]={}) -> pathlib.Path:
import urllib.request
if url.startswith(("/", ".")): return pathlib.Path(url)
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
else:
@@ -400,33 +400,6 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
return fp
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse(socketserver.TCPServer):
allow_reuse_address = True
def __init__(self, server_address, RequestHandlerClass):
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
super().__init__(server_address, RequestHandlerClass)
class HTTPRequestHandler(BaseHTTPRequestHandler):
def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200):
self.send_response(status_code)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(data)))
self.end_headers()
return self.wfile.write(data)
def stream_json(self, source:Generator):
try:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for r in source:
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
# *** Exec helpers
def system(cmd:str, **kwargs) -> str:

View File

@@ -888,8 +888,9 @@ def print_uops(uops:list[UOp]):
def get_location() -> tuple[str, int]:
frm = sys._getframe(1)
# skip over ops.py and anything in mixin
while ((codepath:=pathlib.Path(frm.f_code.co_filename)).name == "ops.py" or codepath.parent.name == "mixin") and frm.f_back is not None and \
not frm.f_back.f_code.co_filename.startswith("<frozen"):
while frm.f_back is not None and not frm.f_back.f_code.co_filename.startswith("<frozen"):
fn = frm.f_code.co_filename.replace("\\", "/")
if not (fn.endswith("/ops.py") or "/mixin/" in fn): break
frm = frm.f_back
return frm.f_code.co_filename, frm.f_lineno

View File

@@ -1,12 +1,41 @@
#!/usr/bin/env python3
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct
import ctypes, pathlib, traceback, itertools
import ctypes, pathlib, traceback, itertools, socketserver
from contextlib import redirect_stdout, redirect_stderr, contextmanager
from decimal import Decimal
from urllib.parse import parse_qs, urlparse
from http.server import BaseHTTPRequestHandler
from typing import Any, TypedDict, TypeVar, Generator, Callable
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.helpers import printable, TCPServerWithReuse, HTTPRequestHandler
from tinygrad.helpers import printable
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
class TCPServerWithReuse(socketserver.TCPServer):
allow_reuse_address = True
def __init__(self, server_address, RequestHandlerClass):
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
super().__init__(server_address, RequestHandlerClass)
class HTTPRequestHandler(BaseHTTPRequestHandler):
def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200):
self.send_response(status_code)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(data)))
self.end_headers()
return self.wfile.write(data)
def stream_json(self, source:Generator):
try:
self.send_response(200)
self.send_header("Content-Type", "text/event-stream")
self.send_header("Cache-Control", "no-cache")
self.end_headers()
for r in source:
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
# pass if client closed connection
except (BrokenPipeError, ConnectionResetError): return
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.uop.ops import print_uops, range_start, multirange_str
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent