mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
@@ -44,9 +44,11 @@ class WaveSlot:
|
||||
simd:int
|
||||
se:int
|
||||
@property
|
||||
def simd_loc(self) -> str: return f"SE:{self.se} CU:{self.cu} SIMD:{self.simd}"
|
||||
def cu_loc(self) -> str: return f"SE:{self.se} CU:{self.cu}"
|
||||
@property
|
||||
def wave_loc(self) -> str: return f"{self.simd_loc} WAVE:{self.wave_id}"
|
||||
def simd_loc(self) -> str: return f"{self.cu_loc} SIMD:{self.simd}"
|
||||
@property
|
||||
def wave_loc(self) -> str: return f"{self.simd_loc} W:{self.wave_id}"
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class WaveExec(WaveSlot):
|
||||
|
||||
@@ -228,15 +228,16 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
if not any([rctx.inst_execs, rctx.occ_events]): return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
|
||||
steps:list[dict] = []
|
||||
for name,disasm in rctx.disasms.items():
|
||||
events:list[ProfileEvent] = []
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
# wave instruction events
|
||||
wave_insts:dict[str, dict] = {}
|
||||
wave_insts:dict[str, dict[str, dict]] = {}
|
||||
inst_units:dict[str, itertools.count] = {}
|
||||
for w in rctx.inst_execs.get(name, []):
|
||||
if (u:=w.wave_loc) not in inst_units: inst_units[u] = itertools.count(0)
|
||||
n = next(inst_units[u])
|
||||
if (events:=cu_events.get(w.cu_loc)) is None: cu_events[w.cu_loc] = events = []
|
||||
events.append(ProfileRangeEvent(w.simd_loc, f"INST WAVE:{w.wave_id} N:{n}", Decimal(w.begin_time), Decimal(w.end_time)))
|
||||
wave_insts[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "run_number":n}
|
||||
wave_insts.setdefault(w.cu_loc, {})[f"{u} N:{n}"] = {"wave":w, "disasm":disasm, "run_number":n}
|
||||
# occupancy events
|
||||
units:dict[str, itertools.count] = {}
|
||||
wave_start:dict[str, int] = {}
|
||||
@@ -244,14 +245,20 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
if (u:=occ.wave_loc) not in units: units[u] = itertools.count(0)
|
||||
if u in inst_units: continue
|
||||
if occ.start: wave_start[u] = occ.time
|
||||
else: events.append(ProfileRangeEvent(occ.simd_loc, f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)),Decimal(occ.time)))
|
||||
if not events: continue
|
||||
# gather and sort all sqtt events for this kernel
|
||||
events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+events
|
||||
else:
|
||||
if (events:=cu_events.get(occ.cu_loc)) is None: cu_events[occ.cu_loc] = events = []
|
||||
events.append(ProfileRangeEvent(occ.simd_loc, f"OCC WAVE:{occ.wave_id} N:{next(units[u])}", Decimal(wave_start.pop(u)), Decimal(occ.time)))
|
||||
if not cu_events: continue
|
||||
prg_cu = sorted(cu_events, key=row_tuple)
|
||||
kernel = trace.keys[r].ret if (r:=ref_map.get(name)) else None
|
||||
steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)),
|
||||
{"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=1))
|
||||
for k in sorted(wave_insts, key=row_tuple): steps.append(create_step(k, ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[k], depth=2))
|
||||
src = f"Scheduled on {len(prg_cu)} CUs"+(f"\n\n{kernel.global_size=} {kernel.local_size=}" if kernel else "")
|
||||
steps.append(create_step(kernel.name if kernel is not None else name, ("/counters", len(ctxs), len(steps)), {"src":src}, depth=1))
|
||||
for cu in prg_cu:
|
||||
events = [ProfilePointEvent(unit, "start", unit, ts=Decimal(0)) for unit in units]+cu_events[cu]
|
||||
steps.append(create_step(cu, ("/counters", len(ctxs), len(steps)),
|
||||
{"value":get_profile(events, sort_fn=row_tuple), "content_type":"application/octet-stream"}, depth=2))
|
||||
for k in sorted(wave_insts.get(cu, []), key=row_tuple):
|
||||
steps.append(create_step(k.replace(cu, ""), ("/sqtt-insts", len(ctxs), len(steps)), wave_insts[cu][k], depth=3))
|
||||
ctxs.append({"name":"Counters", "steps":steps})
|
||||
|
||||
def device_sort_fn(k):
|
||||
|
||||
Reference in New Issue
Block a user