From f6c660f7fac4b503968517111448e3ea84f8fbc5 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 28 Dec 2025 00:31:16 +0900 Subject: [PATCH] simplify sqtt decoder infra (#13849) * more work * simpler --- extra/sqtt/roc.py | 26 ++++++-------------------- tinygrad/viz/serve.py | 11 +++++------ 2 files changed, 11 insertions(+), 26 deletions(-) diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 4780a3036e..225456b073 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -48,17 +48,11 @@ class OccEvent(WaveSlot): RunKey = tuple[str, int] class _ROCParseCtx: - def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]): - self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs - self.disasms:dict[str, dict[int, tuple[str, int]]] = {} + def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]): + self.sqtt_evs, self.disasms = iter(sqtt_evs), disasms self.inst_execs:dict[RunKey, list[WaveExec]] = {} self.occ_events:dict[RunKey, list[OccEvent]] = {} - for prog in prog_evs: - base = unwrap(prog.base) - target = unwrap(dev_evs[prog.device].props)['gfx_target_version'] - self.disasms[prog.name] = asm = {base+addr:info for addr,info in llvm_disasm(target, unwrap(prog.lib)).items()} - def next_sqtt(self): x = next(self.sqtt_evs, None) self.active_run = (x.kern, x.exec_tag) if x is not None else None @@ -81,16 +75,8 @@ class _ROCParseCtx: self.inst_execs.setdefault(unwrap(self.active_run), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time, ev.end_time, insts_blob)) -def decode(profile:list[ProfileEvent]) -> _ROCParseCtx: - dev_events:dict[str, ProfileDeviceEvent] = {} - sqtt_events:list[ProfileSQTTEvent] = [] - prog_events:list[ProfileProgramEvent] = [] - for e in profile: - if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e - if isinstance(e, ProfileSQTTEvent): sqtt_events.append(e) - if isinstance(e, ProfileProgramEvent) and e.device.startswith("AMD"): prog_events.append(e) - - ROCParseCtx = _ROCParseCtx(dev_events, sqtt_events, prog_events) +def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]) -> _ROCParseCtx: + ROCParseCtx = _ROCParseCtx(sqtt_evs, disasms) @rocprof.rocprof_trace_decoder_se_data_callback_t def copy_cb(buf, buf_size, _): @@ -150,7 +136,7 @@ if __name__ == "__main__": args = parser.parse_args() with args.profile.open("rb") as f: profile = pickle.load(f) - rctx = decode(profile) - print('SQTT:', rctx.inst_execs.keys()) + #rctx = decode(profile, disasm) + #print('SQTT:', rctx.inst_execs.keys()) print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)]) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 8ba36ce661..f0c937dbf2 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -243,13 +243,11 @@ def load_counters(profile:list[ProfileEvent]) -> None: counter_events:dict[tuple[str, int], dict] = {} durations:dict[str, list[float]] = {} prg_events:dict[str, ProfileProgramEvent] = {} - dev_events:dict[str, ProfileDeviceEvent] = {} for e in profile: if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e) if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None: durations.setdefault(str(e.name), []).append(float(e.en-e.st)) if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e - if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e if len(counter_events) == 0: return None ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]}) run_number = {n:0 for n,_ in counter_events} @@ -263,7 +261,7 @@ def load_counters(profile:list[ProfileEvent]) -> None: all_counters[(name, run_number[k], k)] = pmc[0] if (sqtt:=v.get(ProfileSQTTEvent)): # to decode a SQTT trace, we need the raw stream, program binary and device properties - steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]]))) + steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k]))) if getenv("SQTT_PARSE"): # run our decoder on startup, we don't use this since it only works on gfx11 from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets @@ -272,11 +270,12 @@ def load_counters(profile:list[ProfileEvent]) -> None: # ** SQTT OCC only unpacks wave start, end time and SIMD location -def unpack_sqtt(key:tuple[str, int], profile:list[ProfileEvent]) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]: +def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]: # * init decoder from extra.sqtt.roc import decode - rctx = decode(profile) - disasm = rctx.disasms[key[0]] + base = unwrap(p.base) + disasm = {addr+base:inst_disasm for addr,inst_disasm in llvm_disasm(device_props[p.device]["gfx_target_version"], unwrap(p.lib)).items()} + rctx = decode(data, {p.name:disasm}) cu_events:dict[str, list[ProfileEvent]] = {} # * INST waves wave_insts:dict[str, dict[str, dict]] = {}