viz graph drawing small cleanups (#12830)

* viz graph drawing small cleanups

* str literal
This commit is contained in:
qazal
2025-10-21 15:51:32 +08:00
committed by GitHub
parent 367fbabc30
commit 32af1ff84b
2 changed files with 12 additions and 5 deletions

View File

@@ -148,6 +148,14 @@ class TestViz(BaseTestViz):
a2 = uop_to_json(a)[id(a)]
self.assertEqual(ansistrip(a2["label"]), f"CUSTOM\n{TestStruct.__qualname__}(colored_field='xyz12345')")
def test_colored_label_multiline(self):
arg = colored("x", "green")+"\n"+colored("y", "red")+colored("z", "yellow")+colored("ww\nw", "magenta")
src = [Tensor.empty(1).uop for _ in range(10)]
a = UOp(Ops.CUSTOM, src=tuple(src), arg=arg)
exec_rewrite(a, [PatternMatcher([])])
a2 = next(get_viz_details(0, 0))["graph"][id(a)]
self.assertEqual(ansistrip(a2["label"]), "CUSTOM\nx\nyzww\nw")
def test_inf_loop(self):
a = UOp.variable('a', 0, 10, dtype=dtypes.int)
b = a.replace(op=Ops.CONST)

View File

@@ -78,7 +78,7 @@ function renderDag(graph, additions, recenter) {
if (parents == null && children == null) return;
const src = [...parents, ...children, d.id];
nodes.classed("highlight", n => src.includes(n.id)).classed("child", n => children.includes(n.id));
const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : "";
const matchEdge = (v, w) => (v===d.id && children.includes(w)) ? "highlight child " : (parents.includes(v) && w===d.id) ? "highlight " : "";
d3.select("#edges").selectAll("path.edgePath").attr("class", e => matchEdge(e.v, e.w)+"edgePath");
d3.select("#edge-labels").selectAll("g.port").attr("class", (_, i, n) => matchEdge(...n[i].id.split("-"))+"port");
e.stopPropagation();
@@ -92,10 +92,9 @@ function renderDag(graph, additions, recenter) {
}).selectAll("text").data(d => {
const ret = [[]];
for (const { st, color } of parseColors(d.label, defaultColor="initial")) {
for (const [i, l] of st.split("\n").entries()) {
if (i > 0) ret.push([]);
ret.at(-1).push({ st:l, color });
}
const lines = st.split("\n");
ret.at(-1).push({ st:lines[0], color });
for (let i=1; i<lines.length; i++) ret.push([{ st:lines[i], color }]);
}
return [ret];
}).join("text").selectAll("tspan").data(d => d).join("tspan").attr("x", "0").attr("dy", 14).selectAll("tspan").data(d => d).join("tspan")