From efaee756562547689eae370faaed9f8f2e9097c1 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Mon, 24 Mar 2025 19:05:35 +0800 Subject: [PATCH] start viz of memory usage (#9561) * start viz of memory usage * polygons/bars + use d3 --- tinygrad/engine/schedule.py | 1 + tinygrad/viz/index.html | 3 +- tinygrad/viz/lib/graph.js | 110 +++++++++++++++++++++++++++++++++++- 3 files changed, 112 insertions(+), 2 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 8e5a278bd1..ada1aa6d8f 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -453,6 +453,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va # display the final graph if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph") + if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph") # final toposort (bfs) children: dict[UOp, list[UOp]] = {} diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index 94566990de..d4ebe19ccf 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -228,6 +228,7 @@ + @@ -377,7 +378,7 @@ }; } if (ret.length === 0) return; - renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || []); + renderGraph(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], kernel.name); // ***** RHS metadata const metadata = document.querySelector(".container.metadata"); metadata.innerHTML = ""; diff --git a/tinygrad/viz/lib/graph.js b/tinygrad/viz/lib/graph.js index 743244d2b1..0494e09c96 100644 --- a/tinygrad/viz/lib/graph.js +++ b/tinygrad/viz/lib/graph.js @@ -9,13 +9,18 @@ function intersectRect(r1, r2) { } const allWorkers = []; -window.renderGraph = function(graph, additions) { +window.renderGraph = function(graph, additions, name) { while (allWorkers.length) { const { worker, timeout } = allWorkers.pop(); worker.terminate(); clearTimeout(timeout); } + if (name === "View Memory Graph") { + return renderMemoryGraph(graph); + } + d3.select("#bars").html(""); + // ** start calculating the new layout (non-blocking) worker = new Worker("/lib/worker.js"); const progressMessage = document.querySelector(".progress-message"); @@ -58,3 +63,106 @@ window.renderGraph = function(graph, additions) { .attr("markerWidth", 6).attr("markerHeight", 6).attr("orient", "auto").append("path").attr("d", "M0,-5L10,0L0,5").attr("fill", "#4a4b57"); }; } + + +DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4, + "long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8} +function getBuffer(e) { + const [_, size, dtype, device, num] = e.label.split("\n"); + return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])}; +} + +function renderMemoryGraph(graph) { + // ** construct alloc/free traces + // we can map reads/writes from the kernel graph + const actions = []; + for (const [k,v] of Object.entries(graph)) { + if (!(v.label.startsWith("ASSIGN"))) continue; + actions.push({ op: "write", buffer: v.src[0] }); + for (const s of graph[v.src[1]].src) { + const snode = graph[s]; + const srcBuf = snode.label.startsWith("ASSIGN") ? snode.src[0] : s; + if (srcBuf !== v.src[0]) actions.push({ op: "read", buffer: srcBuf }); + } + } + const prealloc = new Set(); + const traces = []; + for (const a of actions) { + // a buffer is allocated immediately before the first write + // TODO: we don't know the buffer is preallocated if there's only an assign in the graph + if (a.op === "write") { + traces.push({ type: "alloc", buffer: a.buffer }); + } + else { + if (traces.find(t => t.buffer === a.buffer && t.type === "alloc") == null) { + prealloc.add(a.buffer); + } + else if (a === actions.findLast(({ buffer }) => buffer === a.buffer)) { + traces.push({type: "free", buffer: a.buffer }); + } + } + } + // ** get coordinates and layout for each buffer + const ret = {}; + let timestep = 0; // x + let memUsed = 0; // y + for (const id of prealloc) { + const buf = getBuffer(graph[id]); + ret[id] = { x: [timestep], y: [memUsed], buf }; + memUsed += buf.nbytes; + } + let peak = memUsed; + const liveBufs = [...prealloc]; + for (const t of traces) { + const buf = getBuffer(graph[t.buffer]); + const idx = liveBufs.findLastIndex(b => t.buffer === b); + // alloc + if (idx === -1) { + liveBufs.push(t.buffer); + ret[t.buffer] = { x: [timestep], y: [memUsed], buf }; + memUsed += buf.nbytes; + peak = Math.max(memUsed, peak); + timestep += 1; + } // free + else { + memUsed -= buf.nbytes; + timestep += 1; + const removed = ret[liveBufs.splice(idx, 1)[0]]; + removed.x.push(timestep); + removed.y.push(removed.y.at(-1)); + if (idx < liveBufs.length) { + for (let j=idx; j "")); + const polygonGroup = render.append("g").attr("id", "polygons"); + const colors = ["7aa2f7", "ff9e64", "f7768e", "2ac3de", "7dcfff", "1abc9c", "9ece6a", "e0af68", "bb9af7", "9d7cd8", "ff007c"]; + const polygons = polygonGroup.selectAll("polygon").data(Object.values(ret)).join("polygon").attr("points", (d) => { + const xs = d.x.map(t => xscale(t)); + const y1 = d.y.map(t => yscale(t)); + const y2 = d.y.map(t => yscale(t+d.buf.nbytes)); + const p0 = xs.map((x, i) => `${x},${y1[i]}`); + const p1 = xs.map((x, i) => `${x},${y2[i]}`).reverse(); + return `${p0.join(' ')} ${p1.join(' ')}`; + }).attr("fill", d => `#${colors[d.buf.num % colors.length]}`); + // TODO: add the toposort graph here + d3.select("#nodes").html(""); + d3.select("#edges").html(""); +}