From 712c7a64482ec47c0c352b51cb1e05a66a2cfbcb Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:50:34 +0800 Subject: [PATCH] sqtt loader cleanups from the occupancy branch (#13431) * cleanup err handling * from disasms * s/wave_execs/wave_insts --- extra/sqtt/roc.py | 2 +- tinygrad/viz/serve.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 8157b4fa3b..e708d9b1ed 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -81,7 +81,7 @@ class _ROCParseCtx: ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz) self.inst_execs.setdefault(unwrap(self.active_kern), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time, - ev.end_time, insts_blob)) + ev.end_time, insts_blob)) def decode(profile:list[ProfileEvent]) -> _ROCParseCtx: dev_events:dict[str, ProfileDeviceEvent] = {} diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 8a3995b839..c263990fe7 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -210,14 +210,14 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int, peaks.append(peak) return struct.pack(" None: + ctxs.append({"name":"ERR", "steps":[create_step(name, ("render",len(ctxs),0), {"src":msg or traceback.format_exc()})]}) + def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(x.split(":")[1]) for x in row.split()) def load_sqtt(profile:list[ProfileEvent]) -> None: from tinygrad.runtime.ops_amd import ProfileSQTTEvent if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None - def err(name:str, msg:str|None=None) -> None: - step = {"name":name, "data":{"src":msg or traceback.format_exc()}, "depth":0, "query":f"/render?ctx={len(ctxs)}&step=0&fmt=counters"} - return ctxs.append({"name":"Counters", "steps":[step]}) try: from extra.sqtt.roc import decode except Exception: return err("DECODER IMPORT ISSUE") try: rctx = decode(profile) @@ -227,22 +227,22 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: for e in sqtt_events: parse_sqtt_print_packets(e.blob) if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded") steps:list[dict] = [] - for name,waves in rctx.inst_execs.items(): - disasm = rctx.disasms[name] + for name,disasm in rctx.disasms.items(): units:dict[str, int] = {} events:list[ProfileEvent] = [] - wave_execs:dict[str, dict] = {} - for w in waves: + wave_insts:dict[str, dict] = {} + for w in rctx.inst_execs.get(name, []): if (row:=f"SE:{w.se} CU:{w.cu} SIMD:{w.simd} WAVE:{w.wave_id}") not in units: units[row] = 0 units[row] += 1 events.append(ProfileRangeEvent(row, f"N:{units[row]}", Decimal(w.begin_time), Decimal(w.end_time))) - wave_execs[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]} + wave_insts[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]} # gather and sort all wave execs of this kernel + if not events: continue events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events kernel = trace.keys[r].ret if (r:=ref_map.get(name)) else None steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)), {"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=1)) - for k in sorted(wave_execs, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_execs[k], depth=2)) + for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2)) ctxs.append({"name":"Counters", "steps":steps}) def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None: