viz: count unpickle in server startup time (#12715)

* viz: count unpickle in server startup time

* type checking
This commit is contained in:
qazal
2025-10-16 13:07:46 +08:00
committed by GitHub
parent 7c19db00f1
commit b77bdbbc62

View File

@@ -5,7 +5,7 @@ from contextlib import redirect_stdout
from decimal import Decimal
from http.server import BaseHTTPRequestHandler
from urllib.parse import parse_qs, urlparse
from typing import Any, TypedDict, Generator
from typing import Any, TypedDict, TypeVar, Generator
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
@@ -294,8 +294,9 @@ def reloader():
os.execv(sys.executable, [sys.executable] + sys.argv)
time.sleep(0.1)
def load_pickle(fp:str) -> list:
if not (path:=pathlib.Path(fp)).exists(): return []
T = TypeVar("T")
def load_pickle(path:pathlib.Path, default:T) -> T:
if not path.exists(): return default
with path.open("rb") as f: return pickle.load(f)
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
@@ -303,8 +304,8 @@ class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--kernels', type=load_pickle, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
parser.add_argument('--profile', type=load_pickle, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
args = parser.parse_args()
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
@@ -315,8 +316,8 @@ if __name__ == "__main__":
st = time.perf_counter()
print("*** viz is starting")
ctxs = get_rewrites(trace:=args.kernels)
profile_ret = get_profile(args.profile)
ctxs = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
profile_ret = get_profile(load_pickle(args.profile, default=[]))
server = TCPServerWithReuse(('', PORT), Handler)
reloader_thread = threading.Thread(target=reloader)