diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 2c40dad0ec..275ccb0c37 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -109,12 +109,12 @@ async function initWorker() { workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" })); } -function renderDag(graph, additions, recenter, layoutOpts) { +function renderDag(layoutSpec, { recenter }) { // start calculating the new layout (non-blocking) updateProgress(Status.STARTED, "Rendering new graph..."); if (worker != null) worker.terminate(); worker = new Worker(workerUrl); - worker.postMessage({graph, additions, opts:layoutOpts }); + worker.postMessage(layoutSpec); worker.onmessage = (e) => { displaySelection("#graph"); updateProgress(Status.COMPLETE); @@ -837,7 +837,7 @@ async function main() { if (ret.length === 0) return; // ** center graph const data = ret[currentRewrite]; - const render = (opts) => renderDag(data.graph, data.changed_nodes ?? [], currentRewrite === 0, opts); + const render = (opts) => renderDag({ graph:data.graph, change:data.change, opts }, { recenter:currentRewrite === 0 }); render({ showIndexing:toggle.checked }); toggle.onchange = (e) => render({ showIndexing:e.target.checked }); // ** right sidebar metadata diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 8669244de8..033a0a130f 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -5,11 +5,11 @@ const ctx = canvas.getContext("2d"); ctx.font = `350 ${LINE_HEIGHT}px sans-serif`; onmessage = (e) => { - const { graph, additions, opts } = e.data; + const { graph, change, opts } = e.data; const g = new dagre.graphlib.Graph({ compound: true }); g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); - if (additions.length !== 0) g.setNode("addition", {label:"", labelWidth:0, labelHeight:0, className:"overlay"}); - for (let [k, {label, src, ref, ...rest }] of Object.entries(graph)) { + if (change?.length) g.setNode("overlay", {label:"", labelWidth:0, labelHeight:0, className:"overlay"}); + for (const [k, {label, src, ref, ...rest }] of Object.entries(graph)) { // adjust node dims by label size (excluding escape codes) + add padding let [width, height] = [0, 0]; for (line of label.replace(/\u001B\[(?:K|.*?m)/g, "").split("\n")) { @@ -18,10 +18,10 @@ onmessage = (e) => { } g.setNode(k, {width:width+NODE_PADDING*2, height:height+NODE_PADDING*2, label, labelHeight:height, labelWidth:width, ref, id:k, ...rest}); // add edges - const edgeCounts = {} + const edgeCounts = {}; for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1; for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}}); - if (additions.includes(parseInt(k))) g.setParent(k, "addition"); + if (change?.includes(parseInt(k))) g.setParent(k, "overlay"); } // optionally hide nodes from the layuot if (!opts.showIndexing) { @@ -31,8 +31,8 @@ onmessage = (e) => { } } dagre.layout(g); - // remove additions overlay if it's empty - if (!g.node("addition")?.width) g.removeNode("addition"); + // remove overlay node if it's empty + if (!g.node("overlay")?.width) g.removeNode("overlay"); postMessage(dagre.graphlib.json.write(g)); self.close(); } diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 459273ae7b..b98b5d939b 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -53,7 +53,7 @@ class GraphRewriteDetails(TypedDict): graph: dict # JSON serialized UOp for this rewrite step uop: str # strigified UOp for this rewrite step diff: list[str]|None # diff of the single UOp that changed - changed_nodes: list[int]|None # the changed UOp id + all its parents ids + change: list[int]|None # the new UOp id + all its parents ids upat: tuple[tuple[str, int], str]|None # [loc, source_code] of the matched UPat def shape_to_str(s:tuple[sint, ...]): return "(" + ','.join(srender(x) for x in s) + ")" @@ -115,14 +115,14 @@ def _reconstruct(a:int): def get_full_rewrite(ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails, None, None]: next_sink = _reconstruct(ctx.sink) # in the schedule graph we don't show indexing ops (unless it's in a kernel AST or rewriting dtypes.index sink) - yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "changed_nodes":None, "diff":None, "upat":None} + yield {"graph":uop_to_json(next_sink), "uop":pystr(next_sink), "change":None, "diff":None, "upat":None} replaces: dict[UOp, UOp] = {} for u0_num,u1_num,upat_loc,dur in tqdm(ctx.matches): replaces[u0:=_reconstruct(u0_num)] = u1 = _reconstruct(u1_num) try: new_sink = next_sink.substitute(replaces) except RuntimeError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) match_repr = f"# {dur*1e6:.2f} us\n"+printable(upat_loc) - yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json], + yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":pystr(new_sink), "change":[id(x) for x in u1.toposort() if id(x) in sink_json], "diff":list(difflib.unified_diff(pystr(u0).splitlines(), pystr(u1).splitlines())), "upat":(upat_loc, match_repr)} if not ctx.bottom_up: next_sink = new_sink