mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
do not lockup VIZ when rendering big graphs [pr] (#8795)
* new viz renderer * aesthetics * progress message * pruning + timeout at 2s
This commit is contained in:
2
tinygrad/viz/assets/d3js.org/d3.v5.min.js
vendored
2
tinygrad/viz/assets/d3js.org/d3.v5.min.js
vendored
File diff suppressed because one or more lines are too long
2
tinygrad/viz/assets/d3js.org/d3.v7.min.js
vendored
Normal file
2
tinygrad/viz/assets/d3js.org/d3.v7.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
801
tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js
vendored
Normal file
801
tinygrad/viz/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@@ -5,8 +5,8 @@ fetch() {
|
||||
rmdir assets/$1
|
||||
curl -o assets/$1 https://$1
|
||||
}
|
||||
fetch "d3js.org/d3.v5.min.js"
|
||||
fetch "dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"
|
||||
fetch "d3js.org/d3.v7.min.js"
|
||||
fetch "dagrejs.github.io/project/dagre/latest/dagre.min.js"
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css"
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"
|
||||
fetch "cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
<title>tinygrad viz</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
|
||||
<script src="assets/d3js.org/d3.v5.min.js" charset="utf-8"></script>
|
||||
<script src="assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
|
||||
<script src="assets/d3js.org/d3.v7.min.js" charset="utf-8"></script>
|
||||
<script src="assets/dagrejs.github.io/project/dagre/latest/dagre.min.js"></script>
|
||||
<script src="lib/graph.js"></script>
|
||||
<link rel="stylesheet" href="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
|
||||
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
|
||||
@@ -59,7 +59,7 @@
|
||||
cursor: default;
|
||||
user-select: none;
|
||||
}
|
||||
.node rect {
|
||||
rect {
|
||||
stroke: #4a4b57;
|
||||
stroke-width: 1.4px;
|
||||
rx: 8px;
|
||||
@@ -69,9 +69,9 @@
|
||||
color: #08090e;
|
||||
font-weight: 350;
|
||||
}
|
||||
.edgePath path {
|
||||
.edgePath {
|
||||
stroke: #4a4b57;
|
||||
fill: #4a4b57;
|
||||
fill: none;
|
||||
stroke-width: 1.4px;
|
||||
}
|
||||
.main-container {
|
||||
@@ -198,6 +198,14 @@
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
}
|
||||
.progress-message {
|
||||
position: absolute;
|
||||
z-index: 2;
|
||||
left: 50%;
|
||||
top: 2%;
|
||||
color: #ffd230;
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
@@ -210,8 +218,12 @@
|
||||
</div>
|
||||
<div class="container kernel-list-parent"><div class="container kernel-list"></div></div>
|
||||
<div class="graph">
|
||||
<svg id="graph-svg">
|
||||
<g id="render"></g>
|
||||
<div class="progress-message">Rendering new layout...</div>
|
||||
<svg id="graph-svg" preserveAspectRatio="xMidYMid meet">
|
||||
<g id="render">
|
||||
<g id="edges"></g>
|
||||
<g id="nodes"></g>
|
||||
</g>
|
||||
</svg>
|
||||
</div>
|
||||
<div class="container metadata"></div>
|
||||
|
||||
@@ -10,39 +10,72 @@ function recenterRects(svg, zoom) {
|
||||
svg.call(zoom.transform, d3.zoomIdentity)
|
||||
}
|
||||
|
||||
function intersectRect(r1, r2) {
|
||||
const dx = r2.x-r1.x;
|
||||
const dy = r2.y-r1.y;
|
||||
if (dx === 0 && dy === 0) throw new Error("Invalid node coordinates, rects must not overlap");
|
||||
const scaleX = dx !== 0 ? (r1.width/2)/Math.abs(dx) : Infinity;
|
||||
const scaleY = dy !== 0 ? (r1.height/2)/Math.abs(dy) : Infinity;
|
||||
const scale = Math.min(scaleX, scaleY);
|
||||
return {x:r1.x+dx*scale, y:r1.y+dy*scale};
|
||||
}
|
||||
|
||||
const allWorkers = [];
|
||||
window.renderGraph = function(graph, additions) {
|
||||
// ** initialize the graph
|
||||
const g = new dagreD3.dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
|
||||
for (const [k,v] of Object.entries(graph)) {
|
||||
const [label, src, color] = v;
|
||||
const node = {label, labelType:"text", style:`fill: ${color};`};
|
||||
// for PROGRAM UOp we render the node with a code block
|
||||
if (label.includes("PROGRAM")) {
|
||||
const [name, ...rest] = label.split("\n");
|
||||
const labelEl = Object.assign(document.createElement("div"));
|
||||
labelEl.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }));
|
||||
labelEl.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
|
||||
node.label = labelEl;
|
||||
node.labelType = "html";
|
||||
}
|
||||
g.setNode(k, node);
|
||||
for (const s of src) g.setEdge(s, k, {curve: d3.curveBasis});
|
||||
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
|
||||
while (allWorkers.length) {
|
||||
const { worker, timeout } = allWorkers.pop();
|
||||
worker.terminate();
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
|
||||
// ** start calculating the new layout (non-blocking)
|
||||
worker = new Worker("/lib/worker.js");
|
||||
const progressMessage = document.querySelector(".progress-message");
|
||||
const timeout = setTimeout(() => {
|
||||
progressMessage.style.display = "block";
|
||||
}, 2000);
|
||||
allWorkers.push({worker, timeout});
|
||||
worker.postMessage({graph, additions});
|
||||
|
||||
// ** select svg render
|
||||
const svg = d3.select("#graph-svg");
|
||||
const inner = svg.select("g");
|
||||
const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", () => {
|
||||
const transform = d3.event.transform;
|
||||
const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", ({ transform }) => {
|
||||
inner.attr("transform", transform);
|
||||
});
|
||||
recenterRects(svg, zoom);
|
||||
svg.call(zoom);
|
||||
|
||||
// ** calculate layout + render
|
||||
const render = new dagreD3.render();
|
||||
render(inner, g);
|
||||
worker.onmessage = (e) => {
|
||||
progressMessage.style.display = "none";
|
||||
clearTimeout(timeout);
|
||||
const g = dagre.graphlib.json.read(e.data);
|
||||
// ** draw nodes
|
||||
const nodeRender = inner.select("#nodes");
|
||||
const nodes = nodeRender.selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g")
|
||||
.attr("transform", d => `translate(${d.x},${d.y})`);
|
||||
nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color)
|
||||
.attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style);
|
||||
// +labels
|
||||
nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => {
|
||||
const x = (d.width-d.padding*2)/2;
|
||||
const y = (d.height-d.padding*2)/2;
|
||||
return `translate(-${x}, -${y})`;
|
||||
}).selectAll("text").data(d => [d.label.split("\n")]).join("text").selectAll("tspan").data(d => d).join("tspan").text(d => d).attr("x", "1")
|
||||
.attr("dy", 14).attr("xml:space", "preserve");
|
||||
|
||||
// ** draw edges
|
||||
const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis);
|
||||
const edgeRender = inner.select("#edges");
|
||||
edgeRender.selectAll("path.edgePath").data(g.edges()).join("path").attr("class", "edgePath").attr("d", (e) => {
|
||||
const edge = g.edge(e);
|
||||
const points = edge.points.slice(1, edge.points.length-1);
|
||||
points.unshift(intersectRect(g.node(e.v), points[0]));
|
||||
points.push(intersectRect(g.node(e.w), points[points.length-1]));
|
||||
return line(points);
|
||||
}).attr("marker-end", "url(#arrowhead)");
|
||||
// +arrow heads
|
||||
inner.append("defs").append("marker").attr("id", "arrowhead").attr("viewBox", "0 -5 10 10").attr("refX", 10).attr("refY", 0)
|
||||
.attr("markerWidth", 6).attr("markerHeight", 6).attr("orient", "auto").append("path").attr("d", "M0,-5L10,0L0,5").attr("fill", "#4a4b57");
|
||||
};
|
||||
}
|
||||
|
||||
34
tinygrad/viz/lib/worker.js
Normal file
34
tinygrad/viz/lib/worker.js
Normal file
@@ -0,0 +1,34 @@
|
||||
importScripts("../assets/dagrejs.github.io/project/dagre/latest/dagre.min.js");
|
||||
|
||||
const NODE_PADDING = 10;
|
||||
const LINE_HEIGHT = 14;
|
||||
const canvas = new OffscreenCanvas(0, 0);
|
||||
const ctx = canvas.getContext('2d');
|
||||
ctx.font = `${LINE_HEIGHT}px sans-serif`;
|
||||
|
||||
function getTextDims(text) {
|
||||
let [maxWidth, height] = [0, 0];
|
||||
for (line of text.split("\n")) {
|
||||
const { width } = ctx.measureText(line);
|
||||
if (width > maxWidth) maxWidth = width;
|
||||
height += LINE_HEIGHT;
|
||||
}
|
||||
return [maxWidth, height];
|
||||
}
|
||||
|
||||
onmessage = (e) => {
|
||||
const { graph, additions } = e.data;
|
||||
const g = new dagre.graphlib.Graph({ compound: true });
|
||||
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
|
||||
if (additions.length !== 0) g.setNode("addition", {label: "", style: "fill: rgba(26, 27, 38, 0.5); stroke: none;", padding:0});
|
||||
for (const [k, [label, src, color]] of Object.entries(graph)) {
|
||||
// adjust node dims by label size + add padding
|
||||
const [labelWidth, labelHeight] = getTextDims(label);
|
||||
g.setNode(k, {label, color, width:labelWidth+NODE_PADDING*2, height:labelHeight+NODE_PADDING*2, padding:NODE_PADDING, labelWidth, labelHeight});
|
||||
for (const s of src) g.setEdge(s, k);
|
||||
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
|
||||
}
|
||||
dagre.layout(g);
|
||||
postMessage(dagre.graphlib.json.write(g));
|
||||
self.close();
|
||||
}
|
||||
Reference in New Issue
Block a user