simplify sqtt decoder infra (#13849)

* more work

* simpler
This commit is contained in:
qazal
2025-12-28 00:31:16 +09:00
committed by GitHub
parent ae013beab8
commit f6c660f7fa
2 changed files with 11 additions and 26 deletions

View File

@@ -48,17 +48,11 @@ class OccEvent(WaveSlot):
RunKey = tuple[str, int]
class _ROCParseCtx:
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs
self.disasms:dict[str, dict[int, tuple[str, int]]] = {}
def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]):
self.sqtt_evs, self.disasms = iter(sqtt_evs), disasms
self.inst_execs:dict[RunKey, list[WaveExec]] = {}
self.occ_events:dict[RunKey, list[OccEvent]] = {}
for prog in prog_evs:
base = unwrap(prog.base)
target = unwrap(dev_evs[prog.device].props)['gfx_target_version']
self.disasms[prog.name] = asm = {base+addr:info for addr,info in llvm_disasm(target, unwrap(prog.lib)).items()}
def next_sqtt(self):
x = next(self.sqtt_evs, None)
self.active_run = (x.kern, x.exec_tag) if x is not None else None
@@ -81,16 +75,8 @@ class _ROCParseCtx:
self.inst_execs.setdefault(unwrap(self.active_run), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
ev.end_time, insts_blob))
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
dev_events:dict[str, ProfileDeviceEvent] = {}
sqtt_events:list[ProfileSQTTEvent] = []
prog_events:list[ProfileProgramEvent] = []
for e in profile:
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
if isinstance(e, ProfileSQTTEvent): sqtt_events.append(e)
if isinstance(e, ProfileProgramEvent) and e.device.startswith("AMD"): prog_events.append(e)
ROCParseCtx = _ROCParseCtx(dev_events, sqtt_events, prog_events)
def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]) -> _ROCParseCtx:
ROCParseCtx = _ROCParseCtx(sqtt_evs, disasms)
@rocprof.rocprof_trace_decoder_se_data_callback_t
def copy_cb(buf, buf_size, _):
@@ -150,7 +136,7 @@ if __name__ == "__main__":
args = parser.parse_args()
with args.profile.open("rb") as f: profile = pickle.load(f)
rctx = decode(profile)
print('SQTT:', rctx.inst_execs.keys())
#rctx = decode(profile, disasm)
#print('SQTT:', rctx.inst_execs.keys())
print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)])

View File

@@ -243,13 +243,11 @@ def load_counters(profile:list[ProfileEvent]) -> None:
counter_events:dict[tuple[str, int], dict] = {}
durations:dict[str, list[float]] = {}
prg_events:dict[str, ProfileProgramEvent] = {}
dev_events:dict[str, ProfileDeviceEvent] = {}
for e in profile:
if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e)
if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None:
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
if len(counter_events) == 0: return None
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
run_number = {n:0 for n,_ in counter_events}
@@ -263,7 +261,7 @@ def load_counters(profile:list[ProfileEvent]) -> None:
all_counters[(name, run_number[k], k)] = pmc[0]
if (sqtt:=v.get(ProfileSQTTEvent)):
# to decode a SQTT trace, we need the raw stream, program binary and device properties
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]])))
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
if getenv("SQTT_PARSE"):
# run our decoder on startup, we don't use this since it only works on gfx11
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
@@ -272,11 +270,12 @@ def load_counters(profile:list[ProfileEvent]) -> None:
# ** SQTT OCC only unpacks wave start, end time and SIMD location
def unpack_sqtt(key:tuple[str, int], profile:list[ProfileEvent]) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
# * init decoder
from extra.sqtt.roc import decode
rctx = decode(profile)
disasm = rctx.disasms[key[0]]
base = unwrap(p.base)
disasm = {addr+base:inst_disasm for addr,inst_disasm in llvm_disasm(device_props[p.device]["gfx_target_version"], unwrap(p.lib)).items()}
rctx = decode(data, {p.name:disasm})
cu_events:dict[str, list[ProfileEvent]] = {}
# * INST waves
wave_insts:dict[str, dict[str, dict]] = {}