From 5265f25088872c68c220b362dcc840c48db2402e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:14:14 +0300 Subject: [PATCH] add counter for incoming edges in viz (#9907) --- tinygrad/viz/js/index.js | 19 +++++++++++++++++++ tinygrad/viz/js/worker.js | 6 +++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index f395ebcc59..6af8eec67c 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -54,6 +54,25 @@ async function renderDag(graph, additions, recenter=false) { points.push(intersectRect(g.node(e.w), points[points.length-1])); return line(points); }).attr("marker-end", "url(#arrowhead)"); + const edgeLabels = d3.select("#edges").selectAll("g").data(g.edges().filter(e => g.edge(e).label != null)).join("g").attr("transform", (e) => { + // get a point near the end + const [p1, p2] = g.edge(e).points.slice(-2); + const dx = p2.x-p1.x; + const dy = p2.y-p1.y; + // normalize to the unit vector + const len = Math.sqrt(dx*dx + dy*dy); + const ux = dx / len; + const uy = dy / len; + // avoid overlap with the arrowhead + const offset = 17; + const x = p2.x - ux * offset; + const y = p2.y - uy * offset; + return `translate(${x}, ${y})` + }); + edgeLabels.selectAll("circle").data(e => [g.edge(e).label]).join("circle").attr("r", 5).attr("fill", "#FFD700").attr("stroke", "#B8860B") + .attr("stroke-width", 0.8); + edgeLabels.selectAll("text").data(e => [g.edge(e).label]).join("text").text(d => d).attr("text-anchor", "middle").attr("dy", "0.35em"). + attr("font-size", "6px").attr("fill", "black"); if (recenter) document.getElementById("zoom-to-fit-btn").click(); }; diff --git a/tinygrad/viz/js/worker.js b/tinygrad/viz/js/worker.js index c184082fc9..2030e28bfb 100644 --- a/tinygrad/viz/js/worker.js +++ b/tinygrad/viz/js/worker.js @@ -23,7 +23,11 @@ onmessage = (e) => { // adjust node dims by label size + add padding const [labelWidth, labelHeight] = getTextDims(label); g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING}); - for (const s of src) g.setEdge(s, k); + const edgeCounts = {} + for (const s of src) { + edgeCounts[s] = (edgeCounts[s] || 0)+1; + } + for (const s of src) g.setEdge(s, k, { label: edgeCounts[s] > 1 ? edgeCounts[s] : null }); if (additions.includes(parseInt(k))) g.setParent(k, "addition"); } dagre.layout(g);