From b77bdbbc62984e9c41f8482a92211fd72daba6de Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 16 Oct 2025 13:07:46 +0800 Subject: [PATCH] viz: count unpickle in server startup time (#12715) * viz: count unpickle in server startup time * type checking --- tinygrad/viz/serve.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index a0161ba52c..f153cf6fa5 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -5,7 +5,7 @@ from contextlib import redirect_stdout from decimal import Decimal from http.server import BaseHTTPRequestHandler from urllib.parse import parse_qs, urlparse -from typing import Any, TypedDict, Generator +from typing import Any, TypedDict, TypeVar, Generator from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device @@ -294,8 +294,9 @@ def reloader(): os.execv(sys.executable, [sys.executable] + sys.argv) time.sleep(0.1) -def load_pickle(fp:str) -> list: - if not (path:=pathlib.Path(fp)).exists(): return [] +T = TypeVar("T") +def load_pickle(path:pathlib.Path, default:T) -> T: + if not path.exists(): return default with path.open("rb") as f: return pickle.load(f) # NOTE: using HTTPServer forces a potentially slow socket.getfqdn @@ -303,8 +304,8 @@ class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--kernels', type=load_pickle, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True))) - parser.add_argument('--profile', type=load_pickle, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True))) + parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True))) + parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True))) args = parser.parse_args() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: @@ -315,8 +316,8 @@ if __name__ == "__main__": st = time.perf_counter() print("*** viz is starting") - ctxs = get_rewrites(trace:=args.kernels) - profile_ret = get_profile(args.profile) + ctxs = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {}))) + profile_ret = get_profile(load_pickle(args.profile, default=[])) server = TCPServerWithReuse(('', PORT), Handler) reloader_thread = threading.Thread(target=reloader)