viz: remove per schedule renderMemoryGraph (#11019)

replaced with per device Buffer viz https://github.com/tinygrad/tinygrad/pull/10960
This commit is contained in:
qazal
2025-06-28 22:09:38 +03:00
committed by GitHub
parent 4c8d2a0383
commit cb6a66ea84
3 changed files with 1 additions and 138 deletions

View File

@@ -470,7 +470,6 @@ def get_kernelize_map(sink:UOp) -> dict[UOp, UOp]:
# display the final graph
sched_sink = tensor_map[sink]
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Kernel Graph")
if getenv("VIZ"): graph_rewrite(sched_sink, PatternMatcher([]), name="View Memory Graph")
# verify Kernels match the spec
if __debug__: type_verify(list(sched_sink.toposort()), tensor_uop_spec)

View File

@@ -248,7 +248,6 @@
<g id="edges"></g>
<g id="nodes"></g>
<g id="edge-labels"></g> <!-- NOTE: this ensures edge labels are always on top -->
<g id="bars"></g>
</g>
<defs>
<marker id="arrowhead" viewBox="0 -5 10 10" refX="10" refY="0" markerWidth="6" markerHeight="6" orient="auto">

View File

@@ -37,7 +37,6 @@ async function renderDag(graph, additions, recenter=false) {
displayGraph("graph");
progressMessage.style.display = "none";
clearTimeout(timeout);
d3.select("#bars").html("");
const g = dagre.graphlib.json.read(e.data);
// draw nodes
const STROKE_WIDTH = 1.4;
@@ -96,136 +95,6 @@ async function renderDag(graph, additions, recenter=false) {
}
// ** Memory graph (WIP)
DTYPE_SIZE = {"bool": 1, "char": 1, "uchar": 1, "short": 2, "ushort": 2, "int": 4, "uint": 4,
"long": 8, "ulong": 8, "half": 2, "bfloat": 2, "float": 4, "double": 8}
function getBuffer(e) {
const [_, size, dtype, num, device] = e.label.split("\n");
return {nbytes:size*DTYPE_SIZE[dtype.split("dtypes.")[1]], dtype, device:device.split(" ")[1], num:parseInt(num.split(" ")[1])};
}
function pluralize(num, name, alt=null) {
return num === 1 ? `${num} ${name}` : `${num} ${alt ?? name+'s'}`
}
function renderMemoryGraph(graph) {
displayGraph("graph");
// ** construct alloc/free traces
// we can map reads/writes from the kernel graph
const actions = [];
const children = new Map(); // {buffer: [...assign]}
for (const [k,v] of Object.entries(graph)) {
if (!v.label.startsWith("ASSIGN")) continue;
actions.push({ op: "write", buffer: v.src[0] });
for (const ks of graph[v.src[1]].src) {
const node = graph[ks];
const s = node.label.startsWith("ASSIGN") ? node.src[0] : ks;
if (!children.has(s)) children.set(s, []);
children.get(s).push(v);
if (s !== v.src[0]) actions.push({ op: "read", buffer: s });
}
}
const prealloc = new Set();
const traces = [];
for (const a of actions) {
// a buffer is allocated immediately before the first write
// TODO: we don't know the buffer is preallocated if there's only an assign in the graph
if (a.op === "write") {
traces.push({ type: "alloc", buffer: a.buffer });
}
else {
if (traces.find(t => t.buffer === a.buffer && t.type === "alloc") == null) {
prealloc.add(a.buffer);
}
else if (a === actions.findLast(({ buffer }) => buffer === a.buffer)) {
traces.push({type: "free", buffer: a.buffer });
}
}
}
// ** get coordinates and layout for each buffer
const ret = {};
let timestep = 0; // x
let memUsed = 0; // y
for (const id of prealloc) {
const buf = getBuffer(graph[id]);
ret[id] = { x: [timestep], y: [memUsed], buf, id };
memUsed += buf.nbytes;
}
let peak = memUsed;
const liveBufs = [...prealloc];
for (const t of traces) {
const buf = getBuffer(graph[t.buffer]);
const idx = liveBufs.findLastIndex(b => t.buffer === b);
// alloc
if (idx === -1) {
liveBufs.push(t.buffer);
ret[t.buffer] = { x: [timestep], y: [memUsed], buf, id: t.buffer };
memUsed += buf.nbytes;
peak = Math.max(memUsed, peak);
timestep += 1;
} // free
else {
memUsed -= buf.nbytes;
timestep += 1;
const removed = ret[liveBufs.splice(idx, 1)[0]];
removed.x.push(timestep);
removed.y.push(removed.y.at(-1));
if (idx < liveBufs.length) {
for (let j=idx; j<liveBufs.length; j++) {
const b = ret[liveBufs[j]];
b.x.push(timestep, timestep);
b.y.push(b.y.at(-1), b.y.at(-1)-buf.nbytes);
}
}
}
}
for (const id of liveBufs) {
const b = ret[id];
b.x.push(timestep);
b.y.push(b.y.at(-1));
}
// ** render traces
// clear existing groups
document.querySelector(".progress-message").style.display = "none";
for (c of document.getElementById("render").children) c.innerHTML = "";
const render = d3.select("#bars");
const yscale = d3.scaleLinear().domain([0, peak]).range([576, 0]);
const xscale = d3.scaleLinear().domain([0, timestep]).range([0, 1024]);
const axesGroup = render.append("g").attr("id", "axes");
const nbytes_format = (d) => d3.format(".3~s")(d)+"B";
axesGroup.append("g").call(d3.axisLeft(yscale).tickFormat(nbytes_format));
axesGroup.append("g").attr("transform", `translate(0, ${yscale.range()[0]})`).call(d3.axisBottom(xscale).tickFormat(() => ""));
const polygonGroup = render.append("g").attr("id", "polygons");
const colors = ["7aa2f7", "ff9e64", "f7768e", "2ac3de", "7dcfff", "1abc9c", "9ece6a", "e0af68", "bb9af7", "9d7cd8", "ff007c"];
const polygons = polygonGroup.selectAll("polygon").data(Object.values(ret)).join("polygon").attr("points", (d) => {
const xs = d.x.map(t => xscale(t));
const y1 = d.y.map(t => yscale(t));
const y2 = d.y.map(t => yscale(t+d.buf.nbytes));
const p0 = xs.map((x, i) => `${x},${y1[i]}`);
const p1 = xs.map((x, i) => `${x},${y2[i]}`).reverse();
return `${p0.join(' ')} ${p1.join(' ')}`;
}).attr("fill", d => `#${colors[d.buf.num % colors.length]}`).on("mouseover", (e, { id, buf, x }) => {
d3.select(e.currentTarget).attr("stroke", "rgba(26, 27, 38, 0.8)").attr("stroke-width", 0.8);
const metadata = document.querySelector(".metadata");
document.getElementById("current-buf")?.remove();
const { num, dtype, nbytes, ...rest } = buf;
let label = `<BUFFER n${num} ${dtype} ${nbytes_format(nbytes)}>\nalive for ${pluralize(x[x.length-1]-x[0], 'timestep')}`;
label += '\n'+Object.entries(rest).map(([k, v]) => `${k}=${v}`).join('\n');
const buf_children = children.get(id);
if (buf_children) {
label += `\n${pluralize(buf_children.length, 'child', 'children')}\n`;
label += buf_children.map((c,i) => `[${i+1}] `+graph[c.src[1]].label.split("\n")[1]).join("\n");
}
metadata.appendChild(Object.assign(document.createElement("pre"), { innerText: label, id: "current-buf", className: "wrap" }));
}).on("mouseout", (e, _) => {
d3.select(e.currentTarget).attr("stroke", null).attr("stroke-width", null);
document.getElementById("current-buf")?.remove()
});
// TODO: add the kernel line here
document.getElementById("zoom-to-fit-btn").click();
}
const ANSI_COLORS = ["#b3b3b3", "#ff6666", "#66b366", "#ffff66", "#6666ff", "#ff66ff", "#66ffff", "#ffffff"];
const parseColors = (name, defaultColor="#ffffff") => [...name.matchAll(/(?:\u001b\[(\d+)m([\s\S]*?)\u001b\[0m)|([^\u001b]+)/g)]
.map(([_, code, colored_st, st]) => ({ st: colored_st ?? st, color: code != null ? ANSI_COLORS[(parseInt(code)-30+60)%60] : defaultColor }));
@@ -606,11 +475,7 @@ async function main() {
};
}
if (ret.length === 0) return;
if (step.name == "View Memory Graph") {
renderMemoryGraph(ret[currentRewrite].graph);
} else {
renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0);
}
renderDag(ret[currentRewrite].graph, ret[currentRewrite].changed_nodes || [], recenter=currentRewrite === 0);
// ** right sidebar code blocks
const metadata = document.querySelector(".metadata");
const [code, lang] = ctx.kernel_code != null ? [ctx.kernel_code, "cpp"] : [ret[currentRewrite].uop, "python"];