diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 2510ae17b6..14553c64fe 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -25,6 +25,13 @@ const colored = n => d3.create("span").call(s => s.selectAll("span").data(typeof const rect = (s) => (typeof s === "string" ? document.querySelector(s) : s).getBoundingClientRect(); +// dims of shapes on the canvas aren't tracked by the browser, we compute it +const canvasRect = (s, pixelScale) => { + const { e } = selectShape(s), t = data.tracks.get(s.split("-")[0]); + const x = pixelScale(e.x), w = pixelScale(e.x+e.width)-x, y = t.offsetY+e.y; + return {x0:x, x1:x+w, y0:y, y1:y+e.height}; +}; + let timeout = null; const Status = {STARTED:0, COMPLETE:1, ERR:2} const updateProgress = (st, msg) => { @@ -292,6 +299,8 @@ function setFocus(key) { const [st, et] = xscale.range().map(zoomLevel.invertX, zoomLevel).map(xscale.invert, xscale); if (x1 < st || x0 > et) zoomLevel = d3.zoomIdentity.translate(-xscale((x0+x1)/2-(et-st)/2)*zoomLevel.k, 0).scale(zoomLevel.k); } + const link = e?.arg.link ?? data.links.get(key); + data.link = link == null ? null : [key, link]; focusedShape = key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel); } const { eventType, e } = selectShape(key); @@ -352,8 +361,10 @@ function setFocus(key) { metadata.insertBefore(instList.parentElement, html.node()); } d3.select(instList).selectAll("span").classed("highlight", false); - const instLine = document.getElementById(`inst-${e?.arg.pc}`); instLine?.classList.add("highlight"); + let instLine = document.getElementById(`inst-${e?.arg.pc}`); + if (instLine == null && data.link != null) instLine = document.getElementById(`inst-${selectShape(data.link[1]).e.arg.pc}`); if (instLine != null) { + instLine.classList.add("highlight"); const r = rect(instLine), c = rect(instList); if (Math.max(c.top-r.bottom, r.top-c.bottom)>=-30) instList.scrollTop = instLine.offsetTop-instList.clientHeight/2+instLine.clientHeight/2; } @@ -366,7 +377,7 @@ async function renderProfiler(path, opts) { displaySelection("#profiler"); // support non realtime x axis units formatTime = opts.unit === "ms" ? formatMicroseconds : formatCycles; - if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; } + if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null, links:new Map()}; focusedDevice = null; focusedShape = null; } setFocus(focusedShape); // layout once! if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE); @@ -455,10 +466,11 @@ async function renderProfiler(path, opts) { } // tiny device events go straight to the rewrite rule const key = k.startsWith("TINY") ? null : `${k}-${j}`; - let info = e.info != null ? "\n"+e.info : "", trace = null, pc = null + let info = e.info != null ? "\n"+e.info : "", trace = null, pc = null, link = null if (info.startsWith("\nPC:")) { pc = parseInt(e.info.split(":")[1]); info = ""; } if (info.startsWith("\nTB:")) { trace = info; info = ""; } - const arg = { tooltipText:" N:"+shapes.length+"\n"+formatTime(e.dur)+info, label, pc, trace, bufs:[], key, ctx:shapeRef?.ctx, step:shapeRef?.step }; + if (info.startsWith("\nLINK:")) { link = info.replace("\nLINK:", ""); info = ""; data.links.set(link, key); } + const arg = { tooltipText:" N:"+shapes.length+"\n"+formatTime(e.dur)+info, label, pc, trace, link, bufs:[], key, ctx:shapeRef?.ctx, step:shapeRef?.step }; if (e.key != null) shapeMap.set(e.key, key); // offset y by depth shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label:opts.hideLabels ? null : label, fillColor }); @@ -625,7 +637,9 @@ async function renderProfiler(path, opts) { // add label drawText(ctx, e.label, x+2, y+e.height/2, width); } - if (focusedShape != null && e.arg?.key === focusedShape) { ctx.strokeStyle = pcolor; ctx.stroke(); } + if ((focusedShape != null && e.arg?.key === focusedShape) || (data.link != null && (e.arg?.key === data.link[0] || e.arg?.key === data.link[1]))) { + ctx.strokeStyle = pcolor; ctx.stroke(); + } } // draw row line if (rowBorderColor != null) { @@ -633,6 +647,15 @@ async function renderProfiler(path, opts) { drawLine(ctx, [0, canvasWidth], [y, y], { color:rowBorderColor }); } } + // draw the link + if (data.link != null) { + const [a, b] = [canvasRect(data.link[0], xscale), canvasRect(data.link[1], xscale)]; + const [left, right] = a.x0 <= b.x0 ? [a, b] : [b, a]; + const startX = left.x1, endX = right.x0; + const leftY = (left.y0+left.y1)/2, rightY = (right.y0+right.y1)/2; + const dx = endX-startX, bend = Math.max(12, Math.min(40, dx/2)); + ctx.beginPath(); ctx.moveTo(startX, leftY); ctx.bezierCurveTo(startX+bend, leftY, endX-bend, rightY, endX, rightY); ctx.strokeStyle = "#858b9d"; ctx.stroke(); + } // draw axes ctx.translate(0, baseOffset); const y = secondaryTick != null ? tickSize+padding : 0; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index d16d542812..9069993b90 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -343,10 +343,14 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: WAVEEND, CDNA_WAVEEND, WAVERDY) ret:list[ProfileEvent] = [] row_ends:dict[str, Decimal] = {} + row_counts:dict[str, itertools.count] = {} curr_barrier:dict[str, ProfileRangeEvent] = {} + exec_pending:dict[str, list[str]] = {} NS_PER_TICK = 10 # 100MHz prev_pair:tuple[int, int]|None = None # (shader, realtime) is_cdna = target.startswith("gfx9") + dispatch_to_exec = {"WMMA":"VALU", "VALU":"VALU", "VALUINST":"VALU", "VINTERP":"VALU", "GLOBAL":"VMEM", "FLAT":"VMEM", "LDS":"LDS", "SALU":"SALU", + "SMEM":"SALU", "VMEM":"VMEM"} def add(name:str, p:PacketType, op:str|None=None, wave:int|None=None, info:InstructionInfo|None=None) -> None: row = f"WAVE:{wave}" if (wave:=getattr(p, "wave", wave)) is not None else f"{p.__class__.__name__}:0 {name}" # barrier on this row extends to fill the time our wave was waiting @@ -355,7 +359,12 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]: # allow CDNA packets to overlap, NOT allowed on RDNA. if (et:=row_ends.get(row)) is not None and e.st < et and not is_cdna: raise RuntimeError(f"packet {p} overlaps another packet in {row}.") row_ends[row] = unwrap(e.en) + idx = next(row_counts.setdefault(row, itertools.count(0))) if name == "BARRIER": curr_barrier[row] = e + # queue for exec linking + if isinstance(p, (VALUINST, INST, INST_RDNA4)) and (exec_type:=dispatch_to_exec.get(name.split("_")[0])) is not None: + exec_pending.setdefault(exec_type, []).append(f"{row}-{idx}") + if isinstance(p, (ALUEXEC, VMEMEXEC)) and "ALT" not in str(p.src): e.name = TracingKey(op or name, ret=f"LINK:{exec_pending[name].pop(0)}") for p, info in map_insts(data, lib, target): if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break if isinstance(p, (TS_DELTA_OR_MARK, TS_DELTA_OR_MARK_RDNA4)) and p.is_marker: