mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
minor import speedups (#14244)
* minor import speedups * server stuff in server places * pre-commit * fix
This commit is contained in:
@@ -25,7 +25,7 @@ class TestLLMServer(unittest.TestCase):
|
||||
llm_module.eos_id = cls.eos_id
|
||||
|
||||
from tinygrad.apps.llm import Handler
|
||||
from tinygrad.helpers import TCPServerWithReuse
|
||||
from tinygrad.viz.serve import TCPServerWithReuse
|
||||
|
||||
cls.server = TCPServerWithReuse(('127.0.0.1', 0), Handler)
|
||||
cls.port = cls.server.server_address[1]
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import sys, argparse, typing, re, unicodedata, json, uuid, time, functools
|
||||
from tinygrad import Tensor, nn, UOp, TinyJit, getenv
|
||||
from tinygrad.helpers import partition, TCPServerWithReuse, HTTPRequestHandler, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
||||
from tinygrad.helpers import partition, DEBUG, Timing, GlobalCounters, stderr_log, colored
|
||||
from tinygrad.viz.serve import TCPServerWithReuse, HTTPRequestHandler
|
||||
|
||||
class SimpleTokenizer:
|
||||
def __init__(self, normal_tokens:dict[str, int], special_tokens:dict[str, int], preset:str="llama3"):
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, tempfile, pathlib, string, ctypes, sys, gzip, getpass, gc
|
||||
import urllib.request, subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools, socketserver, json
|
||||
import subprocess, shutil, math, types, copyreg, inspect, importlib, decimal, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Iterable, Any, TypeVar, Callable, Sequence, TypeGuard, Iterator, Generic, Generator, cast, overload
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
@@ -380,6 +379,7 @@ def _ensure_downloads_dir() -> pathlib.Path:
|
||||
|
||||
def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False,
|
||||
allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers:dict[str, str]={}) -> pathlib.Path:
|
||||
import urllib.request
|
||||
if url.startswith(("/", ".")): return pathlib.Path(url)
|
||||
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
||||
else:
|
||||
@@ -400,33 +400,6 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
|
||||
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
||||
return fp
|
||||
|
||||
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
||||
class TCPServerWithReuse(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
def __init__(self, server_address, RequestHandlerClass):
|
||||
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-Type", content_type)
|
||||
self.send_header("Content-Length", str(len(data)))
|
||||
self.end_headers()
|
||||
return self.wfile.write(data)
|
||||
def stream_json(self, source:Generator):
|
||||
try:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.end_headers()
|
||||
for r in source:
|
||||
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
||||
self.wfile.flush()
|
||||
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
|
||||
# pass if client closed connection
|
||||
except (BrokenPipeError, ConnectionResetError): return
|
||||
|
||||
# *** Exec helpers
|
||||
|
||||
def system(cmd:str, **kwargs) -> str:
|
||||
|
||||
@@ -888,8 +888,9 @@ def print_uops(uops:list[UOp]):
|
||||
def get_location() -> tuple[str, int]:
|
||||
frm = sys._getframe(1)
|
||||
# skip over ops.py and anything in mixin
|
||||
while ((codepath:=pathlib.Path(frm.f_code.co_filename)).name == "ops.py" or codepath.parent.name == "mixin") and frm.f_back is not None and \
|
||||
not frm.f_back.f_code.co_filename.startswith("<frozen"):
|
||||
while frm.f_back is not None and not frm.f_back.f_code.co_filename.startswith("<frozen"):
|
||||
fn = frm.f_code.co_filename.replace("\\", "/")
|
||||
if not (fn.endswith("/ops.py") or "/mixin/" in fn): break
|
||||
frm = frm.f_back
|
||||
return frm.f_code.co_filename, frm.f_lineno
|
||||
|
||||
|
||||
@@ -1,12 +1,41 @@
|
||||
#!/usr/bin/env python3
|
||||
import multiprocessing, pickle, difflib, os, threading, json, time, sys, webbrowser, socket, argparse, functools, codecs, io, struct
|
||||
import ctypes, pathlib, traceback, itertools
|
||||
import ctypes, pathlib, traceback, itertools, socketserver
|
||||
from contextlib import redirect_stdout, redirect_stderr, contextmanager
|
||||
from decimal import Decimal
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from typing import Any, TypedDict, TypeVar, Generator, Callable
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.helpers import printable, TCPServerWithReuse, HTTPRequestHandler
|
||||
from tinygrad.helpers import printable
|
||||
|
||||
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
||||
class TCPServerWithReuse(socketserver.TCPServer):
|
||||
allow_reuse_address = True
|
||||
def __init__(self, server_address, RequestHandlerClass):
|
||||
print(f"*** started server on http://127.0.0.1:{server_address[1]}")
|
||||
super().__init__(server_address, RequestHandlerClass)
|
||||
|
||||
class HTTPRequestHandler(BaseHTTPRequestHandler):
|
||||
def send_data(self, data:bytes, content_type:str="application/json", status_code:int=200):
|
||||
self.send_response(status_code)
|
||||
self.send_header("Content-Type", content_type)
|
||||
self.send_header("Content-Length", str(len(data)))
|
||||
self.end_headers()
|
||||
return self.wfile.write(data)
|
||||
def stream_json(self, source:Generator):
|
||||
try:
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/event-stream")
|
||||
self.send_header("Cache-Control", "no-cache")
|
||||
self.end_headers()
|
||||
for r in source:
|
||||
self.wfile.write(f"data: {json.dumps(r)}\n\n".encode("utf-8"))
|
||||
self.wfile.flush()
|
||||
self.wfile.write("data: [DONE]\n\n".encode("utf-8"))
|
||||
# pass if client closed connection
|
||||
except (BrokenPipeError, ConnectionResetError): return
|
||||
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, GroupOp, srender, sint, sym_infer, range_str, pyrender
|
||||
from tinygrad.uop.ops import print_uops, range_start, multirange_str
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device, ProfileProgramEvent
|
||||
|
||||
Reference in New Issue
Block a user