diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 00957f2c28..62167beb93 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -161,6 +161,15 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx: t.join() return ROCParseCtx +def print_pmc(ev:ProfilePMCEvent) -> None: + ptr = 0 + for s in ev.sched: + view = memoryview(ev.blob).cast('Q') + print(f"\t{s.name}") + for xcc, inst, se_idx, sa_idx, wgp_idx in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa), range(s.wgp)): + print(f"\t\tXCC {xcc} Inst {inst:<2} SE {se_idx} SA {sa_idx} WGP {wgp_idx}: {view[ptr]:#x}") + ptr += 1 + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--profile', type=pathlib.Path, help='Path to profile', default=pathlib.Path(temp("profile.pkl", append_user=True))) @@ -173,10 +182,4 @@ if __name__ == "__main__": for ev in profile: if not isinstance(ev, ProfilePMCEvent): continue print(f"PMC Event: dev={ev.device} kern={ev.kern}") - ptr = 0 - for s in ev.sched: - view = memoryview(ev.blob).cast('Q') - print(f"\t{s.name}") - for xcc, inst, se_idx, sa_idx, wgp_idx in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa), range(s.wgp)): - print(f"\t\tXCC {xcc} Inst {inst} SE {se_idx} SA {sa_idx} WGP {wgp_idx}: {view[ptr]:#x}") - ptr += 1 + print_pmc(ev) diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index a810ce97ff..ad9f84525c 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -216,9 +216,10 @@ def err(name:str, msg:str|None=None) -> None: def row_tuple(row:str) -> tuple[int, ...]: return tuple(int(x.split(":")[1]) for x in row.split()) def load_sqtt(profile:list[ProfileEvent]) -> None: - from tinygrad.runtime.ops_amd import ProfileSQTTEvent - if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]): return None - try: from extra.sqtt.roc import decode + from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent + pmc_events = {(e.kern, e.exec_tag):e for e in profile if isinstance(e, ProfilePMCEvent)} + if not (sqtt_events:=[e for e in profile if isinstance(e, ProfileSQTTEvent)]) and not pmc_events: return + try: from extra.sqtt.roc import decode, print_pmc except Exception: return err("DECODER IMPORT ISSUE") try: rctx = decode(profile) except Exception: return err("DECODER ERROR") @@ -253,6 +254,8 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: prg_cu = sorted(cu_events, key=row_tuple) kernel = trace.keys[r].ret if (r:=ref_map.get(name.prg)) else None src = f"Scheduled on {len(prg_cu)} CUs"+(f"\n\n{kernel.global_size=} {kernel.local_size=}" if kernel else "") + pmc = pmc_events.get((name.prg, name.tag)) + if pmc is not None: src += "\n\nPMC:\n"+get_stdout(lambda: print_pmc(pmc)) steps.append(create_step(kernel.name if kernel is not None else name.prg, ("/counters", len(ctxs), len(steps)), {"src":src}, depth=1)) for cu in prg_cu: events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]