mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
sqtt loader cleanups from the occupancy branch (#13431)
* cleanup err handling * from disasms * s/wave_execs/wave_insts
This commit is contained in:
@@ -81,7 +81,7 @@ class _ROCParseCtx:
|
||||
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
|
||||
|
||||
self.inst_execs.setdefault(unwrap(self.active_kern), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
|
||||
ev.end_time, insts_blob))
|
||||
ev.end_time, insts_blob))
|
||||
|
||||
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
|
||||
dev_events:dict[str, ProfileDeviceEvent] = {}
|
||||
|
||||
@@ -210,14 +210,14 @@ def mem_layout(dev_events:list[tuple[int, int, float, DevEvent]], start_ts:int,
|
||||
peaks.append(peak)
|
||||
return struct.pack("<BIQ", 1, len(events), peak)+b"".join(events) if events else None
|
||||
|
||||
def err(name:str, msg:str|None=None) -> None:
|
||||
ctxs.append({"name":"ERR", "steps":[create_step(name, ("render",len(ctxs),0), {"src":msg or traceback.format_exc()})]})
|
||||
|
||||
def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(x.split(":")[1]) for x in row.split())
|
||||
|
||||
def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
from tinygrad.runtime.ops_amd import ProfileSQTTEvent
|
||||
if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None
|
||||
def err(name:str, msg:str|None=None) -> None:
|
||||
step = {"name":name, "data":{"src":msg or traceback.format_exc()}, "depth":0, "query":f"/render?ctx={len(ctxs)}&step=0&fmt=counters"}
|
||||
return ctxs.append({"name":"Counters", "steps":[step]})
|
||||
try: from extra.sqtt.roc import decode
|
||||
except Exception: return err("DECODER IMPORT ISSUE")
|
||||
try: rctx = decode(profile)
|
||||
@@ -227,22 +227,22 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
for e in sqtt_events: parse_sqtt_print_packets(e.blob)
|
||||
if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
|
||||
steps:list[dict] = []
|
||||
for name,waves in rctx.inst_execs.items():
|
||||
disasm = rctx.disasms[name]
|
||||
for name,disasm in rctx.disasms.items():
|
||||
units:dict[str, int] = {}
|
||||
events:list[ProfileEvent] = []
|
||||
wave_execs:dict[str, dict] = {}
|
||||
for w in waves:
|
||||
wave_insts:dict[str, dict] = {}
|
||||
for w in rctx.inst_execs.get(name, []):
|
||||
if (row:=f"SE:{w.se} CU:{w.cu} SIMD:{w.simd} WAVE:{w.wave_id}") not in units: units[row] = 0
|
||||
units[row] += 1
|
||||
events.append(ProfileRangeEvent(row, f"N:{units[row]}", Decimal(w.begin_time), Decimal(w.end_time)))
|
||||
wave_execs[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]}
|
||||
wave_insts[f"{row} N:{units[row]}"] = {"wave":w, "disasm":disasm, "run_number":units[row]}
|
||||
# gather and sort all wave execs of this kernel
|
||||
if not events: continue
|
||||
events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events
|
||||
kernel = trace.keys[r].ret if (r:=ref_map.get(name)) else None
|
||||
steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)),
|
||||
{"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=1))
|
||||
for k in sorted(wave_execs, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_execs[k], depth=2))
|
||||
for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2))
|
||||
ctxs.append({"name":"Counters", "steps":steps})
|
||||
|
||||
def get_profile(profile:list[ProfileEvent], sort_fn:Callable[[str], Any]|None=None) -> bytes|None:
|
||||
|
||||
Reference in New Issue
Block a user