diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 713c43b4bd..f5e0f3e6f9 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -170,19 +170,18 @@ function formatMicroseconds(ts, showUs=true) { const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit; const WAVE_COLORS = {VALU:"#ffffc0", SALU:"#cef263", LOAD:"#ffc0c0", STORE:"#4fa3cc", IMMEDIATE:"#f3b44a", BARRIER:"#d00000", JUMP:"#ffb703", - JUMP_NO:"#fb8500", MESSAGE:"#90dbf4"}; + JUMP_NO:"#fb8500", MESSAGE:"#90dbf4", VMEM:"#b2b7c9"}; const waveColor = (op) => { - const cat = op.includes("VALU") || op === "VINTERP" ? "VALU" : op.includes("SALU") ? "SALU" + const cat = op.includes("VALU") || op === "VINTERP" ? "VALU" : op.includes("SALU") ? "SALU" : op.includes("VMEM") ? "VMEM" : op.includes("LOAD") || op === "SMEM" ? "LOAD" : op.includes("STORE") ? "STORE" : op; ret = WAVE_COLORS[cat] ?? "#ffffff"; - // TODO: OTHER packets need to go on the second row - if (op.includes("OTHER_")) { ret = darkenHex(ret, 75) } + if (op.includes("OTHER_") || op.includes("_ALT")) { ret = darkenHex(ret, 75) } return ret }; const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]), DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"], BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]), - WAVE:waveColor, VMEMEXEC:["#f4978e"], ALUEXEC:["#f72585"]} + WAVE:waveColor, VMEMEXEC:waveColor, ALUEXEC:waveColor} const cycleColors = (lst, i) => lst[i%lst.length]; const rescaleTrack = (source, tid, k) => { @@ -800,7 +799,7 @@ async function main() { } // timeline with cycles on the x axis if (ret instanceof ArrayBuffer) { - opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), stepColors:!step.name.includes("Packets")}; + opts = {heightScale:0.5, hideLabels:true, levelKey:(e) => parseInt(e.name.split(" ")[1].split(":")[1]), stepColors:!step.name.includes("PKTS")}; return renderProfiler(ckey, "clk", opts); } metadata.innerHTML = ""; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 2d9af7d89c..48d1012b00 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -283,21 +283,17 @@ def sqtt_timeline(e) -> list[ProfileEvent]: from extra.assembly.amd.sqtt import decode, PacketType, INST, InstOp, VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC ret:list[ProfileEvent] = [] rows:dict[str, None] = {} - def add(name:str, p:PacketType, idx=0, width=1) -> None: + def add(name:str, p:PacketType, idx=0, width=1, op_name=None) -> None: rows.setdefault(r:=(f"WAVE:{p.wave}" if hasattr(p, "wave") else f"{p.__class__.__name__}:0 {name}")) - ret.append(ProfileRangeEvent(r, f"{name} OP:{idx}", Decimal(p._time), Decimal(p._time+width))) + ret.append(ProfileRangeEvent(r, f"{op_name if op_name is not None else name} OP:{idx}", Decimal(p._time), Decimal(p._time+width))) for p in decode(e.blob): if len(ret) > 50_000: break if isinstance(p, INST): op_name = p.op.name if isinstance(p.op, InstOp) else f"0x{p.op:02x}" - # skip OTHER_* packets - if "OTHER" in op_name: continue name, width = (op_name, 10 if "BARRIER" in op_name else 1) - # Wave SALU op, not to be confused with the global SALU - if op_name == "SALU": name = "WAVE_SALU" - add(name, p, width=width) + add(name, p, width=width, idx=int("OTHER" in name)) if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p) - if isinstance(p, (VMEMEXEC, ALUEXEC)): add(str(p.src).split('.')[1], p) + if isinstance(p, (VMEMEXEC, ALUEXEC)): add((name:=str(p.src).split('.')[1]).replace("_ALT", ""), p, idx=int("ALT" in name), op_name=name) return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret # ** SQTT OCC only unpacks wave start, end time and SIMD location