From 88d650d60688d88ced23fc936102cf308645a7c4 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 26 Feb 2026 18:57:56 +0800 Subject: [PATCH] viz: clean up call node detection check (#15025) --- tinygrad/viz/js/index.js | 4 ++-- tinygrad/viz/js/worker.js | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index e82ca5f7f4..81a0ea0da1 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -62,7 +62,7 @@ const drawGraph = (data) => { const callCount = g.graph().callCount; const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g").attr("class", d => d.className ?? "node") .attr("transform", d => `translate(${d.x},${d.y})`).on("click", (e,d) => { - if (d.label.startsWith("CALL")) { + if (d.callNode) { if (state.callSrcMask.has(d.id)) state.callSrcMask.delete(d.id); else state.callSrcMask.add(d.id); if (state.callSrcMask.size >= callCount) { showCallSrc.toggle.checked = !showCallSrc.toggle.checked; state.callSrcMask.clear(); } return setState({}); @@ -110,7 +110,7 @@ const drawGraph = (data) => { }); addTags(nodes.selectAll("g.tag").data(d => d.tag != null ? [d] : []).join("g").attr("class", "tag") .attr("transform", d => `translate(${-d.width/2+8}, ${-d.height/2+8})`).datum(e => e.tag)); - addTags(nodes.selectAll("g.type").data(d => d.label.startsWith("CALL\n") ? [d] : []).join("g") + addTags(nodes.selectAll("g.type").data(d => d.callNode ? [d] : []).join("g") .attr("class", d => `tag ${d.collapsed ? 'collapsed' : 'expanded'}`) .attr("transform", d => `translate(${-d.width/2}, ${0})`).datum(d => d.collapsed ? "+" : "−")); // draw edges diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index 1c1c00056d..2ca5b6298f 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -54,13 +54,14 @@ const layoutUOp = (g, { graph, change }, opts) => { width = Math.max(width, ctx.measureText(line).width); height += lineHeight; } - if (label.startsWith("CALL\n")) callCount++; - g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag}); + const callNode = label.startsWith("CALL\n"); + if (callNode) callCount++; + g.setNode(k, {...rectDims(width, height), label, ref, id:k, color, tag, callNode}); // add edges const edgeCounts = {}; for (const [_, s] of src) edgeCounts[s] = (edgeCounts[s] || 0)+1; for (const [port, s] of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? {type:"tag", text:edgeCounts[s]} : {type:"port", text:port}, - ...(label.startsWith("CALL\n") && port === 0 && {color:"#a0a1b8"})}); + ...(callNode && port === 0 && {color:"#a0a1b8"})}); if (change?.includes(parseInt(k))) g.setParent(k, "overlay"); } // optionally hide nodes from the layout @@ -81,7 +82,7 @@ const layoutUOp = (g, { graph, change }, opts) => { const disconnected = new Set(); for (const n of g.nodes()) { const node = g.node(n); - if (node.label.startsWith("CALL\n") && (opts.showCallSrc ? opts.callSrcMask.has(n) : !opts.callSrcMask.has(n))) { + if (node.callNode && (opts.showCallSrc ? opts.callSrcMask.has(n) : !opts.callSrcMask.has(n))) { node.collapsed = true; for (const pred of (g.predecessors(n) || [])) { const edge = g.edge(pred, n);