From 32af1ff84b578226cf76e76735b58094074889b0 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 21 Oct 2025 15:51:32 +0800 Subject: [PATCH] viz graph drawing small cleanups (#12830) * viz graph drawing small cleanups * str literal --- test/unit/test_viz.py | 8 ++++++++ tinygrad/viz/js/index.js | 9 ++++----- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/test/unit/test_viz.py b/test/unit/test_viz.py index 9810de38d4..4edee1323b 100644 --- a/test/unit/test_viz.py +++ b/test/unit/test_viz.py @@ -148,6 +148,14 @@ class TestViz(BaseTestViz): a2 = uop_to_json(a)[id(a)] self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')") + def test_colored_label_multiline(self): + arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta") + src = [Tensor.empty(1).uop for _ in range(10)] + a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg) + exec_rewrite(a, [PatternMatcher([])]) + a2 = next(get_viz_details(0, 0))["graph"][id(a)] + self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw") + def test_inf_loop(self): a = UOp.variable('a', 0, 10, dtype=dtypes.int) b = a.replace(op=Ops.CONST) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 4c4888fc48..dd803bda5b 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -78,7 +78,7 @@ function renderDag(graph, additions, recenter) { if (parents == null && children == null) return; const src = [...parents, ...children, d.id]; nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id)); - const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : ""; + const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : ""; d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath"); d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port"); e.stopPropagation(); @@ -92,10 +92,9 @@ function renderDag(graph, additions, recenter) { }).selectAll("text").data(d => { const ret = [[]]; for (const { st, color } of parseColors(d.label, defaultColor="initial")) { - for (const [i, l] of st.split("\n").entries()) { - if (i > 0) ret.push([]); - ret.at(-1).push({ st:l, color }); - } + const lines = st.split("\n"); + ret.at(-1).push({ st:lines[0], color }); + for (let i=1; i d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")