From 42cbf7aed41d4f9bd35212d37f978347edf6f419 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sat, 3 May 2025 10:57:56 +0300 Subject: [PATCH] more viz cleanups + notes [pr] (#10145) --- tinygrad/viz/index.html | 2 +- tinygrad/viz/js/index.js | 38 ++++++++++++++++++-------------------- tinygrad/viz/serve.py | 2 +- 3 files changed, 20 insertions(+), 22 deletions(-) diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 6cc8f5b650..a6d847a41b 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -78,7 +78,7 @@ position: relative; height: 100%; } - .container > * + *, .rewrite-container > * + *, .kernel-list > * + * { + .metadata > * + *, .rewrite-container > * + *, .kernel-list > * + * { margin-top: 12px; } .kernel-list > ul > * + * { diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 85c3747c9e..dabedcf638 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -38,7 +38,7 @@ async function renderDag(graph, additions, recenter=false) { const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g") .attr("transform", d => `translate(${d.x},${d.y})`); nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color) - .attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => "rx:8; ry:8; stroke:#4a4b57; stroke-width:1.4px;"+d.style); + .attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => "stroke:#4a4b57; stroke-width:1.4px;"+d.style); nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => { const x = (d.width-d.padding*2)/2; const y = (d.height-d.padding*2)/2; @@ -71,8 +71,8 @@ async function renderDag(graph, additions, recenter=false) { }); edgeLabels.selectAll("circle").data(e => [g.edge(e).label]).join("circle").attr("r", 5).attr("fill", "#FFD700").attr("stroke", "#B8860B") .attr("stroke-width", 0.8); - edgeLabels.selectAll("text").data(e => [g.edge(e).label]).join("text").text(d => d).attr("text-anchor", "middle").attr("dy", "0.35em"). - attr("font-size", "6px").attr("fill", "black"); + edgeLabels.selectAll("text").data(e => [g.edge(e).label]).join("text").text(d => d).attr("text-anchor", "middle").attr("dy", "0.35em") + .attr("font-size", "6px").attr("fill", "black"); if (recenter) document.getElementById("zoom-to-fit-btn").click(); }; @@ -412,16 +412,22 @@ appendResizer(document.querySelector(".metadata-parent"), { minWidth: 20, maxWid document.addEventListener("keydown", async function(event) { const { currentKernel, currentUOp, currentRewrite, expandKernel } = state; // up and down change the UOp or kernel from the list - if (!expandKernel) { - if (event.key == "ArrowUp") { - event.preventDefault() - return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.max(0, currentKernel-1) }); - } - if (event.key == "ArrowDown") { - event.preventDefault() - return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.min(kernels.length-1, currentKernel+1) }); + if (event.key == "ArrowUp") { + event.preventDefault(); + if (expandKernel) { + return setState({ currentRewrite:0, currentUOp:Math.max(0, currentUOp-1) }); } + return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.max(0, currentKernel-1) }); } + if (event.key == "ArrowDown") { + event.preventDefault(); + if (expandKernel) { + const totalUOps = kernels[currentKernel][1].length-1; + return setState({ currentRewrite:0, currentUOp:Math.min(totalUOps, currentUOp+1) }); + } + return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.min(kernels.length-1, currentKernel+1) }); + } + // enter toggles focus on a single rewrite stage if (event.key == "Enter") { event.preventDefault() if (state.currentKernel === -1) { @@ -429,15 +435,6 @@ document.addEventListener("keydown", async function(event) { } return setState({ currentUOp:0, currentRewrite:0, expandKernel:!expandKernel }); } - if (event.key == "ArrowUp") { - event.preventDefault() - return setState({ currentRewrite:0, currentUOp:Math.max(0, currentUOp-1) }); - } - if (event.key == "ArrowDown") { - event.preventDefault() - const totalUOps = kernels[currentKernel][1].length-1; - return setState({ currentRewrite:0, currentUOp:Math.min(totalUOps, currentUOp+1) }); - } // left and right go through rewrites in a single UOp if (event.key == "ArrowLeft") { event.preventDefault() @@ -448,6 +445,7 @@ document.addEventListener("keydown", async function(event) { const totalRewrites = ret.length-1; return setState({ currentRewrite:Math.min(totalRewrites, currentRewrite+1) }); } + // space recenters the graph if (event.key == " ") { event.preventDefault() document.getElementById("zoom-to-fit-btn").click(); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index e79371630e..3cc215a363 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -91,7 +91,7 @@ def get_details(k:Any, ctx:TrackedGraphRewrite) -> Generator[GraphRewriteDetails replaces[u0] = u1 try: new_sink = next_sink.substitute(replaces) except RecursionError as e: new_sink = UOp(Ops.NOOP, arg=str(e)) - yield {"graph": (sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json], + yield {"graph":(sink_json:=uop_to_json(new_sink)), "uop":str(new_sink), "changed_nodes":[id(x) for x in u1.toposort() if id(x) in sink_json], "diff":list(difflib.unified_diff(pcall(str, u0).splitlines(), pcall(str, u1).splitlines())), "upat":(upat.location, upat.printable())} if not ctx.bottom_up: next_sink = new_sink