diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html index d678b4559e..d94eced17c 100644 --- a/tinygrad/viz/index.html +++ b/tinygrad/viz/index.html @@ -143,6 +143,30 @@ g.label rect.bg.highlight { fill: #5f0059; } + #insts span.highlight { + background-color: rgba(0, 199, 47, 0.2); + } + #insts .line { + display: flex; + flex-direction: column; + cursor: pointer; + margin-bottom: 8px; + } + #insts .left { + display: flex; + gap: 8px; + } + #insts .n { + color: #787fa1; + min-width: 5ch; + } + #insts .wave { + color: #7aa2f7; + min-width: 2ch; + } + #insts .pc { + color: #73daca; + } g.node.highlight rect.node, .edgePath.highlight, g.port circle { stroke: #89C9A2; } diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 77e0fb96fb..3fe2996466 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -234,7 +234,7 @@ const drawLine = (ctx, x, y, opts) => { } function tabulate(rows) { - const root = d3.create("div").style("display", "grid").style("grid-template-columns", `${Math.max(...rows.map(x => x[0].length), 0)}ch 1fr`).style("gap", "0.2em"); + const root = d3.create("div").style("display", "grid").style("grid-template-columns", `${Math.max(...rows.map(x => x[0].length), 0)}ch 1fr`).style("gap", "0.2em").style("white-space", "nowrap"); for (const [k,v] of rows) { root.append("div").text(k); root.append("div").node().append(v); } return root; } @@ -264,7 +264,8 @@ function setFocus(key) { focusedShape = key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel); } const { eventType, e } = selectShape(key); - const html = d3.create("div").classed("info", true); + if (metadata.querySelector(".info") == null) d3.select(metadata).html("").append("div").classed("info", true); + const html = d3.select(".info").html(""); if (eventType === EventTypes.EXEC) { const [n, _, ...rest] = e.arg.tooltipText.split("\n"); html.append(() => tabulate([["Name", d3.create("p").html(n).node()], ["Duration", formatTime(e.width)], ["Start Time", formatTime(e.x)]]).node()); @@ -298,7 +299,24 @@ function setFocus(key) { if (shape != null) p.style("cursor", "pointer").on("click", () => setFocus(shape)); } } - return metadata.replaceChildren(html.node()); + // instructions list renderer + let instList = document.getElementById("insts"); + if (data.pcToShape.size > 0 && instList == null) { + let contents = "", i = 0; + for (const [k, v] of data.pcToShape) { + contents += `
${i++}${v.wave} + ${"0x"+v.pc.toString(16).padStart(12, "0")}${data.pcMap[v.pc]}
`; + } + instList = d3.create("pre").append("code").classed("hljs", true).style("margin-top", "20px").attr("id", "insts").html(contents) + .on("click", e => { const line = e.target.closest(".line"); line && setFocus(line.dataset.k); }).node(); + metadata.insertBefore(instList.parentElement, html.node()); + } + d3.select(instList).selectAll("span").classed("highlight", false); + const instLine = document.getElementById(`inst-${key}`); instLine?.classList.add("highlight"); + if (instLine != null && instList != null) { + const r = rect(instLine), c = rect(instList); + if (Math.max(c.top-r.bottom, r.top-c.bottom)>=-30) instLine.scrollIntoView({ block:"center" }); + } } const EventTypes = { EXEC:0, BUF:1 }; @@ -307,7 +325,7 @@ async function renderProfiler(path, unit, opts) { displaySelection("#profiler"); // support non realtime x axis units formatTime = unit === "realtime" ? formatMicroseconds : formatCycles; - if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; } + if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null, pcToShape:new Map()}; focusedDevice = null; focusedShape = null; } setFocus(focusedShape); // layout once! if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE); @@ -322,7 +340,7 @@ async function renderProfiler(path, unit, opts) { const optional = (i) => i === 0 ? null : i-1; const dur = u32(), tracePeak = u64(), indexLen = u32(), layoutsLen = u32(); data.dur = dur; const textDecoder = new TextDecoder("utf-8"); - const { strings, dtypeSize, markers } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen; + const { strings, dtypeSize, markers, ...extData } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen; // place devices on the y axis and set vertical positions const [tickSize, padding, baseOffset] = [10, 8, markers.length ? 14 : 0]; const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+baseOffset+"px"); @@ -341,7 +359,8 @@ async function renderProfiler(path, unit, opts) { const k = textDecoder.decode(new Uint8Array(buf, offset, nameLen)); offset += nameLen; const div = deviceList.append("div").attr("id", k).text(k).style("padding", padding+"px").style("width", opts.width); const { y:baseY, height:baseHeight } = rect(div.node()); - const colors = colorScheme[k.split(":")[0]] ?? colorScheme.DEFAULT; + const [dname, dnum] = k.split(":", 2); + const colors = colorScheme[dname] ?? colorScheme.DEFAULT; const offsetY = baseY-canvasTop+padding/2; const shapes = [], visible = []; const eventType = u8(), eventsLen = u32(); @@ -398,8 +417,9 @@ async function renderProfiler(path, unit, opts) { // tiny device events go straight to the rewrite rule const key = k.startsWith("TINY") ? null : `${k}-${j}`; const labelHTML = label.map(l=>`${l.st}`).join(""); - const arg = { tooltipText:labelHTML+" N:"+shapes.length+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), bufs:[], key, - ctx:shapeRef?.ctx, step:shapeRef?.step }; + let info = e.info != null ? "\n"+e.info : ""; + if (info.startsWith("\nPC:")) data.pcToShape.set(key, {wave:dnum, pc:parseInt(e.info.split(":")[1]), st:e.st}); info = ""; + const arg = { tooltipText:labelHTML+" N:"+shapes.length+"\n"+formatTime(e.dur)+info, bufs:[], key, ctx:shapeRef?.ctx, step:shapeRef?.step }; if (e.key != null) shapeMap.set(e.key, key); // offset y by depth shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label:opts.hideLabels ? null : label, fillColor }); @@ -492,6 +512,8 @@ async function renderProfiler(path, unit, opts) { } } for (const m of markers) m.label = m.name.split(/(\s+)/).map(st => ({ st, color:m.color, width:ctx.measureText(st).width })); + data.pcToShape = new Map([...data.pcToShape].sort((a, b) => a[1].st - b[1].st)); + if (extData.pcMap != null) data.pcMap = extData.pcMap; setFocus(focusedShape); updateProgress(Status.COMPLETE); // draw events on a timeline const dpr = window.devicePixelRatio || 1; @@ -837,6 +859,8 @@ async function main() { // ** center graph const { currentCtx, currentStep, currentRewrite, expandSteps } = state; if (currentCtx == -1) return; + // always have a new sidebar when view changes + metadata.innerHTML = ""; const ctx = ctxs[currentCtx]; const step = ctx.steps[currentStep]; const ckey = step?.query; @@ -868,7 +892,6 @@ async function main() { opts = {heightScale:0.5, hideLabels:true, levelKey:step.name.includes("PKTS") ? (e) => parseInt(e.name.split(" ")[1].split(":")[1]) : null, colorByName:ckey.includes("pkts")}; return renderProfiler(ckey, "clk", opts); } - metadata.innerHTML = ""; ret.metadata?.forEach(m => { if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => { return [label.trim(), typeof value === "string" ? value : formatUnit(value)]; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 4aac563d3b..4e44faa735 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -342,7 +342,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: def add(name:str, p:PacketType, idx=0, width=1, op_name=None, wave=None, info:InstructionInfo|None=None) -> None: if hasattr(p, "wave"): wave = p.wave rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}")) - key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=str(info.inst) if info is not None else None) + key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=f"PC:{info.pc}" if info is not None else None) ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width))) for p, info in map_insts(data, lib, target): if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break @@ -361,7 +361,8 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: add(name.replace("_ALT", ""), p, op_name=name) if p._time in trace.setdefault(name, set()): raise AssertionError(f"packets overlap in shared resource! {name}") trace[name].add(p._time) - return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret + pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()} + return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in rows]+ret # ** SQTT OCC only unpacks wave start, end time and SIMD location @@ -415,6 +416,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_ # map events per device dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {} markers:list[ProfilePointEvent] = [] + ext_data:dict[str, Any] = {} start_ts:int|None = None end_ts:int|None = None for ts,en,e in flatten_events(profile): @@ -422,6 +424,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_ if start_ts is None or st < start_ts: start_ts = st if end_ts is None or et > end_ts: end_ts = et if isinstance(e, ProfilePointEvent) and e.name == "marker": markers.append(e) + if isinstance(e, ProfilePointEvent) and e.name == "JSON": ext_data[e.key] = e.arg if start_ts is None: return None # return layout of per device events layout:dict[str, bytes|None] = {} @@ -434,7 +437,8 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_ layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache) sorted_layout = sorted([k for k,v in layout.items() if v is not None], key=sort_fn) ret = [b"".join([struct.pack("