sqtt loader cleanups from the occupancy branch (#13431)

* cleanup err handling

* from disasms

* s/wave_execs/wave_insts
This commit is contained in:
qazal
2025-11-23 21:50:34 +08:00
committed by GitHub
parent 9d7a17ee39
commit 712c7a6448
2 changed files with 10 additions and 10 deletions

View File

@@ -81,7 +81,7 @@ class _ROCParseCtx:
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
self.inst_execs.setdefault(unwrap(self.active_kern), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
ev.end_time, insts_blob))
ev.end_time, insts_blob))
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
dev_events:dict[str, ProfileDeviceEvent] = {}

View File

@@ -210,14 +210,14 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int,
peaks.append(peak)
return struct.pack("<BIQ", 1, len(events), peak)+b"".join(events) if events else None
def err(name:str, msg:str|None=None) -> None:
ctxs.append({"name":"ERR", "steps":[create_step(name, ("render",len(ctxs),0), {"src":msg or traceback.format_exc()})]})
def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(x.split(":")[1]) for x in row.split())
def load_sqtt(profile:list[ProfileEvent]) -> None:
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None
def err(name:str, msg:str|None=None) -> None:
step = {"name":name, "data":{"src":msg or traceback.format_exc()}, "depth":0, "query":f"/render?ctx={len(ctxs)}&step=0&fmt=counters"}
return ctxs.append({"name":"Counters", "steps":[step]})
try: from extra.sqtt.roc import decode
except Exception: return err("DECODER IMPORT ISSUE")
try: rctx = decode(profile)
@@ -227,22 +227,22 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
for e in sqtt_events: parse_sqtt_print_packets(e.blob)
if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
steps:list[dict] = []
for name,waves in rctx.inst_execs.items():
disasm = rctx.disasms[name]
for name,disasm in rctx.disasms.items():
units:dict[str, int] = {}
events:list[ProfileEvent] = []
wave_execs:dict[str, dict] = {}
for w in waves:
wave_insts:dict[str, dict] = {}
for w in rctx.inst_execs.get(name, []):
if (row:=f"SE:{w.se} CU:{w.cu} SIMD:{w.simd} WAVE:{w.wave_id}") not in units: units[row] = 0
units[row] += 1
events.append(ProfileRangeEvent(row, f"N:{units[row]}", Decimal(w.begin_time), Decimal(w.end_time)))
wave_execs[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]}
wave_insts[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]}
# gather and sort all wave execs of this kernel
if not events: continue
events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events
kernel = trace.keys[r].ret if (r:=ref_map.get(name)) else None
steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)),
{"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=1))
for k in sorted(wave_execs, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_execs[k], depth=2))
for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2))
ctxs.append({"name":"Counters", "steps":steps})
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None: