viz: SQTT timeline with our decoder (#14139)

* viz: sqtt OCC/INST timeline in our decoder

* todo

* lint

* work

* cleaner

* profiling

* better timing

* keep the generic api

* more generic

* 80x -> 20x off the C decoder

* unusably slow

* rm filters

* work

* work

* other way to sort ops

* work

* first 10k

* 100K actually tells a story

* barrier INST packets get their own red color and row

* minor detail

* 50K

* soft_err
This commit is contained in:
qazal
2026-01-15 06:45:16 -05:00
committed by GitHub
parent 0cb024a5bb
commit 32e1c267ee
2 changed files with 26 additions and 9 deletions

View File

@@ -171,7 +171,9 @@ const formatUnit = (d, unit="") => d3.format(".3~s")(d)+unit;
const colorScheme = {TINY:new Map([["Schedule","#1b5745"],["get_program","#1d2e62"],["compile","#63b0cd"],["DEFAULT","#354f52"]]),
DEFAULT:["#2b2e39", "#2c2f3a", "#31343f", "#323544", "#2d303a", "#2e313c", "#343746", "#353847", "#3c4050", "#404459", "#444862", "#4a4e65"],
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),}
BUFFER:["#342483", "#3E2E94", "#4938A4", "#5442B4", "#5E4CC2", "#674FCA"], SIMD:new Map([["OCC", "#101725"], ["INST", "#0A2042"]]),
WAVE:new Map([["INST", "#e76f51"], ["VALUINST", "#415a77"], ["IMMEDIATE", "#f3b44a"], ["BARRIER", "#d00000"]]),
SHARED:new Map([["VMEMEXEC", "#f4978e"], ["ALUEXEC", "#f72585"]]),}
const cycleColors = (lst, i) => lst[i%lst.length];
const rescaleTrack = (source, tid, k) => {

View File

@@ -218,7 +218,7 @@ def soft_err(fn:Callable):
try: yield
except Exception: fn({"src":traceback.format_exc()})
def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(x.split(":")[1]) for x in row.split())
def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(ss[1]) if len(ss:=x.split(":"))>1 else 999 for x in row.split())
# *** Performance counters
@@ -271,15 +271,28 @@ def load_counters(profile:list[ProfileEvent]) -> None:
if (pmc:=v.get(ProfilePMCEvent)):
steps.append(create_step("PMC", ("/prg-pmc", len(ctxs), len(steps)), pmc))
all_counters[(name, run_number[k], k)] = pmc[0]
if (sqtt:=v.get(ProfileSQTTEvent)):
# to decode a SQTT trace, we need the raw stream, program binary and device properties
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
if getenv("SQTT_PARSE"):
# run our decoder on startup, we don't use this since it only works on gfx11
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
for e in sqtt: parse_sqtt_print_packets(e.blob)
# to decode a SQTT trace, we need the raw stream, program binary and device properties
if (sqtt:=v.get(ProfileSQTTEvent)): steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
ctxs.append({"name":f"Exec {name}"+(f" n{run_number[k]}" if run_number[k] > 1 else ""), "steps":steps})
def sqtt_timeline(e) -> list[ProfileEvent]:
from extra.assembly.amd.sqtt import decode, PacketType, INST, InstOp, VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC
ret:list[ProfileEvent] = []
rows:dict[str, None] = {}
def add(name:str, p:PacketType, op="OP", idx=0, width=5) -> None:
rows.setdefault(r:=(f"WAVE:{p.wave} {name}:1" if hasattr(p, "wave") else f"SHARED:0 {name}:0"))
ret.append(ProfileRangeEvent(r, f"{name} {op}:{idx}", Decimal(p._time), Decimal(p._time+width)))
op_idx:dict = {}
for p in decode(e.blob):
if len(ret) > 50_000: break
if isinstance(p, INST):
if p.op not in op_idx: op_idx[p.op] = len(op_idx)
op_name, idx = (p.op.name, op_idx[p.op]) if isinstance(p.op, InstOp) else (f"0x{p.op:02x}", len(op_idx))
if "BARRIER" in op_name: add("BARRIER", p, op_name, width=100)
else: add(p.__class__.__name__, p, op_name, idx)
if isinstance(p, (VALUINST, IMMEDIATE, VMEMEXEC, ALUEXEC)): add(p.__class__.__name__, p)
return [ProfilePointEvent(r, "start", r, ts=Decimal(0)) for r in rows]+ret
# ** SQTT OCC only unpacks wave start, end time and SIMD location
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
@@ -308,6 +321,8 @@ def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[
else:
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
events.append(ProfileRangeEvent(f"SIMD:{occ.simd}", f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
# * INST timeline
with soft_err(lambda _:None): cu_events |= {f"SE:{e.se} Packets": timeline for e in data if (timeline := sqtt_timeline(e))}
return cu_events, list(units), wave_insts
def device_sort_fn(k:str) -> tuple[int, str, int]: