From eee0dcc37ac5f109180bb6ab50d8abc2b51d8989 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 1 Apr 2025 19:52:02 +0800 Subject: [PATCH] merge viz back into one file (#9672) * merge viz back into one file * work * rename lib to js directory * fix diff * less indenting * memory graph is back * viz_sz.py --- extra/viz_sz.py | 9 + tinygrad/viz/index.html | 249 +---------------- tinygrad/viz/js/index.js | 432 +++++++++++++++++++++++++++++ tinygrad/viz/{lib => js}/worker.js | 0 tinygrad/viz/lib/graph.js | 191 ------------- tinygrad/viz/serve.py | 2 +- 6 files changed, 444 insertions(+), 439 deletions(-) create mode 100644 extra/viz_sz.py create mode 100644 tinygrad/viz/js/index.js rename tinygrad/viz/{lib => js}/worker.js (100%) delete mode 100644 tinygrad/viz/lib/graph.js diff --git a/extra/viz_sz.py b/extra/viz_sz.py new file mode 100644 index 0000000000..e02e73ba40 --- /dev/null +++ b/extra/viz_sz.py @@ -0,0 +1,9 @@ +files = ["./tinygrad/viz/js/index.js", "./tinygrad/viz/js/worker.js"] +for fp in files: + with open(fp) as f: content = f.read() + cnt = 0 + for i,line in enumerate(content.splitlines()): + if not (line:=line.strip()) or line.startswith("//"): continue + #print(f"{i} {line}") + cnt += 1 + print(f"{fp.split('/')[-1]} - {cnt} lines") diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 3b3a89f6d4..a38f561701 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -6,7 +6,6 @@ - @@ -210,250 +209,6 @@
- - + + diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js new file mode 100644 index 0000000000..238ac77a38 --- /dev/null +++ b/tinygrad/viz/js/index.js @@ -0,0 +1,432 @@ +// **** graph renderers + +// ** UOp graph + +function intersectRect(r1, r2) { + const dx = r2.x-r1.x; + const dy = r2.y-r1.y; + if (dx === 0 && dy === 0) throw new Error("Invalid node coordinates, rects must not overlap"); + const scaleX = dx !== 0 ? (r1.width/2)/Math.abs(dx) : Infinity; + const scaleY = dy !== 0 ? (r1.height/2)/Math.abs(dy) : Infinity; + const scale = Math.min(scaleX, scaleY); + return {x:r1.x+dx*scale, y:r1.y+dy*scale}; +} + +let [workerUrl, worker, timeout] = [null, null, null]; +async function renderDag(graph, additions, recenter=false) { + // start calculating the new layout (non-blocking) + if (worker == null) { + const resp = await Promise.all(["/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js","/js/worker.js"].map(u => fetch(u))); + workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" })); + worker = new Worker(workerUrl); + } else { + worker.terminate(); + worker = new Worker(workerUrl); + } + if (timeout != null) clearTimeout(timeout); + const progressMessage = document.querySelector(".progress-message"); + timeout = setTimeout(() => {progressMessage.style.display = "block"}, 2000); + worker.postMessage({graph, additions}); + worker.onmessage = (e) => { + progressMessage.style.display = "none"; + clearTimeout(timeout); + d3.select("#bars").html(""); + const g = dagre.graphlib.json.read(e.data); + // draw nodes + const nodes = d3.select("#nodes").selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g") + .attr("transform", d => `translate(${d.x},${d.y})`); + nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color) + .attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style); + nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => { + const x = (d.width-d.padding*2)/2; + const y = (d.height-d.padding*2)/2; + return `translate(-${x}, -${y})`; + }).selectAll("text").data(d => [d.label.split("\n")]).join("text").selectAll("tspan").data(d => d).join("tspan").text(d => d).attr("x", "1") + .attr("dy", 14).attr("xml:space", "preserve"); + // draw edges + const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis); + d3.select("#edges").selectAll("path.edgePath").data(g.edges()).join("path").attr("class", "edgePath").attr("d", (e) => { + const edge = g.edge(e); + const points = edge.points.slice(1, edge.points.length-1); + points.unshift(intersectRect(g.node(e.v), points[0])); + points.push(intersectRect(g.node(e.w), points[points.length-1])); + return line(points); + }).attr("marker-end", "url(#arrowhead)"); + if (recenter) document.getElementById("zoom-to-fit-btn").click(); + }; + +} + +// ** Memory graph (WIP) + +DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4, + "long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8} +function getBuffer(e) { + const [_, size, dtype, device, num] = e.label.split("\n"); + return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])}; +} + +function pluralize(num, name, alt=null) { + return num === 1 ? `${num} ${name}` : `${num} ${alt ?? name+'s'}` +} + +function renderMemoryGraph(graph) { + // ** construct alloc/free traces + // we can map reads/writes from the kernel graph + const actions = []; + const children = new Map(); // {buffer: [...assign]} + for (const [k,v] of Object.entries(graph)) { + if (!v.label.startsWith("ASSIGN")) continue; + actions.push({ op: "write", buffer: v.src[0] }); + for (const ks of graph[v.src[1]].src) { + const node = graph[ks]; + const s = node.label.startsWith("ASSIGN") ? node.src[0] : ks; + if (!children.has(s)) children.set(s, []); + children.get(s).push(v); + if (s !== v.src[0]) actions.push({ op: "read", buffer: s }); + } + } + const prealloc = new Set(); + const traces = []; + for (const a of actions) { + // a buffer is allocated immediately before the first write + // TODO: we don't know the buffer is preallocated if there's only an assign in the graph + if (a.op === "write") { + traces.push({ type: "alloc", buffer: a.buffer }); + } + else { + if (traces.find(t => t.buffer === a.buffer && t.type === "alloc") == null) { + prealloc.add(a.buffer); + } + else if (a === actions.findLast(({ buffer }) => buffer === a.buffer)) { + traces.push({type: "free", buffer: a.buffer }); + } + } + } + // ** get coordinates and layout for each buffer + const ret = {}; + let timestep = 0; // x + let memUsed = 0; // y + for (const id of prealloc) { + const buf = getBuffer(graph[id]); + ret[id] = { x: [timestep], y: [memUsed], buf, id }; + memUsed += buf.nbytes; + } + let peak = memUsed; + const liveBufs = [...prealloc]; + for (const t of traces) { + const buf = getBuffer(graph[t.buffer]); + const idx = liveBufs.findLastIndex(b => t.buffer === b); + // alloc + if (idx === -1) { + liveBufs.push(t.buffer); + ret[t.buffer] = { x: [timestep], y: [memUsed], buf, id: t.buffer }; + memUsed += buf.nbytes; + peak = Math.max(memUsed, peak); + timestep += 1; + } // free + else { + memUsed -= buf.nbytes; + timestep += 1; + const removed = ret[liveBufs.splice(idx, 1)[0]]; + removed.x.push(timestep); + removed.y.push(removed.y.at(-1)); + if (idx < liveBufs.length) { + for (let j=idx; j d3.format(".3~s")(d)+"B"; + axesGroup.append("g").call(d3.axisLeft(yscale).tickFormat(nbytes_format)); + axesGroup.append("g").attr("transform", `translate(0, ${yscale.range()[0]})`).call(d3.axisBottom(xscale).tickFormat(() => "")); + const polygonGroup = render.append("g").attr("id", "polygons"); + const colors = ["7aa2f7", "ff9e64", "f7768e", "2ac3de", "7dcfff", "1abc9c", "9ece6a", "e0af68", "bb9af7", "9d7cd8", "ff007c"]; + const polygons = polygonGroup.selectAll("polygon").data(Object.values(ret)).join("polygon").attr("points", (d) => { + const xs = d.x.map(t => xscale(t)); + const y1 = d.y.map(t => yscale(t)); + const y2 = d.y.map(t => yscale(t+d.buf.nbytes)); + const p0 = xs.map((x, i) => `${x},${y1[i]}`); + const p1 = xs.map((x, i) => `${x},${y2[i]}`).reverse(); + return `${p0.join(' ')} ${p1.join(' ')}`; + }).attr("fill", d => `#${colors[d.buf.num % colors.length]}`).on("mouseover", (e, { id, buf, x }) => { + d3.select(e.currentTarget).attr("stroke", "rgba(26, 27, 38, 0.8)").attr("stroke-width", 0.8); + const metadata = document.querySelector(".container.metadata"); + document.getElementById("current-buf")?.remove(); + const { num, dtype, nbytes, ...rest } = buf; + let label = `\nalive for ${pluralize(x[x.length-1]-x[0], 'timestep')}`; + label += '\n'+Object.entries(rest).map(([k, v]) => `${k}=${v}`).join('\n'); + const buf_children = children.get(id); + if (buf_children) { + label += `\n${pluralize(buf_children.length, 'child', 'children')}\n`; + label += buf_children.map((c,i) => `[${i+1}] `+graph[c.src[1]].label.split("\n")[1]).join("\n"); + } + metadata.appendChild(Object.assign(document.createElement("pre"), { innerText: label, id: "current-buf", className: "wrap" })); + }).on("mouseout", (e, _) => { + d3.select(e.currentTarget).attr("stroke", null).attr("stroke-width", null); + document.getElementById("current-buf")?.remove() + }); + // TODO: add the toposort graph here + document.querySelector(".progress-message").style.display = "none"; + d3.select("#nodes").html(""); + d3.select("#edges").html(""); + document.getElementById("zoom-to-fit-btn").click(); +} + +// ** zoom and recentering + +const zoom = d3.zoom().scaleExtent([0.05, 2]).on("zoom", (e) => d3.select("#render").attr("transform", e.transform)); +d3.select("#graph-svg").call(zoom); +// zoom to fit into view +document.getElementById("zoom-to-fit-btn").addEventListener("click", () => { + const svg = d3.select("#graph-svg"); + svg.call(zoom.transform, d3.zoomIdentity); + const mainRect = document.querySelector(".main-container").getBoundingClientRect(); + const x0 = document.querySelector(".kernel-list-parent").getBoundingClientRect().right; + const x1 = document.querySelector(".metadata").getBoundingClientRect().left; + const pad = 16; + const R = { x: x0+pad, y: mainRect.top+pad, width: (x1>0 ? x1-x0 : mainRect.width)-2*pad, height: mainRect.height-2*pad }; + const r = document.querySelector("#render").getBoundingClientRect(); + if (r.width === 0) return; + const scale = Math.min(R.width/r.width, R.height/r.height); + const [tx, ty] = [R.x+(R.width-r.width*scale)/2, R.y+(R.height-r.height*scale)/2]; + svg.call(zoom.transform, d3.zoomIdentity.translate(tx, ty).scale(scale)); +}); + +// **** main VIZ interfacae + +function codeBlock(st, language, { loc, wrap }) { + const code = document.createElement("code"); + code.innerHTML = hljs.highlight(st, { language }).value; + code.className = "hljs"; + const ret = document.createElement("pre"); + if (wrap) ret.className = "wrap"; + if (loc != null) { + const link = ret.appendChild(document.createElement("a")); + link.href = "vscode://file"+loc.join(":"); + link.textContent = `${loc[0].split("/").at(-1)}:${loc[1]}`+"\n\n"; + } + ret.appendChild(code); + return ret; +} + +// ** hljs extra definitions for UOps and float4 +hljs.registerLanguage("python", (hljs) => ({ + ...hljs.getLanguage("python"), + case_insensitive: false, + contains: [ + { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-]*(\\.[a-zA-Z_][a-zA-Z0-9_-]*)*' + '(?=[.\\s\\n[:,(])', className: "type" }, + { begin: 'dtypes\\.[a-zA-Z_][a-zA-Z0-9_-].vec*' + '(?=[.\\s\\n[:,(])', className: "type" }, + { begin: '[a-zA-Z_][a-zA-Z0-9_-]*\\.[a-zA-Z_][a-zA-Z0-9_-]*' + '(?=[.\\s\\n[:,()])', className: "operator" }, + { begin: '[A-Z][a-zA-Z0-9_]*(?=\\()', className: "section", ignoreEnd: true }, + ...hljs.getLanguage("python").contains, + ] +})); +hljs.registerLanguage("cpp", (hljs) => ({ + ...hljs.getLanguage('cpp'), + contains: [{ begin: '\\b(?:float|half)[0-9]+\\b', className: 'type' }, ...hljs.getLanguage('cpp').contains] +})); + +var ret = []; +var cache = {}; +var kernels = null; +const evtSources = []; +const state = {currentKernel:-1, currentUOp:0, currentRewrite:0, expandKernel:false}; +function setState(ns) { + Object.assign(state, ns); + main(); +} +async function main() { + const { currentKernel, currentUOp, currentRewrite, expandKernel } = state; + // ** left sidebar kernel list + if (kernels == null) { + kernels = await (await fetch("/kernels")).json(); + setState({ currentKernel:-1 }); + } + const kernelList = document.querySelector(".kernel-list"); + kernelList.innerHTML = ""; + for (const [i,k] of kernels.entries()) { + const ul = kernelList.appendChild(document.createElement("ul")); + if (i === currentKernel) { + ul.className = "active"; + requestAnimationFrame(() => ul.scrollIntoView({ behavior: "auto", block: "nearest" })); + } + const p = ul.appendChild(document.createElement("p")); + p.innerHTML = k[0].replace(/\u001b\[(\d+)m(.*?)\u001b\[0m/g, (_, code, st) => { + const colors = ['gray','red','green','yellow','blue','magenta','cyan','white']; + return `${st}`; + }); + p.onclick = () => { + setState(i === currentKernel ? { expandKernel:!expandKernel } : { expandKernel:true, currentKernel:i, currentUOp:0, currentRewrite:0 }); + } + for (const [j,u] of k[1].entries()) { + const inner = ul.appendChild(document.createElement("ul")); + if (i === currentKernel && j === currentUOp) inner.className = "active"; + inner.innerText = `${u.name ?? u.loc[0].replaceAll("\\", "/").split("/").pop()+':'+u.loc[1]} - ${u.match_count}`; + inner.style.display = i === currentKernel && expandKernel ? "block" : "none"; + inner.onclick = (e) => { + e.stopPropagation(); + setState({ currentUOp:j, currentKernel:i, currentRewrite:0 }); + } + } + } + // ** center graph + if (currentKernel == -1) return; + const kernel = kernels[currentKernel][1][currentUOp]; + const cacheKey = `kernel=${currentKernel}&idx=${currentUOp}`; + // close any pending event sources + let activeSrc = null; + for (const e of evtSources) { + if (e.url.split("?")[1] !== cacheKey) e.close(); + else if (e.readyState === EventSource.OPEN) activeSrc = e; + } + if (cacheKey in cache) { + ret = cache[cacheKey]; + } + // if we don't have a complete cache yet we start streaming this kernel + if (!(cacheKey in cache) || (cache[cacheKey].length !== kernel.match_count+1 && activeSrc == null)) { + ret = []; + cache[cacheKey] = ret; + const eventSource = new EventSource(`/kernels?kernel=${currentKernel}&idx=${currentUOp}`); + evtSources.push(eventSource); + eventSource.onmessage = (e) => { + if (e.data === "END") return eventSource.close(); + const chunk = JSON.parse(e.data); + ret.push(chunk); + // if it's the first one render this new rgaph + if (ret.length === 1) return main(); + // otherwise just enable the graph selector + const ul = document.getElementById(`rewrite-${ret.length-1}`); + if (ul != null) ul.classList.remove("disabled"); + }; + } + if (ret.length === 0) return; + if (kernel.name == "View Memory Graph") { + renderMemoryGraph(ret[currentRewrite].graph); + } else { + renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0); + } + // ** right sidebar code blocks + const metadata = document.querySelector(".metadata"); + const [code, lang] = kernel.kernel_code != null ? [kernel.kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"]; + metadata.replaceChildren(codeBlock(kernel.code_line, "python", { loc:kernel.loc, wrap:true }), codeBlock(code, lang, { wrap:false })); + appendResizer(metadata, { minWidth: 20, maxWidth: 50 }); + // ** rewrite steps + if (kernel.match_count >= 1) { + const rewriteList = metadata.appendChild(document.createElement("div")); + rewriteList.className = "rewrite-list"; + for (let s=0; s<=kernel.match_count; s++) { + const ul = rewriteList.appendChild(document.createElement("ul")); + ul.innerText = s; + ul.id = `rewrite-${s}`; + ul.onclick = () => setState({ currentRewrite:s }); + ul.className = s > ret.length-1 ? "disabled" : s === currentRewrite ? "active" : ""; + if (s > 0 && s === currentRewrite) { + const { upat, diff } = ret[s]; + metadata.appendChild(codeBlock(upat[1], "python", { loc:upat[0], wrap:true })); + const diffCode = metadata.appendChild(document.createElement("pre")); + diffCode.innerHTML = ``+diff.map((line) => { + const color = line.startsWith("+") ? "#3aa56d" : line.startsWith("-") ? "#d14b4b" : "#f0f0f5"; + return `${line}`; + }).join("
")+`
`; + diffCode.className = "wrap"; + } + } + } +} + +// **** collapse/expand + +let isCollapsed = false; +const mainContainer = document.querySelector('.main-container'); +document.querySelector(".collapse-btn").addEventListener("click", (e) => { + isCollapsed = !isCollapsed; + mainContainer.classList.toggle("collapsed", isCollapsed); + e.currentTarget.blur(); + e.currentTarget.style.transform = isCollapsed ? "rotate(180deg)" : "rotate(0deg)"; +}); + +// **** resizer + +function appendResizer(element, { minWidth, maxWidth }, left=false) { + const handle = Object.assign(document.createElement("div"), { className: "resize-handle", style: left ? "right: 0" : "left: 0; margin-top: 0" }); + element.appendChild(handle); + const resize = (e) => { + const change = e.clientX - element.dataset.startX; + let newWidth = ((Number(element.dataset.startWidth)+(left ? change : -change))/Number(element.dataset.containerWidth))*100; + element.style.width = `${Math.max(minWidth, Math.min(maxWidth, newWidth))}%`; + }; + handle.addEventListener("mousedown", (e) => { + e.preventDefault(); + element.dataset.startX = e.clientX; + element.dataset.containerWidth = mainContainer.getBoundingClientRect().width; + element.dataset.startWidth = element.getBoundingClientRect().width; + document.documentElement.addEventListener("mousemove", resize, false); + document.documentElement.addEventListener("mouseup", () => { + document.documentElement.removeEventListener("mousemove", resize, false); + element.style.userSelect = "initial"; + }, { once: true }); + }); +} +appendResizer(document.querySelector(".kernel-list-parent"), { minWidth: 15, maxWidth: 50 }, left=true); + +// **** keyboard shortcuts + +document.addEventListener("keydown", async function(event) { + const { currentKernel, currentUOp, currentRewrite, expandKernel } = state; + // up and down change the UOp or kernel from the list + if (!expandKernel) { + if (event.key == "ArrowUp") { + event.preventDefault() + return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.max(0, currentKernel-1) }); + } + if (event.key == "ArrowDown") { + event.preventDefault() + return setState({ currentUOp:0, currentRewrite:0, currentKernel:Math.min(kernels.length-1, currentKernel+1) }); + } + } + if (event.key == "Enter") { + event.preventDefault() + if (state.currentKernel === -1) { + return setState({ currentKernel:0, expandKernel:true }); + } + return setState({ currentUOp:0, currentRewrite:0, expandKernel:!expandKernel }); + } + if (event.key == "ArrowUp") { + event.preventDefault() + return setState({ currentRewrite:0, currentUOp:Math.max(0, currentUOp-1) }); + } + if (event.key == "ArrowDown") { + event.preventDefault() + const totalUOps = kernels[currentKernel][1].length-1; + return setState({ currentRewrite:0, currentUOp:Math.min(totalUOps, currentUOp+1) }); + } + // left and right go through rewrites in a single UOp + if (event.key == "ArrowLeft") { + event.preventDefault() + return setState({ currentRewrite:Math.max(0, currentRewrite-1) }); + } + if (event.key == "ArrowRight") { + event.preventDefault() + const totalRewrites = ret.length-1; + return setState({ currentRewrite:Math.min(totalRewrites, currentRewrite+1) }); + } + if (event.key == " ") { + document.getElementById("zoom-to-fit-btn").click(); + } +}); + +main() diff --git a/tinygrad/viz/lib/worker.js b/tinygrad/viz/js/worker.js similarity index 100% rename from tinygrad/viz/lib/worker.js rename to tinygrad/viz/js/worker.js diff --git a/tinygrad/viz/lib/graph.js b/tinygrad/viz/lib/graph.js deleted file mode 100644 index 48fcab9fbc..0000000000 --- a/tinygrad/viz/lib/graph.js +++ /dev/null @@ -1,191 +0,0 @@ -function intersectRect(r1, r2) { - const dx = r2.x-r1.x; - const dy = r2.y-r1.y; - if (dx === 0 && dy === 0) throw new Error("Invalid node coordinates, rects must not overlap"); - const scaleX = dx !== 0 ? (r1.width/2)/Math.abs(dx) : Infinity; - const scaleY = dy !== 0 ? (r1.height/2)/Math.abs(dy) : Infinity; - const scale = Math.min(scaleX, scaleY); - return {x:r1.x+dx*scale, y:r1.y+dy*scale}; -} - -let [workerUrl, worker, timeout] = [null, null, null]; -window.renderGraph = async function(graph, additions, name, recenter=false) { - if (name === "View Memory Graph") { - return renderMemoryGraph(graph); - } - d3.select("#bars").html(""); - - // ** start calculating the new layout (non-blocking) - if (worker == null) { - const resp = await Promise.all(["/assets/dagrejs.github.io/project/dagre/latest/dagre.min.js","/lib/worker.js"].map(u => fetch(u))); - workerUrl = URL.createObjectURL(new Blob([(await Promise.all(resp.map((r) => r.text()))).join("\n")], { type: "application/javascript" })); - worker = new Worker(workerUrl); - } else { - worker.terminate(); - worker = new Worker(workerUrl); - } - if (timeout != null) clearTimeout(timeout); - const progressMessage = document.querySelector(".progress-message"); - timeout = setTimeout(() => { - progressMessage.style.display = "block"; - }, 2000); - worker.postMessage({graph, additions}); - - worker.onmessage = (e) => { - progressMessage.style.display = "none"; - clearTimeout(timeout); - const g = dagre.graphlib.json.read(e.data); - // ** draw nodes - const nodeRender = d3.select("#nodes"); - const nodes = nodeRender.selectAll("g").data(g.nodes().map(id => g.node(id)), d => d).join("g") - .attr("transform", d => `translate(${d.x},${d.y})`); - nodes.selectAll("rect").data(d => [d]).join("rect").attr("width", d => d.width).attr("height", d => d.height).attr("fill", d => d.color) - .attr("x", d => -d.width/2).attr("y", d => -d.height/2).attr("style", d => d.style); - // +labels - nodes.selectAll("g.label").data(d => [d]).join("g").attr("class", "label").attr("transform", d => { - const x = (d.width-d.padding*2)/2; - const y = (d.height-d.padding*2)/2; - return `translate(-${x}, -${y})`; - }).selectAll("text").data(d => [d.label.split("\n")]).join("text").selectAll("tspan").data(d => d).join("tspan").text(d => d).attr("x", "1") - .attr("dy", 14).attr("xml:space", "preserve"); - - // ** draw edges - const line = d3.line().x(d => d.x).y(d => d.y).curve(d3.curveBasis); - const edgeRender = d3.select("#edges"); - edgeRender.selectAll("path.edgePath").data(g.edges()).join("path").attr("class", "edgePath").attr("d", (e) => { - const edge = g.edge(e); - const points = edge.points.slice(1, edge.points.length-1); - points.unshift(intersectRect(g.node(e.v), points[0])); - points.push(intersectRect(g.node(e.w), points[points.length-1])); - return line(points); - }).attr("marker-end", "url(#arrowhead)"); - if (recenter) document.getElementById("zoom-to-fit-btn").click(); - }; -} - - -DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4, - "long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8} -function getBuffer(e) { - const [_, size, dtype, device, num] = e.label.split("\n"); - return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])}; -} -function pluralize(num, name, alt=null) { - return num === 1 ? `${num} ${name}` : `${num} ${alt ?? name+'s'}` -} - -function renderMemoryGraph(graph) { - // ** construct alloc/free traces - // we can map reads/writes from the kernel graph - const actions = []; - const children = new Map(); // {buffer: [...assign]} - for (const [k,v] of Object.entries(graph)) { - if (!v.label.startsWith("ASSIGN")) continue; - actions.push({ op: "write", buffer: v.src[0] }); - for (const ks of graph[v.src[1]].src) { - const node = graph[ks]; - const s = node.label.startsWith("ASSIGN") ? node.src[0] : ks; - if (!children.has(s)) children.set(s, []); - children.get(s).push(v); - if (s !== v.src[0]) actions.push({ op: "read", buffer: s }); - } - } - const prealloc = new Set(); - const traces = []; - for (const a of actions) { - // a buffer is allocated immediately before the first write - // TODO: we don't know the buffer is preallocated if there's only an assign in the graph - if (a.op === "write") { - traces.push({ type: "alloc", buffer: a.buffer }); - } - else { - if (traces.find(t => t.buffer === a.buffer && t.type === "alloc") == null) { - prealloc.add(a.buffer); - } - else if (a === actions.findLast(({ buffer }) => buffer === a.buffer)) { - traces.push({type: "free", buffer: a.buffer }); - } - } - } - // ** get coordinates and layout for each buffer - const ret = {}; - let timestep = 0; // x - let memUsed = 0; // y - for (const id of prealloc) { - const buf = getBuffer(graph[id]); - ret[id] = { x: [timestep], y: [memUsed], buf, id }; - memUsed += buf.nbytes; - } - let peak = memUsed; - const liveBufs = [...prealloc]; - for (const t of traces) { - const buf = getBuffer(graph[t.buffer]); - const idx = liveBufs.findLastIndex(b => t.buffer === b); - // alloc - if (idx === -1) { - liveBufs.push(t.buffer); - ret[t.buffer] = { x: [timestep], y: [memUsed], buf, id: t.buffer }; - memUsed += buf.nbytes; - peak = Math.max(memUsed, peak); - timestep += 1; - } // free - else { - memUsed -= buf.nbytes; - timestep += 1; - const removed = ret[liveBufs.splice(idx, 1)[0]]; - removed.x.push(timestep); - removed.y.push(removed.y.at(-1)); - if (idx < liveBufs.length) { - for (let j=idx; j d3.format(".3~s")(d)+"B"; - axesGroup.append("g").call(d3.axisLeft(yscale).tickFormat(nbytes_format)); - axesGroup.append("g").attr("transform", `translate(0, ${yscale.range()[0]})`).call(d3.axisBottom(xscale).tickFormat(() => "")); - const polygonGroup = render.append("g").attr("id", "polygons"); - const colors = ["7aa2f7", "ff9e64", "f7768e", "2ac3de", "7dcfff", "1abc9c", "9ece6a", "e0af68", "bb9af7", "9d7cd8", "ff007c"]; - const polygons = polygonGroup.selectAll("polygon").data(Object.values(ret)).join("polygon").attr("points", (d) => { - const xs = d.x.map(t => xscale(t)); - const y1 = d.y.map(t => yscale(t)); - const y2 = d.y.map(t => yscale(t+d.buf.nbytes)); - const p0 = xs.map((x, i) => `${x},${y1[i]}`); - const p1 = xs.map((x, i) => `${x},${y2[i]}`).reverse(); - return `${p0.join(' ')} ${p1.join(' ')}`; - }).attr("fill", d => `#${colors[d.buf.num % colors.length]}`).on("mouseover", (e, { id, buf, x }) => { - d3.select(e.currentTarget).attr("stroke", "rgba(26, 27, 38, 0.8)").attr("stroke-width", 0.8); - const metadata = document.querySelector(".container.metadata"); - document.getElementById("current-buf")?.remove(); - const { num, dtype, nbytes, ...rest } = buf; - let label = `\nalive for ${pluralize(x[x.length-1]-x[0], 'timestep')}`; - label += '\n'+Object.entries(rest).map(([k, v]) => `${k}=${v}`).join('\n'); - const buf_children = children.get(id); - if (buf_children) { - label += `\n${pluralize(buf_children.length, 'child', 'children')}\n`; - label += buf_children.map((c,i) => `[${i+1}] `+graph[c.src[1]].label.split("\n")[1]).join("\n"); - } - metadata.appendChild(Object.assign(document.createElement("pre"), { innerText: label, id: "current-buf", className: "wrap" })); - }).on("mouseout", (e, _) => { - d3.select(e.currentTarget).attr("stroke", null).attr("stroke-width", null); - document.getElementById("current-buf")?.remove() - }); - // TODO: add the toposort graph here - document.querySelector(".progress-message").style.display = "none"; - d3.select("#nodes").html(""); - d3.select("#edges").html(""); -} diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 025fc924f3..0b858f4d34 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -123,7 +123,7 @@ class Handler(BaseHTTPRequestHandler): with open(os.path.join(os.path.dirname(__file__), "index.html"), "rb") as f: ret = f.read() elif (url:=urlparse(self.path)).path == "/profiler": with open(os.path.join(os.path.dirname(__file__), "perfetto.html"), "rb") as f: ret = f.read() - elif self.path.startswith(("/assets/", "/lib/")) and '/..' not in self.path: + elif self.path.startswith(("/assets/", "/js/")) and '/..' not in self.path: try: with open(os.path.join(os.path.dirname(__file__), self.path.strip('/')), "rb") as f: ret = f.read() if url.path.endswith(".js"): content_type = "application/javascript"