diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 06fde491fd..333cdb94b3 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -16,6 +16,7 @@ const darkenHex = (h, p = 0) => const ANSI_COLORS = ["#b3b3b3", "#ff6666", "#66b366", "#ffff66", "#6666ff", "#ff66ff", "#66ffff", "#ffffff"]; const ANSI_COLORS_LIGHT = ["#d9d9d9","#ff9999","#99cc99","#ffff99","#9999ff","#ff99ff","#ccffff","#ffffff"]; +const colorsCache = new Map(); const parseColors = (name, defaultColor="#ffffff") => Array.from(name.matchAll(/(?:\u001b\[(\d+)m([\s\S]*?)\u001b\[0m)|([^\u001b]+)/g), ([_, code, colored_st, st]) => ({ st: colored_st ?? st, color: code != null ? (code>=90 ? ANSI_COLORS_LIGHT : ANSI_COLORS)[(parseInt(code)-30+60)%60] : defaultColor })); @@ -191,22 +192,58 @@ function tabulate(rows) { return root; } -var data, focusedDevice, focusedShape, canvasZoom, zoomLevel = d3.zoomIdentity, shapeMetadata = new Map(); +var data, focusedDevice, focusedShape, canvasZoom, formatTime, zoomLevel = d3.zoomIdentity; + +function getMetadata(shape) { + if (shape == null) return; + const [t, idx] = shape.split("-"); + const track = data.tracks.get(t); + if (track == null) return; + const e = track.shapes[idx]; + const html = d3.create("div").classed("info", true); + if (track.eventType === EventTypes.EXEC) { + html.append(() => tabulate([["Name", d3.create("p").html(e.arg.tooltipText.split("\n")[0]).node()], + ["Duration", formatTime(e.width)], ["Start Time", formatTime(e.x)]]).node()); + if (e.arg.ctx != null) { + html.append("a").text("View codegen rewrite").on("click", () => switchCtx(e.arg.ctx, e.arg.step)); + html.append("a").text("View program").on("click", () => switchCtx(e.arg.ctx, ctxs[e.arg.ctx+1].steps.findIndex(s => s.name==="View Program"))); + } + } + if (track.eventType === EventTypes.BUF) { + const [dtype, sz, nbytes, dur] = e.arg.tooltipText.split("\n"); + const rows = [["DType", dtype], ["Len", sz], ["Size", nbytes], ["Lifetime", dur]]; + if (e.arg.users != null) rows.push(["Users", e.arg.users.length]); + html.append(() => tabulate(rows).node()); + const kernels = html.append("div").classed("args", true); + for (let u=0; u colored(`[${u}] ${repr} ${bufInfo}`)); + const shapeTxt = shape?.tooltipText?.split("\n").at(-1); + if (shapeTxt != null) p.append("span").text(" "+shapeTxt); + if (shape != null) { + p.style("cursor", "pointer").on("click", () => focusShape(shape)); + } + } + } + return html.node(); +} + function focusShape(shape) { saveToHistory({ shape:focusedShape }); focusedShape = shape?.key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel); - return metadata.replaceChildren(shapeMetadata.get(focusedShape) ?? ""); + return metadata.replaceChildren(getMetadata(focusedShape) ?? ""); } const EventTypes = { EXEC:0, BUF:1 }; async function renderProfiler(path, unit, opts) { displaySelection("#profiler"); - metadata.replaceChildren(shapeMetadata.get(focusedShape) ?? ""); + // support non realtime x axis units + formatTime = unit === "realtime" ? formatMicroseconds : (s) => formatUnit(s, " "+unit); + metadata.replaceChildren(getMetadata(focusedShape) ?? ""); // layout once! if (data != null && data.path === path) return updateProgress({ start:false }); - // support non realtime x axis units - const formatTime = unit === "realtime" ? formatMicroseconds : (s) => formatUnit(s, " "+unit); const profiler = d3.select("#profiler").html(""); const buf = cache[path] ?? await fetchValue(path); const view = new DataView(buf); @@ -274,18 +311,10 @@ async function renderProfiler(path, unit, opts) { const stepIdx = ctxs[ref.ctx+1].steps.findIndex((s, i) => i >= start && s.name == e.name); if (stepIdx !== -1) { ref.step = stepIdx; shapeRef = ref; } } - const html = d3.create("div").classed("info", true); - html.append(() => tabulate([["Name", colored(e.name)], ["Duration", formatTime(e.dur)], ["Start Time", formatTime(e.st)]]).node()); - html.append("div").classed("args", true); - if (e.info != null) html.append("p").style("white-space", "pre-wrap").text(e.info); - if (shapeRef != null) { - html.append("a").text("View codegen rewrite").on("click", () => switchCtx(shapeRef.ctx, shapeRef.step)); - html.append("a").text("View program").on("click", () => switchCtx(shapeRef.ctx, ctxs[shapeRef.ctx+1].steps.findIndex(s => s.name==="View Program"))); - } // tiny device events go straight to the rewrite rule const key = k.startsWith("TINY") ? null : `${k}-${j}`; - if (key != null) shapeMetadata.set(key, html.node()); - const arg = { tooltipText:colored(label).outerHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), key, + const labelHTML = label.map(l=>`${l.st}`).join(""); + const arg = { tooltipText:labelHTML+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), key, ctx:shapeRef?.ctx, step:shapeRef?.step }; if (e.key != null) shapeMap.set(e.key, arg); // offset y by depth @@ -326,33 +355,7 @@ async function renderProfiler(path, unit, opts) { for (const [num, {dtype, sz, nbytes, y, x:steps, users}] of buf_shapes) { const x = steps.map(s => timestamps[s]); const dur = x.at(-1)-x[0]; - const html = d3.create("div").classed("info", true); - const rows = [["DType", dtype], ["Len", formatUnit(sz)], ["Size", formatUnit(nbytes, "B")], ["Lifetime", formatTime(dur)]]; - if (users != null) rows.push(["Users", users.length]); - const info = html.append(() => tabulate(rows).node()); - const arg = {tooltipText:info.node().outerHTML, key:`${k}-${num}`}; - const kernels = html.append("div").classed("args", true); - for (let u=0; u colored(`[${u}] ${repr} ${bufInfo}`)); - const shapeTxt = shape?.tooltipText?.split("\n").at(-1); - if (shapeTxt != null) p.append("span").text(" "+shapeTxt); - if (shape != null) { - p.style("cursor", "pointer").on("click", () => focusShape(shape)) - const args = shapeMetadata.get(shape.key).querySelector(".args"); - const bufArg = d3.create("p").text(`${bufInfo} ${rows[2][1]}`).style("cursor", "pointer").on("click", () => { - const device = document.getElementById(k); - if (!isExpanded(device)) device.click(); - focusShape(arg); - }).node(); - bufArg.dataset.num = num; - let before = null; - for (const c of args.children) { if (+c.dataset.num > num) { before = c; break; } } - args.insertBefore(bufArg, before); - } - } - shapeMetadata.set(arg.key, html.node()) + const arg = { tooltipText:`${dtype}\n${sz}\n${formatUnit(nbytes, 'B')}\n${formatTime(dur)}`, users, key:`${k}-${shapes.length}` }; shapes.push({ x, y0:y.map(yscale), y1:y.map(y0 => yscale(y0+nbytes)), arg, fillColor:cycleColors(colorScheme.BUFFER, shapes.length) }); } // generic polygon merger