From 32e1c267eee969cb34b4bab6de8f73d2c1434201 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Thu, 15 Jan 2026 06:45:16 -0500 Subject: [PATCH] viz: SQTT timeline with our decoder (#14139) * viz: sqtt OCC/INST timeline in our decoder * todo * lint * work * cleaner * profiling * better timing * keep the generic api * more generic * 80x -> 20x off the C decoder * unusably slow * rm filters * work * work * other way to sort ops * work * first 10k * 100K actually tells a story * barrier INST packets get their own red color and row * minor detail * 50K * soft_err --- tinygrad/viz/js/index.js | 4 +++- tinygrad/viz/serve.py | 31 +++++++++++++++++++++++-------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index bc0243804e..770adfb116 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -171,7 +171,9 @@ const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit; const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]), DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"], - BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),} + BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]), + WAVE:new Map([["INST", "#e76f51"], ["VALUINST", "#415a77"], ["IMMEDIATE", "#f3b44a"], ["BARRIER", "#d00000"]]), + SHARED:new Map([["VMEMEXEC", "#f4978e"], ["ALUEXEC", "#f72585"]]),} const cycleColors = (lst, i) => lst[i%lst.length]; const rescaleTrack = (source, tid, k) => { diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 6a64bfe321..a5d2d8efc2 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -218,7 +218,7 @@ def soft_err(fn:Callable): try: yield except Exception: fn({"src":traceback.format_exc()}) -def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(x.split(":")[1]) for x in row.split()) +def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(ss[1]) if len(ss:=x.split(":"))>1 else 999 for x in row.split()) # *** Performance counters @@ -271,15 +271,28 @@ def load_counters(profile:list[ProfileEvent]) -> None: if (pmc:=v.get(ProfilePMCEvent)): steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc)) all_counters[(name, run_number[k], k)] = pmc[0] - if (sqtt:=v.get(ProfileSQTTEvent)): - # to decode a SQTT trace, we need the raw stream, program binary and device properties - steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k]))) - if getenv("SQTT_PARSE"): - # run our decoder on startup, we don't use this since it only works on gfx11 - from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets - for e in sqtt: parse_sqtt_print_packets(e.blob) + # to decode a SQTT trace, we need the raw stream, program binary and device properties + if (sqtt:=v.get(ProfileSQTTEvent)): steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k]))) ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps}) +def sqtt_timeline(e) -> list[ProfileEvent]: + from extra.assembly.amd.sqtt import decode, PacketType, INST, InstOp, VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC + ret:list[ProfileEvent] = [] + rows:dict[str, None] = {} + def add(name:str, p:PacketType, op="OP", idx=0, width=5) -> None: + rows.setdefault(r:=(f"WAVE:{p.wave} {name}:1" if hasattr(p, "wave") else f"SHARED:0 {name}:0")) + ret.append(ProfileRangeEvent(r, f"{name} {op}:{idx}", Decimal(p._time), Decimal(p._time+width))) + op_idx:dict = {} + for p in decode(e.blob): + if len(ret) > 50_000: break + if isinstance(p, INST): + if p.op not in op_idx: op_idx[p.op] = len(op_idx) + op_name, idx = (p.op.name, op_idx[p.op]) if isinstance(p.op, InstOp) else (f"0x{p.op:02x}", len(op_idx)) + if "BARRIER" in op_name: add("BARRIER", p, op_name, width=100) + else: add(p.__class__.__name__, p, op_name, idx) + if isinstance(p, (VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC)): add(p.__class__.__name__, p) + return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret + # ** SQTT OCC only unpacks wave start, end time and SIMD location def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]: @@ -308,6 +321,8 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[ else: if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = [] events.append(ProfileRangeEvent(f"SIMD:{occ.simd}", f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time))) + # * INST timeline + with soft_err(lambda _:None): cu_events |= {f"SE:{e.se} Packets": timeline for e in data if (timeline := sqtt_timeline(e))} return cu_events, list(units), wave_insts def device_sort_fn(k:str) -> tuple[int, str, int]: