mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
viz: count unpickle in server startup time (#12715)
* viz: count unpickle in server startup time * type checking
This commit is contained in:
@@ -5,7 +5,7 @@ from contextlib import redirect_stdout
|
||||
from decimal import Decimal
|
||||
from http.server import BaseHTTPRequestHandler
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from typing import Any, TypedDict, Generator
|
||||
from typing import Any, TypedDict, TypeVar, Generator
|
||||
from tinygrad.helpers import colored, getenv, tqdm, unwrap, word_wrap, TRACEMETA, ProfileEvent, ProfileRangeEvent, TracingKey, ProfilePointEvent, temp
|
||||
from tinygrad.uop.ops import TrackedGraphRewrite, RewriteTrace, UOp, Ops, printable, GroupOp, srender, sint, sym_infer, range_str, pyrender
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileGraphEvent, ProfileGraphEntry, Device
|
||||
@@ -294,8 +294,9 @@ def reloader():
|
||||
os.execv(sys.executable, [sys.executable] + sys.argv)
|
||||
time.sleep(0.1)
|
||||
|
||||
def load_pickle(fp:str) -> list:
|
||||
if not (path:=pathlib.Path(fp)).exists(): return []
|
||||
T = TypeVar("T")
|
||||
def load_pickle(path:pathlib.Path, default:T) -> T:
|
||||
if not path.exists(): return default
|
||||
with path.open("rb") as f: return pickle.load(f)
|
||||
|
||||
# NOTE: using HTTPServer forces a potentially slow socket.getfqdn
|
||||
@@ -303,8 +304,8 @@ class TCPServerWithReuse(socketserver.TCPServer): allow_reuse_address = True
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--kernels', type=load_pickle, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
|
||||
parser.add_argument('--profile', type=load_pickle, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
parser.add_argument('--kernels', type=pathlib.Path, help='Path to kernels', default=pathlib.Path(temp("rewrites.pkl", append_user=True)))
|
||||
parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True)))
|
||||
args = parser.parse_args()
|
||||
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
@@ -315,8 +316,8 @@ if __name__ == "__main__":
|
||||
st = time.perf_counter()
|
||||
print("*** viz is starting")
|
||||
|
||||
ctxs = get_rewrites(trace:=args.kernels)
|
||||
profile_ret = get_profile(args.profile)
|
||||
ctxs = get_rewrites(trace:=load_pickle(args.kernels, default=RewriteTrace([], [], {})))
|
||||
profile_ret = get_profile(load_pickle(args.profile, default=[]))
|
||||
|
||||
server = TCPServerWithReuse(('', PORT), Handler)
|
||||
reloader_thread = threading.Thread(target=reloader)
|
||||
|
||||
Reference in New Issue
Block a user