move viz graph to lib/graph [pr] (#9196)

* move viz graph to lib/graph [pr]

* add package

* share with program
This commit is contained in:
qazal
2025-02-21 22:04:07 +02:00
committed by GitHub
parent 6587c7879b
commit 1db4341e9f
4 changed files with 51 additions and 49 deletions

View File

@@ -25,7 +25,7 @@ setup(name='tinygrad',
packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.runtime.autogen.am', 'tinygrad.codegen', 'tinygrad.nn',
'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support',
'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape'],
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*']},
package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'lib/**/*']},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"

View File

@@ -6,6 +6,7 @@
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
<script src="assets/d3js.org/d3.v5.min.js" charset="utf-8"></script>
<script src="assets/dagrejs.github.io/project/dagre-d3/latest/dagre-d3.min.js"></script>
<script src="lib/graph.js"></script>
<link rel="stylesheet" href="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/styles/default.min.css">
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/highlight.min.js"></script>
<script src="assets/cdnjs.cloudflare.com/ajax/libs/highlight.js/11.10.0/languages/python.min.js"></script>
@@ -233,53 +234,6 @@
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));
// **** D3
function recenterRects(svg, zoom) {
const svgBounds = svg.node().getBoundingClientRect();
for (const rect of svg.node().querySelectorAll("rect")) {
const rectBounds = rect.getBoundingClientRect();
const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
// if there's at least one rect in view we don't do anything
if (!outOfBounds) return;
}
svg.call(zoom.transform, d3.zoomIdentity)
}
function renderGraph(graph, additions) {
const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
for (const [k,u] of Object.entries(graph)) {
let node = {label: u[0], labelType: "text", style: `fill: ${u[2]};`};
// for PROGRAM UOp we render the node with a code block
if (u[0].includes("PROGRAM")) {
const [name, ...rest] = u[0].split("\n");
const label = Object.assign(document.createElement("div"));
label.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }))
label.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
node = {label, labelType: "html", style: `fill: ${u[2]}`};
}
g.setNode(k, node);
for (const src of u[1]) {
g.setEdge(src, k, {curve: d3.curveBasis})
}
if (additions.includes(parseInt(k))) {
g.setParent(k, "addition");
}
}
const svg = d3.select("#graph-svg");
const inner = svg.select("g");
var zoom = d3.zoom()
.scaleExtent([0.05, 2])
.on("zoom", () => {
const transform = d3.event.transform;
inner.attr("transform", transform);
});
recenterRects(svg, zoom);
svg.call(zoom);
const render = new dagreD3.render();
render(inner, g);
}
// **** extra helpers
const toPath = ([fp, lineno]) => `${fp.replaceAll("\\", "/").split("/").pop()}:${lineno}`;
const vsCodeOpener = (parts) => Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n",

48
tinygrad/viz/lib/graph.js Normal file
View File

@@ -0,0 +1,48 @@
function recenterRects(svg, zoom) {
const svgBounds = svg.node().getBoundingClientRect();
for (const rect of svg.node().querySelectorAll("rect")) {
const rectBounds = rect.getBoundingClientRect();
const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
// if there's at least one rect in view we don't do anything
if (!outOfBounds) return;
}
svg.call(zoom.transform, d3.zoomIdentity)
}
window.renderGraph = function(graph, additions) {
// ** initialize the graph
const g = new dagreD3.dagre.graphlib.Graph({ compound: true });
g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
for (const [k,v] of Object.entries(graph)) {
const [label, src, color] = v;
const node = {label, labelType:"text", style:`fill: ${color};`};
// for PROGRAM UOp we render the node with a code block
if (label.includes("PROGRAM")) {
const [name, ...rest] = label.split("\n");
const labelEl = Object.assign(document.createElement("div"));
labelEl.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }));
labelEl.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
node.label = labelEl;
node.labelType = "html";
}
g.setNode(k, node);
for (const s of src) g.setEdge(s, k, {curve: d3.curveBasis});
if (additions.includes(parseInt(k))) g.setParent(k, "addition");
}
// ** select svg render
const svg = d3.select("#graph-svg");
const inner = svg.select("g");
const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", () => {
const transform = d3.event.transform;
inner.attr("transform", transform);
});
recenterRects(svg, zoom);
svg.call(zoom);
// ** calculate layout + render
const render = new dagreD3.render();
render(inner, g);
}

View File

@@ -124,7 +124,7 @@ class Handler(BaseHTTPRequestHandler):
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif (url:=urlparse(self.path)).path == "/profiler":
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
elif self.path.startswith("/assets/") and '/..' not in self.path:
elif (self.path.startswith("/assets/") or self.path.startswith("/lib/")) and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
if url.path.endswith(".js"): content_type = "application/javascript"