From d7caae5f612f0daeb12f5b1107d620999dad13f2 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Fri, 5 Dec 2025 03:08:39 +0800 Subject: [PATCH] viz: tabulate pmc (#13574) * viz: tabulate pmc * linter * enable nesting * pmc comes before waves --- extra/sqtt/roc.py | 2 +- tinygrad/viz/js/index.js | 2 +- tinygrad/viz/serve.py | 19 +++++++++++++++---- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 0058ccab48..40ca8c04be 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -160,8 +160,8 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx: def print_pmc(ev:ProfilePMCEvent) -> None: ptr = 0 + view = memoryview(ev.blob).cast('Q') for s in ev.sched: - view = memoryview(ev.blob).cast('Q') print(f"\t{s.name}") for xcc, inst, se_idx, sa_idx, wgp_idx in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa), range(s.wgp)): print(f"\t\tXCC {xcc} Inst {inst:<2} SE {se_idx} SA {sa_idx} WGP {wgp_idx}: {view[ptr]:#x}") diff --git a/tinygrad/viz/js/index.js b/tinygrad/viz/js/index.js index 1007abe5b6..8d387761af 100644 --- a/tinygrad/viz/js/index.js +++ b/tinygrad/viz/js/index.js @@ -784,7 +784,7 @@ async function main() { } const td = tr.append("td").classed(ret.cols[i], true); // string format scalar values - if (!Array.isArray(value)) { td.text(value); continue; } + if (!Array.isArray(value)) { td.text(typeof value === "string" ? value : formatUnit(value)); continue; } // display arrays in a bar graph td.classed("pct-row", true); const bar = td.append("div"); diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index 473b421e73..df614190ce 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -222,7 +222,7 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), []).append(e) if not counter_events: return # ** init decoder - try: from extra.sqtt.roc import decode, print_pmc + try: from extra.sqtt.roc import decode except Exception: return err("DECODER IMPORT ISSUE") try: rctx = decode(profile) except Exception: return err("DECODER ERROR") @@ -237,9 +237,6 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: # ** Run summary program = trace.keys[r].ret if (r:=ref_map.get(key[0])) else None summary = [f"{program.global_size=} {program.local_size=}"] if program else [repr(key)] - # ** PMC events - pmc_events = [e for e in counters if isinstance(e, ProfilePMCEvent)] - if pmc_events: summary.append("PMC:\n"+"\n".join([get_stdout(lambda: print_pmc(e)) for e in pmc_events])) # ** SQTT events disasm = rctx.disasms[key[0]] cu_events:dict[str, list[ProfileEvent]] = {} @@ -265,6 +262,20 @@ def load_sqtt(profile:list[ProfileEvent]) -> None: prg_cu = sorted(cu_events, key=row_tuple) if cu_events: summary.append(f"Scheduled on {len(prg_cu)} CUs") steps.append(create_step(program.name if program else key[0], ("/counters", len(ctxs), len(steps)), {"src":"\n\n".join(summary)}, depth=1)) + # ** PMC events + if (pmc_event:=next((e for e in counters if isinstance(e, ProfilePMCEvent)), None)) is not None: + agg_cols = ["Name", "Sum"] + sample_cols = ["XCC", "INST", "SE", "SA", "WGP", "Value"] + rows:list[list] = [] + view, ptr = memoryview(pmc_event.blob).cast('Q'), 0 + for s in pmc_event.sched: + row:list = [s.name, 0, {"cols":sample_cols, "rows":[]}] + for sample in itertools.product(range(s.xcc), range(s.inst), range(s.se), range(s.sa), range(s.wgp)): + row[1] += (val:=int(view[ptr])) + row[2]["rows"].append(sample+(val,)) + ptr += 1 + rows.append(row) + steps.append(create_step("PMC", ("/pmc", len(ctxs), len(steps)), {"rows":rows, "cols":agg_cols, "summary":[]}, depth=2)) for cu in prg_cu: events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu] steps.append(create_step(f"{cu} {len(cu_events[cu])}", ("/counters", len(ctxs), len(steps)),