viz: per cu timeline (#13451)

* add cu_loc

* work

* WAVE -> W
This commit is contained in:
qazal
2025-11-26 00:05:20 +08:00
committed by GitHub
parent 4a9562e353
commit 5520f1fb0b
2 changed files with 21 additions and 12 deletions

View File

@@ -44,9 +44,11 @@ class WaveSlot:
simd:int
se:int
@property
def simd_loc(self) -> str: return f"SE:{self.se} CU:{self.cu} SIMD:{self.simd}"
def cu_loc(self) -> str: return f"SE:{self.se} CU:{self.cu}"
@property
def wave_loc(self) -> str: return f"{self.simd_loc} WAVE:{self.wave_id}"
def simd_loc(self) -> str: return f"{self.cu_loc} SIMD:{self.simd}"
@property
def wave_loc(self) -> str: return f"{self.simd_loc} W:{self.wave_id}"
@dataclasses.dataclass(frozen=True)
class WaveExec(WaveSlot):

View File

@@ -228,15 +228,16 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
if not any([rctx.inst_execs, rctx.occ_events]): return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
steps:list[dict] = []
for name,disasm in rctx.disasms.items():
events:list[ProfileEvent] = []
cu_events:dict[str, list[ProfileEvent]] = {}
# wave instruction events
wave_insts:dict[str, dict] = {}
wave_insts:dict[str, dict[str, dict]] = {}
inst_units:dict[str, itertools.count] = {}
for w in rctx.inst_execs.get(name, []):
if (u:=w.wave_loc) not in inst_units: inst_units[u] = itertools.count(0)
n = next(inst_units[u])
if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = []
events.append(ProfileRangeEvent(w.simd_loc, f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time)))
wave_insts[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "run_number":n}
wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "run_number":n}
# occupancy events
units:dict[str, itertools.count] = {}
wave_start:dict[str, int] = {}
@@ -244,14 +245,20 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
if (u:=occ.wave_loc) not in units: units[u] = itertools.count(0)
if u in inst_units: continue
if occ.start: wave_start[u] = occ.time
else: events.append(ProfileRangeEvent(occ.simd_loc, f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
if not events: continue
# gather and sort all sqtt events for this kernel
events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events
else:
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
events.append(ProfileRangeEvent(occ.simd_loc, f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)), Decimal(occ.time)))
if not cu_events: continue
prg_cu = sorted(cu_events, key=row_tuple)
kernel = trace.keys[r].ret if (r:=ref_map.get(name)) else None
steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)),
{"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=1))
for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2))
src = f"Scheduled on {len(prg_cu)} CUs"+(f"\n\n{kernel.global_size=} {kernel.local_size=}" if kernel else "")
steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)), {"src":src}, depth=1))
for cu in prg_cu:
events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]
steps.append(create_step(cu, ("/counters", len(ctxs), len(steps)),
{"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=2))
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[cu][k], depth=3))
ctxs.append({"name":"Counters", "steps":steps})
def device_sort_fn(k):