diff --git a/setup.py b/setup.py
index 2fe0ed9ebe..5618d77c90 100644
--- a/setup.py
+++ b/setup.py
@@ -25,7 +25,7 @@ setup(name='tinygrad',
packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.runtime.autogen.am', 'tinygrad.codegen', 'tinygrad.nn',
'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support',
'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape'],
- package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*']},
+ package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'lib/**/*']},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"
diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html
index d1b7181633..a177d66f10 100644
--- a/tinygrad/viz/index.html
+++ b/tinygrad/viz/index.html
@@ -6,6 +6,7 @@
+
@@ -233,53 +234,6 @@
contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains]
}));
- // **** D3
- function recenterRects(svg, zoom) {
- const svgBounds = svg.node().getBoundingClientRect();
- for (const rect of svg.node().querySelectorAll("rect")) {
- const rectBounds = rect.getBoundingClientRect();
- const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
- rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
- // if there's at least one rect in view we don't do anything
- if (!outOfBounds) return;
- }
- svg.call(zoom.transform, d3.zoomIdentity)
- }
- function renderGraph(graph, additions) {
- const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
- g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
- for (const [k,u] of Object.entries(graph)) {
- let node = {label: u[0], labelType: "text", style: `fill: ${u[2]};`};
- // for PROGRAM UOp we render the node with a code block
- if (u[0].includes("PROGRAM")) {
- const [name, ...rest] = u[0].split("\n");
- const label = Object.assign(document.createElement("div"));
- label.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }))
- label.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
- node = {label, labelType: "html", style: `fill: ${u[2]}`};
- }
- g.setNode(k, node);
- for (const src of u[1]) {
- g.setEdge(src, k, {curve: d3.curveBasis})
- }
- if (additions.includes(parseInt(k))) {
- g.setParent(k, "addition");
- }
- }
- const svg = d3.select("#graph-svg");
- const inner = svg.select("g");
- var zoom = d3.zoom()
- .scaleExtent([0.05, 2])
- .on("zoom", () => {
- const transform = d3.event.transform;
- inner.attr("transform", transform);
- });
- recenterRects(svg, zoom);
- svg.call(zoom);
- const render = new dagreD3.render();
- render(inner, g);
- }
-
// **** extra helpers
const toPath = ([fp, lineno]) => `${fp.replaceAll("\\", "/").split("/").pop()}:${lineno}`;
const vsCodeOpener = (parts) => Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n",
diff --git a/tinygrad/viz/lib/graph.js b/tinygrad/viz/lib/graph.js
new file mode 100644
index 0000000000..73a4c8d275
--- /dev/null
+++ b/tinygrad/viz/lib/graph.js
@@ -0,0 +1,48 @@
+function recenterRects(svg, zoom) {
+ const svgBounds = svg.node().getBoundingClientRect();
+ for (const rect of svg.node().querySelectorAll("rect")) {
+ const rectBounds = rect.getBoundingClientRect();
+ const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left ||
+ rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right;
+ // if there's at least one rect in view we don't do anything
+ if (!outOfBounds) return;
+ }
+ svg.call(zoom.transform, d3.zoomIdentity)
+}
+
+window.renderGraph = function(graph, additions) {
+ // ** initialize the graph
+ const g = new dagreD3.dagre.graphlib.Graph({ compound: true });
+ g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; });
+ g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"});
+ for (const [k,v] of Object.entries(graph)) {
+ const [label, src, color] = v;
+ const node = {label, labelType:"text", style:`fill: ${color};`};
+ // for PROGRAM UOp we render the node with a code block
+ if (label.includes("PROGRAM")) {
+ const [name, ...rest] = label.split("\n");
+ const labelEl = Object.assign(document.createElement("div"));
+ labelEl.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" }));
+ labelEl.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true));
+ node.label = labelEl;
+ node.labelType = "html";
+ }
+ g.setNode(k, node);
+ for (const s of src) g.setEdge(s, k, {curve: d3.curveBasis});
+ if (additions.includes(parseInt(k))) g.setParent(k, "addition");
+ }
+
+ // ** select svg render
+ const svg = d3.select("#graph-svg");
+ const inner = svg.select("g");
+ const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", () => {
+ const transform = d3.event.transform;
+ inner.attr("transform", transform);
+ });
+ recenterRects(svg, zoom);
+ svg.call(zoom);
+
+ // ** calculate layout + render
+ const render = new dagreD3.render();
+ render(inner, g);
+}
diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py
index 32bf84c418..55d7ee970a 100755
--- a/tinygrad/viz/serve.py
+++ b/tinygrad/viz/serve.py
@@ -124,7 +124,7 @@ class Handler(BaseHTTPRequestHandler):
with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read()
elif (url:=urlparse(self.path)).path == "/profiler":
with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read()
- elif self.path.startswith("/assets/") and '/..' not in self.path:
+ elif (self.path.startswith("/assets/") or self.path.startswith("/lib/")) and '/..' not in self.path:
try:
with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read()
if url.path.endswith(".js"): content_type = "application/javascript"