From 3bb232eb2909e782ab91822f8a7bf562fdad728c Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 27 Jul 2025 19:51:47 +0800 Subject: [PATCH] viz: query path in rewrite steps (#11391) --- tinygrad/viz/js/index.js | 9 +++++---- tinygrad/viz/serve.py | 10 ++++++---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index e4a8ec3ce6..e415443aa8 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -470,11 +470,13 @@ async function main() { const { currentCtx, currentStep, currentRewrite, expandSteps } = state; if (currentCtx == -1) return; const ctx = ctxs[currentCtx]; - const ckey = `ctx=${currentCtx-1}&idx=${currentStep}`; + const step = ctx.steps[currentStep]; + const ckey = step?.query; // close any pending event sources let activeSrc = null; for (const e of evtSources) { - if (e.url.split("?")[1] !== ckey) e.close(); + const url = new URL(e.url); + if (url.pathname+url.search !== ckey) e.close(); else if (e.readyState === EventSource.OPEN) activeSrc = e; } if (ctx.name === "Profiler") return renderProfiler(); @@ -482,11 +484,10 @@ async function main() { ret = cache[ckey]; } // if we don't have a complete cache yet we start streaming rewrites in this step - const step = ctx.steps[currentStep]; if (!(ckey in cache) || (cache[ckey].length !== step.match_count+1 && activeSrc == null)) { ret = []; cache[ckey] = ret; - const eventSource = new EventSource(`/ctxs?${ckey}`); + const eventSource = new EventSource(ckey); evtSources.push(eventSource); eventSource.onmessage = (e) => { if (e.data === "END") return eventSource.close(); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index a301d706c9..aa19d3cbc0 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -25,8 +25,10 @@ ref_map:dict[Any, int] = {} def get_metadata(keys:list[TracingKey], contexts:list[list[TrackedGraphRewrite]]) -> list[dict]: ret = [] for i,(k,v) in enumerate(zip(keys, contexts)): - steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc)} for s in v] + steps = [{"name":s.name, "loc":s.loc, "depth":s.depth, "match_count":len(s.matches), "code_line":printable(s.loc), + "query":f"/ctxs?ctx={i}&idx={j}"} for j,s in enumerate(v)] ret.append(r:={"name":k.display_name, "fmt":k.fmt, "steps":steps}) + # use the first key to get runtime profiling data about this context if getenv("PROFILE_VALUE") >= 2 and k.keys: r["runtime_stats"] = get_runtime_stats(k.keys[0]) for key in k.keys: ref_map[key] = i return ret @@ -194,9 +196,9 @@ class Handler(BaseHTTPRequestHandler): if url.path.endswith(".js"): content_type = "application/javascript" if url.path.endswith(".css"): content_type = "text/css" except FileNotFoundError: status_code = 404 - elif url.path == "/ctxs": - if "ctx" in (q:=parse_qs(url.query)): return self.stream_json(get_details(contexts[1][int(q["ctx"][0])][int(q["idx"][0])])) - ret, content_type = json.dumps(ctxs).encode(), "application/json" + elif (query:=parse_qs(url.query)): + return self.stream_json(get_details(contexts[1][int(query["ctx"][0])][int(query["idx"][0])])) + elif url.path == "/ctxs": ret, content_type = json.dumps(ctxs).encode(), "application/json" elif url.path == "/get_profile" and profile_ret is not None: ret, content_type = profile_ret, "application/json" else: status_code = 404