diff --git a/setup.py b/setup.py index 2fe0ed9ebe..5618d77c90 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ setup(name='tinygrad', packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.runtime.autogen.am', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer', 'tinygrad.engine', 'tinygrad.viz', 'tinygrad.runtime', 'tinygrad.runtime.support', 'tinygrad.runtime.support.am', 'tinygrad.runtime.graph', 'tinygrad.shape'], - package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*']}, + package_data = {'tinygrad': ['py.typed'], 'tinygrad.viz': ['index.html', 'perfetto.html', 'assets/**/*', 'lib/**/*']}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License" diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index d1b7181633..a177d66f10 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -6,6 +6,7 @@ + @@ -233,53 +234,6 @@ contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains] })); - // **** D3 - function recenterRects(svg, zoom) { - const svgBounds = svg.node().getBoundingClientRect(); - for (const rect of svg.node().querySelectorAll("rect")) { - const rectBounds = rect.getBoundingClientRect(); - const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left || - rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right; - // if there's at least one rect in view we don't do anything - if (!outOfBounds) return; - } - svg.call(zoom.transform, d3.zoomIdentity) - } - function renderGraph(graph, additions) { - const g = new dagreD3.graphlib.Graph({ compound: true }).setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); - g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"}); - for (const [k,u] of Object.entries(graph)) { - let node = {label: u[0], labelType: "text", style: `fill: ${u[2]};`}; - // for PROGRAM UOp we render the node with a code block - if (u[0].includes("PROGRAM")) { - const [name, ...rest] = u[0].split("\n"); - const label = Object.assign(document.createElement("div")); - label.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" })) - label.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true)); - node = {label, labelType: "html", style: `fill: ${u[2]}`}; - } - g.setNode(k, node); - for (const src of u[1]) { - g.setEdge(src, k, {curve: d3.curveBasis}) - } - if (additions.includes(parseInt(k))) { - g.setParent(k, "addition"); - } - } - const svg = d3.select("#graph-svg"); - const inner = svg.select("g"); - var zoom = d3.zoom() - .scaleExtent([0.05, 2]) - .on("zoom", () => { - const transform = d3.event.transform; - inner.attr("transform", transform); - }); - recenterRects(svg, zoom); - svg.call(zoom); - const render = new dagreD3.render(); - render(inner, g); - } - // **** extra helpers const toPath = ([fp, lineno]) => `${fp.replaceAll("\\", "/").split("/").pop()}:${lineno}`; const vsCodeOpener = (parts) => Object.assign(document.createElement("a"), { textContent: parts[parts.length-1]+"\n\n", diff --git a/tinygrad/viz/lib/graph.js b/tinygrad/viz/lib/graph.js new file mode 100644 index 0000000000..73a4c8d275 --- /dev/null +++ b/tinygrad/viz/lib/graph.js @@ -0,0 +1,48 @@ +function recenterRects(svg, zoom) { + const svgBounds = svg.node().getBoundingClientRect(); + for (const rect of svg.node().querySelectorAll("rect")) { + const rectBounds = rect.getBoundingClientRect(); + const outOfBounds = rectBounds.top < svgBounds.top || rectBounds.left < svgBounds.left || + rectBounds.bottom > svgBounds.bottom || rectBounds.right > svgBounds.right; + // if there's at least one rect in view we don't do anything + if (!outOfBounds) return; + } + svg.call(zoom.transform, d3.zoomIdentity) +} + +window.renderGraph = function(graph, additions) { + // ** initialize the graph + const g = new dagreD3.dagre.graphlib.Graph({ compound: true }); + g.setGraph({ rankdir: "LR" }).setDefaultEdgeLabel(function() { return {}; }); + g.setNode("addition", {label: "", clusterLabelPos: "top", style: additions.length !== 0 ? "fill: rgba(26, 27, 38, 0.5);" : "display: none;"}); + for (const [k,v] of Object.entries(graph)) { + const [label, src, color] = v; + const node = {label, labelType:"text", style:`fill: ${color};`}; + // for PROGRAM UOp we render the node with a code block + if (label.includes("PROGRAM")) { + const [name, ...rest] = label.split("\n"); + const labelEl = Object.assign(document.createElement("div")); + labelEl.appendChild(Object.assign(document.createElement("p"), {innerText: name, className: "label", style: "margin-bottom: 2px" })); + labelEl.appendChild(highlightedCodeBlock(rest.join("\n"), "cpp", true)); + node.label = labelEl; + node.labelType = "html"; + } + g.setNode(k, node); + for (const s of src) g.setEdge(s, k, {curve: d3.curveBasis}); + if (additions.includes(parseInt(k))) g.setParent(k, "addition"); + } + + // ** select svg render + const svg = d3.select("#graph-svg"); + const inner = svg.select("g"); + const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", () => { + const transform = d3.event.transform; + inner.attr("transform", transform); + }); + recenterRects(svg, zoom); + svg.call(zoom); + + // ** calculate layout + render + const render = new dagreD3.render(); + render(inner, g); +} diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 32bf84c418..55d7ee970a 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -124,7 +124,7 @@ class Handler(BaseHTTPRequestHandler): with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read() elif (url:=urlparse(self.path)).path == "/profiler": with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read() - elif self.path.startswith("/assets/") and '/..' not in self.path: + elif (self.path.startswith("/assets/") or self.path.startswith("/lib/")) and '/..' not in self.path: try: with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read() if url.path.endswith(".js"): content_type = "application/javascript"