simplify sqtt decoder infra (#13849)

* more work

* simpler
This commit is contained in:
qazal
2025-12-28 00:31:16 +09:00
committed by GitHub
parent ae013beab8
commit f6c660f7fa
2 changed files with 11 additions and 26 deletions

View File

@@ -48,17 +48,11 @@ class OccEvent(WaveSlot):
RunKey = tuple[str, int]
class _ROCParseCtx:
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs
self.disasms:dict[str, dict[int, tuple[str, int]]] = {}
def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]):
self.sqtt_evs, self.disasms = iter(sqtt_evs), disasms
self.inst_execs:dict[RunKey, list[WaveExec]] = {}
self.occ_events:dict[RunKey, list[OccEvent]] = {}
for prog in prog_evs:
base = unwrap(prog.base)
target = unwrap(dev_evs[prog.device].props)['gfx_target_version']
self.disasms[prog.name] = asm = {base+addr:info for addr,info in llvm_disasm(target, unwrap(prog.lib)).items()}
def next_sqtt(self):
x = next(self.sqtt_evs, None)
self.active_run = (x.kern, x.exec_tag) if x is not None else None
@@ -81,16 +75,8 @@ class _ROCParseCtx:
self.inst_execs.setdefault(unwrap(self.active_run), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
ev.end_time, insts_blob))
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
dev_events:dict[str, ProfileDeviceEvent] = {}
sqtt_events:list[ProfileSQTTEvent] = []
prog_events:list[ProfileProgramEvent] = []
for e in profile:
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
if isinstance(e, ProfileSQTTEvent): sqtt_events.append(e)
if isinstance(e, ProfileProgramEvent) and e.device.startswith("AMD"): prog_events.append(e)
ROCParseCtx = _ROCParseCtx(dev_events, sqtt_events, prog_events)
def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]) -> _ROCParseCtx:
ROCParseCtx = _ROCParseCtx(sqtt_evs, disasms)
@rocprof.rocprof_trace_decoder_se_data_callback_t
def copy_cb(buf, buf_size, _):
@@ -150,7 +136,7 @@ if __name__ == "__main__":
args = parser.parse_args()
with args.profile.open("rb") as f: profile = pickle.load(f)
rctx = decode(profile)
print('SQTT:', rctx.inst_execs.keys())
#rctx = decode(profile, disasm)
#print('SQTT:', rctx.inst_execs.keys())
print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)])