From a8ae9757dd6bfc67bb1c64dec90c0ed285f6ff1d Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 16 Jan 2026 19:36:14 -0500 Subject: [PATCH] viz: put alts in the same row, LDS color (#14194) * viz: put alts in the same row, coloring work * assert if packets overlap * lds color --- tinygrad/viz/js/index.js | 2 +- tinygrad/viz/serve.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index fa7e208795..06b9fd4605 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -180,7 +180,7 @@ function formatCycles(cycles) { const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit; const WAVE_COLORS = {VALU:"#ffffc0", SALU:"#cef263", LOAD:"#ffc0c0", STORE:"#4fa3cc", IMMEDIATE:"#f3b44a", BARRIER:"#d00000", JUMP:"#ffb703", - JUMP_NO:"#fb8500", MESSAGE:"#90dbf4", VMEM:"#b2b7c9"}; + JUMP_NO:"#fb8500", MESSAGE:"#90dbf4", VMEM:"#b2b7c9", LDS:"#9fb4a6"}; const waveColor = (op) => { const cat = op.includes("VALU") || op === "VINTERP" ? "VALU" : op.includes("SALU") ? "SALU" : op.includes("VMEM") ? "VMEM" : op.includes("LOAD") || op === "SMEM" ? "LOAD" : op.includes("STORE") ? "STORE" : op; diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 48d1012b00..9e1f283ba9 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -283,6 +283,7 @@ def sqtt_timeline(e) -> list[ProfileEvent]: from extra.assembly.amd.sqtt import decode, PacketType, INST, InstOp, VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC ret:list[ProfileEvent] = [] rows:dict[str, None] = {} + trace:dict[str, set[int]] = {} def add(name:str, p:PacketType, idx=0, width=1, op_name=None) -> None: rows.setdefault(r:=(f"WAVE:{p.wave}" if hasattr(p, "wave") else f"{p.__class__.__name__}:0 {name}")) ret.append(ProfileRangeEvent(r, f"{op_name if op_name is not None else name} OP:{idx}", Decimal(p._time), Decimal(p._time+width))) @@ -293,7 +294,10 @@ def sqtt_timeline(e) -> list[ProfileEvent]: name, width = (op_name, 10 if "BARRIER" in op_name else 1) add(name, p, width=width, idx=int("OTHER" in name)) if isinstance(p, (VALUINST, IMMEDIATE)): add(p.__class__.__name__, p) - if isinstance(p, (VMEMEXEC, ALUEXEC)): add((name:=str(p.src).split('.')[1]).replace("_ALT", ""), p, idx=int("ALT" in name), op_name=name) + if isinstance(p, (VMEMEXEC, ALUEXEC)): + add((name:=str(p.src).split('.')[1]).replace("_ALT", ""), p, op_name=name) + if p._time in trace.setdefault(name, set()): raise AssertionError(f"packets overlap in shared resource! {name}") + trace[name].add(p._time) return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret # ** SQTT OCC only unpacks wave start, end time and SIMD location