diff --git a/tinygrad/viz/index.html b/tinygrad/viz/index.html
index d678b4559e..d94eced17c 100644
--- a/tinygrad/viz/index.html
+++ b/tinygrad/viz/index.html
@@ -143,6 +143,30 @@
g.label rect.bg.highlight {
fill: #5f0059;
}
+ #insts span.highlight {
+ background-color: rgba(0, 199, 47, 0.2);
+ }
+ #insts .line {
+ display: flex;
+ flex-direction: column;
+ cursor: pointer;
+ margin-bottom: 8px;
+ }
+ #insts .left {
+ display: flex;
+ gap: 8px;
+ }
+ #insts .n {
+ color: #787fa1;
+ min-width: 5ch;
+ }
+ #insts .wave {
+ color: #7aa2f7;
+ min-width: 2ch;
+ }
+ #insts .pc {
+ color: #73daca;
+ }
g.node.highlight rect.node, .edgePath.highlight, g.port circle {
stroke: #89C9A2;
}
diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js
index 77e0fb96fb..3fe2996466 100644
--- a/tinygrad/viz/js/index.js
+++ b/tinygrad/viz/js/index.js
@@ -234,7 +234,7 @@ const drawLine = (ctx, x, y, opts) => {
}
function tabulate(rows) {
- const root = d3.create("div").style("display", "grid").style("grid-template-columns", `${Math.max(...rows.map(x => x[0].length), 0)}ch 1fr`).style("gap", "0.2em");
+ const root = d3.create("div").style("display", "grid").style("grid-template-columns", `${Math.max(...rows.map(x => x[0].length), 0)}ch 1fr`).style("gap", "0.2em").style("white-space", "nowrap");
for (const [k,v] of rows) { root.append("div").text(k); root.append("div").node().append(v); }
return root;
}
@@ -264,7 +264,8 @@ function setFocus(key) {
focusedShape = key; d3.select("#timeline").call(canvasZoom.transform, zoomLevel);
}
const { eventType, e } = selectShape(key);
- const html = d3.create("div").classed("info", true);
+ if (metadata.querySelector(".info") == null) d3.select(metadata).html("").append("div").classed("info", true);
+ const html = d3.select(".info").html("");
if (eventType === EventTypes.EXEC) {
const [n, _, ...rest] = e.arg.tooltipText.split("\n");
html.append(() => tabulate([["Name", d3.create("p").html(n).node()], ["Duration", formatTime(e.width)], ["Start Time", formatTime(e.x)]]).node());
@@ -298,7 +299,24 @@ function setFocus(key) {
if (shape != null) p.style("cursor", "pointer").on("click", () => setFocus(shape));
}
}
- return metadata.replaceChildren(html.node());
+ // instructions list renderer
+ let instList = document.getElementById("insts");
+ if (data.pcToShape.size > 0 && instList == null) {
+ let contents = "", i = 0;
+ for (const [k, v] of data.pcToShape) {
+ contents += `
${i++}${v.wave}
+ ${"0x"+v.pc.toString(16).padStart(12, "0")}${data.pcMap[v.pc]}
`;
+ }
+ instList = d3.create("pre").append("code").classed("hljs", true).style("margin-top", "20px").attr("id", "insts").html(contents)
+ .on("click", e => { const line = e.target.closest(".line"); line && setFocus(line.dataset.k); }).node();
+ metadata.insertBefore(instList.parentElement, html.node());
+ }
+ d3.select(instList).selectAll("span").classed("highlight", false);
+ const instLine = document.getElementById(`inst-${key}`); instLine?.classList.add("highlight");
+ if (instLine != null && instList != null) {
+ const r = rect(instLine), c = rect(instList);
+ if (Math.max(c.top-r.bottom, r.top-c.bottom)>=-30) instLine.scrollIntoView({ block:"center" });
+ }
}
const EventTypes = { EXEC:0, BUF:1 };
@@ -307,7 +325,7 @@ async function renderProfiler(path, unit, opts) {
displaySelection("#profiler");
// support non realtime x axis units
formatTime = unit === "realtime" ? formatMicroseconds : formatCycles;
- if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null}; focusedDevice = null; focusedShape = null; }
+ if (data?.path !== path) { data = {tracks:new Map(), axes:{}, path, first:null, pcToShape:new Map()}; focusedDevice = null; focusedShape = null; }
setFocus(focusedShape);
// layout once!
if (data.tracks.size !== 0) return updateProgress(Status.COMPLETE);
@@ -322,7 +340,7 @@ async function renderProfiler(path, unit, opts) {
const optional = (i) => i === 0 ? null : i-1;
const dur = u32(), tracePeak = u64(), indexLen = u32(), layoutsLen = u32(); data.dur = dur;
const textDecoder = new TextDecoder("utf-8");
- const { strings, dtypeSize, markers } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen;
+ const { strings, dtypeSize, markers, ...extData } = JSON.parse(textDecoder.decode(new Uint8Array(buf, offset, indexLen))); offset += indexLen;
// place devices on the y axis and set vertical positions
const [tickSize, padding, baseOffset] = [10, 8, markers.length ? 14 : 0];
const deviceList = profiler.append("div").attr("id", "device-list").style("padding-top", tickSize+padding+baseOffset+"px");
@@ -341,7 +359,8 @@ async function renderProfiler(path, unit, opts) {
const k = textDecoder.decode(new Uint8Array(buf, offset, nameLen)); offset += nameLen;
const div = deviceList.append("div").attr("id", k).text(k).style("padding", padding+"px").style("width", opts.width);
const { y:baseY, height:baseHeight } = rect(div.node());
- const colors = colorScheme[k.split(":")[0]] ?? colorScheme.DEFAULT;
+ const [dname, dnum] = k.split(":", 2);
+ const colors = colorScheme[dname] ?? colorScheme.DEFAULT;
const offsetY = baseY-canvasTop+padding/2;
const shapes = [], visible = [];
const eventType = u8(), eventsLen = u32();
@@ -398,8 +417,9 @@ async function renderProfiler(path, unit, opts) {
// tiny device events go straight to the rewrite rule
const key = k.startsWith("TINY") ? null : `${k}-${j}`;
const labelHTML = label.map(l=>`${l.st}`).join("");
- const arg = { tooltipText:labelHTML+" N:"+shapes.length+"\n"+formatTime(e.dur)+(e.info != null ? "\n"+e.info : ""), bufs:[], key,
- ctx:shapeRef?.ctx, step:shapeRef?.step };
+ let info = e.info != null ? "\n"+e.info : "";
+ if (info.startsWith("\nPC:")) data.pcToShape.set(key, {wave:dnum, pc:parseInt(e.info.split(":")[1]), st:e.st}); info = "";
+ const arg = { tooltipText:labelHTML+" N:"+shapes.length+"\n"+formatTime(e.dur)+info, bufs:[], key, ctx:shapeRef?.ctx, step:shapeRef?.step };
if (e.key != null) shapeMap.set(e.key, key);
// offset y by depth
shapes.push({x:e.st, y:levelHeight*depth, width:e.dur, height:levelHeight, arg, label:opts.hideLabels ? null : label, fillColor });
@@ -492,6 +512,8 @@ async function renderProfiler(path, unit, opts) {
}
}
for (const m of markers) m.label = m.name.split(/(\s+)/).map(st => ({ st, color:m.color, width:ctx.measureText(st).width }));
+ data.pcToShape = new Map([...data.pcToShape].sort((a, b) => a[1].st - b[1].st));
+ if (extData.pcMap != null) data.pcMap = extData.pcMap; setFocus(focusedShape);
updateProgress(Status.COMPLETE);
// draw events on a timeline
const dpr = window.devicePixelRatio || 1;
@@ -837,6 +859,8 @@ async function main() {
// ** center graph
const { currentCtx, currentStep, currentRewrite, expandSteps } = state;
if (currentCtx == -1) return;
+ // always have a new sidebar when view changes
+ metadata.innerHTML = "";
const ctx = ctxs[currentCtx];
const step = ctx.steps[currentStep];
const ckey = step?.query;
@@ -868,7 +892,6 @@ async function main() {
opts = {heightScale:0.5, hideLabels:true, levelKey:step.name.includes("PKTS") ? (e) => parseInt(e.name.split(" ")[1].split(":")[1]) : null, colorByName:ckey.includes("pkts")};
return renderProfiler(ckey, "clk", opts);
}
- metadata.innerHTML = "";
ret.metadata?.forEach(m => {
if (Array.isArray(m)) return metadata.appendChild(tabulate(m.map(({ label, value }) => {
return [label.trim(), typeof value === "string" ? value : formatUnit(value)];
diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py
index 4aac563d3b..4e44faa735 100755
--- a/tinygrad/viz/serve.py
+++ b/tinygrad/viz/serve.py
@@ -342,7 +342,7 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
def add(name:str, p:PacketType, idx=0, width=1, op_name=None, wave=None, info:InstructionInfo|None=None) -> None:
if hasattr(p, "wave"): wave = p.wave
rows.setdefault(r:=(f"WAVE:{wave}" if wave is not None else f"{p.__class__.__name__}:0 {name}"))
- key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=str(info.inst) if info is not None else None)
+ key = TracingKey(f"{op_name if op_name is not None else name} OP:{idx}", ret=f"PC:{info.pc}" if info is not None else None)
ret.append(ProfileRangeEvent(r, key, Decimal(p._time), Decimal(p._time+width)))
for p, info in map_insts(data, lib, target):
if len(ret) > getenv("MAX_SQTT_PKTS", 50_000): break
@@ -361,7 +361,8 @@ def sqtt_timeline(data:bytes, lib:bytes, target:str) -> list[ProfileEvent]:
add(name.replace("_ALT", ""), p, op_name=name)
if p._time in trace.setdefault(name, set()): raise AssertionError(f"packets overlap in shared resource! {name}")
trace[name].add(p._time)
- return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
+ pc_map = {addr:str(inst) for addr,inst in amd_decode(lib, target).items()}
+ return [ProfilePointEvent(r, "JSON", "pcMap", pc_map, ts=Decimal(0)) for r in rows]+ret
# ** SQTT OCC only unpacks wave start, end time and SIMD location
@@ -415,6 +416,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
# map events per device
dev_events:dict[str, list[tuple[int, int, float, DevEvent]]] = {}
markers:list[ProfilePointEvent] = []
+ ext_data:dict[str, Any] = {}
start_ts:int|None = None
end_ts:int|None = None
for ts,en,e in flatten_events(profile):
@@ -422,6 +424,7 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
if start_ts is None or st < start_ts: start_ts = st
if end_ts is None or et > end_ts: end_ts = et
if isinstance(e, ProfilePointEvent) and e.name == "marker": markers.append(e)
+ if isinstance(e, ProfilePointEvent) and e.name == "JSON": ext_data[e.key] = e.arg
if start_ts is None: return None
# return layout of per device events
layout:dict[str, bytes|None] = {}
@@ -434,7 +437,8 @@ def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]=device_
layout[f"{k} Memory"] = mem_layout(v, start_ts, unwrap(end_ts), peaks, dtype_size, scache)
sorted_layout = sorted([k for k,v in layout.items() if v is not None], key=sort_fn)
ret = [b"".join([struct.pack("