do not lockup VIZ when rendering big graphs [pr] (#8795)

* new viz renderer

* aesthetics

* progress message

* pruning + timeout at 2s
This commit is contained in:
qazal
2025-02-26 10:15:26 +02:00
committed by GitHub
parent e162aa862d
commit 941559098b
8 changed files with 915 additions and 4851 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -5,8 +5,8 @@ fetch() {
rmdir assets/$1
curl -o assets/$1 https://$1
}
fetch "d3js.org/d3.v5.min.js"
fetch "dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"
fetch "d3js.org/d3.v7.min.js"
fetch "dagrejs.github.io/project/dagre/latest/dagre.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"

View File

@@ -4,8 +4,8 @@
<title>tinygrad viz</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
<script src="assets/d3js.org/d3.v5.min.js" charset="utf-8"></script>
<script src="assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
<script src="assets/d3js.org/d3.v7.min.js" charset="utf-8"></script>
<script src="assets/dagrejs.github.io/project/dagre/latest/dagre.min.js"></script>
<script src="lib/graph.js"></script>
<link rel="stylesheet" href="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
@@ -59,7 +59,7 @@
cursor: default;
user-select: none;
}
.node rect {
rect {
stroke: #4a4b57;
stroke-width: 1.4px;
rx: 8px;
@@ -69,9 +69,9 @@
color: #08090e;
font-weight: 350;
}
.edgePath path {
.edgePath {
stroke: #4a4b57;
fill: #4a4b57;
fill: none;
stroke-width: 1.4px;
}
.main-container {
@@ -198,6 +198,14 @@
border-radius: 8px;
padding: 8px;
}
.progress-message {
position: absolute;
z-index: 2;
left: 50%;
top: 2%;
color: #ffd230;
display: none;
}
</style>
</head>
<body>
@@ -210,8 +218,12 @@
</div>
<div class="container kernel-list-parent"><div class="container kernel-list"></div></div>
<div class="graph">
<svg id="graph-svg">
<g id="render"></g>
<div class="progress-message">Rendering new layout...</div>
<svg id="graph-svg" preserveAspectRatio="xMidYMid meet">
<g id="render">
<g id="edges"></g>
<g id="nodes"></g>
</g>
</svg>
</div>
<div class="container metadata"></div>

View File

@@ -10,39 +10,72 @@ function recenterRects(svg, zoom) {
svg.call(zoom.transform, d3.zoomIdentity)
}
function intersectRect(r1, r2) {
const dx = r2.x-r1.x;
const dy = r2.y-r1.y;
if (dx === 0 && dy === 0) throw new Error("Invalid node coordinates, rects must not overlap");
const scaleX = dx !== 0 ? (r1.width/2)/Math.abs(dx) : Infinity;
const scaleY = dy !== 0 ? (r1.height/2)/Math.abs(dy) : Infinity;
const scale = Math.min(scaleX, scaleY);
return {x:r1.x+dx*scale, y:r1.y+dy*scale};
}
const allWorkers = [];
window.renderGraph = function(graph, additions) {
// ** initialize the graph
const g = new dagreD3.dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
for (const [k,v] of Object.entries(graph)) {
const [label, src, color] = v;
const node = {label, labelType:"text", style:`fill: ${color};`};
// for PROGRAM UOp we render the node with a code block
if (label.includes("PROGRAM")) {
const [name, ...rest] = label.split("\n");
const labelEl = Object.assign(document.createElement("div"));
labelEl.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }));
labelEl.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
node.label = labelEl;
node.labelType = "html";
}
g.setNode(k, node);
for (const s of src) g.setEdge(s, k, {curve: d3.curveBasis});
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
while (allWorkers.length) {
const { worker, timeout } = allWorkers.pop();
worker.terminate();
clearTimeout(timeout);
}
// ** start calculating the new layout (non-blocking)
worker = new Worker("/lib/worker.js");
const progressMessage = document.querySelector(".progress-message");
const timeout = setTimeout(() => {
progressMessage.style.display = "block";
}, 2000);
allWorkers.push({worker, timeout});
worker.postMessage({graph, additions});
// ** select svg render
const svg = d3.select("#graph-svg");
const inner = svg.select("g");
const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", () => {
const transform = d3.event.transform;
const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", ({ transform }) => {
inner.attr("transform", transform);
});
recenterRects(svg, zoom);
svg.call(zoom);
// ** calculate layout + render
const render = new dagreD3.render();
render(inner, g);
worker.onmessage = (e) => {
progressMessage.style.display = "none";
clearTimeout(timeout);
const g = dagre.graphlib.json.read(e.data);
// ** draw nodes
const nodeRender = inner.select("#nodes");
const nodes = nodeRender.selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g")
.attr("transform", d => `translate(${d.x},${d.y})`);
nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color)
.attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style);
// +labels
nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => {
const x = (d.width-d.padding*2)/2;
const y = (d.height-d.padding*2)/2;
return `translate(-${x}, -${y})`;
}).selectAll("text").data(d => [d.label.split("\n")]).join("text").selectAll("tspan").data(d => d).join("tspan").text(d => d).attr("x", "1")
.attr("dy", 14).attr("xml:space", "preserve");
// ** draw edges
const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis);
const edgeRender = inner.select("#edges");
edgeRender.selectAll("path.edgePath").data(g.edges()).join("path").attr("class", "edgePath").attr("d", (e) => {
const edge = g.edge(e);
const points = edge.points.slice(1, edge.points.length-1);
points.unshift(intersectRect(g.node(e.v), points[0]));
points.push(intersectRect(g.node(e.w), points[points.length-1]));
return line(points);
}).attr("marker-end", "url(#arrowhead)");
// +arrow heads
inner.append("defs").append("marker").attr("id", "arrowhead").attr("viewBox", "0 -5 10 10").attr("refX", 10).attr("refY", 0)
.attr("markerWidth", 6).attr("markerHeight", 6).attr("orient", "auto").append("path").attr("d", "M0,-5L10,0L0,5").attr("fill", "#4a4b57");
};
}

View File

@@ -0,0 +1,34 @@
importScripts("../assets/dagrejs.github.io/project/dagre/latest/dagre.min.js");
const NODE_PADDING = 10;
const LINE_HEIGHT = 14;
const canvas = new OffscreenCanvas(0, 0);
const ctx = canvas.getContext('2d');
ctx.font = `${LINE_HEIGHT}px sans-serif`;
function getTextDims(text) {
let [maxWidth, height] = [0, 0];
for (line of text.split("\n")) {
const { width } = ctx.measureText(line);
if (width > maxWidth) maxWidth = width;
height += LINE_HEIGHT;
}
return [maxWidth, height];
}
onmessage = (e) => {
const { graph, additions } = e.data;
const g = new dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
if (additions.length !== 0) g.setNode("addition", {label: "", style: "fill: rgba(26, 27, 38, 0.5); stroke: none;", padding:0});
for (const [k, [label, src, color]] of Object.entries(graph)) {
// adjust node dims by label size + add padding
const [labelWidth, labelHeight] = getTextDims(label);
g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING, labelWidth, labelHeight});
for (const s of src) g.setEdge(s, k);
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
}
dagre.layout(g);
postMessage(dagre.graphlib.json.write(g));
self.close();
}