mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
@@ -48,17 +48,11 @@ class OccEvent(WaveSlot):
|
||||
RunKey = tuple[str, int]
|
||||
|
||||
class _ROCParseCtx:
|
||||
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
|
||||
self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs
|
||||
self.disasms:dict[str, dict[int, tuple[str, int]]] = {}
|
||||
def __init__(self, sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]):
|
||||
self.sqtt_evs, self.disasms = iter(sqtt_evs), disasms
|
||||
self.inst_execs:dict[RunKey, list[WaveExec]] = {}
|
||||
self.occ_events:dict[RunKey, list[OccEvent]] = {}
|
||||
|
||||
for prog in prog_evs:
|
||||
base = unwrap(prog.base)
|
||||
target = unwrap(dev_evs[prog.device].props)['gfx_target_version']
|
||||
self.disasms[prog.name] = asm = {base+addr:info for addr,info in llvm_disasm(target, unwrap(prog.lib)).items()}
|
||||
|
||||
def next_sqtt(self):
|
||||
x = next(self.sqtt_evs, None)
|
||||
self.active_run = (x.kern, x.exec_tag) if x is not None else None
|
||||
@@ -81,16 +75,8 @@ class _ROCParseCtx:
|
||||
self.inst_execs.setdefault(unwrap(self.active_run), []).append(WaveExec(ev.wave_id, ev.cu, ev.simd, unwrap(self.active_se), ev.begin_time,
|
||||
ev.end_time, insts_blob))
|
||||
|
||||
def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
|
||||
dev_events:dict[str, ProfileDeviceEvent] = {}
|
||||
sqtt_events:list[ProfileSQTTEvent] = []
|
||||
prog_events:list[ProfileProgramEvent] = []
|
||||
for e in profile:
|
||||
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
|
||||
if isinstance(e, ProfileSQTTEvent): sqtt_events.append(e)
|
||||
if isinstance(e, ProfileProgramEvent) and e.device.startswith("AMD"): prog_events.append(e)
|
||||
|
||||
ROCParseCtx = _ROCParseCtx(dev_events, sqtt_events, prog_events)
|
||||
def decode(sqtt_evs:list[ProfileSQTTEvent], disasms:dict[str, dict[int, tuple[str, int]]]) -> _ROCParseCtx:
|
||||
ROCParseCtx = _ROCParseCtx(sqtt_evs, disasms)
|
||||
|
||||
@rocprof.rocprof_trace_decoder_se_data_callback_t
|
||||
def copy_cb(buf, buf_size, _):
|
||||
@@ -150,7 +136,7 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
with args.profile.open("rb") as f: profile = pickle.load(f)
|
||||
rctx = decode(profile)
|
||||
print('SQTT:', rctx.inst_execs.keys())
|
||||
#rctx = decode(profile, disasm)
|
||||
#print('SQTT:', rctx.inst_execs.keys())
|
||||
|
||||
print_pmc([ev for ev in profile if isinstance(ev, ProfilePMCEvent)])
|
||||
|
||||
@@ -243,13 +243,11 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
||||
counter_events:dict[tuple[str, int], dict] = {}
|
||||
durations:dict[str, list[float]] = {}
|
||||
prg_events:dict[str, ProfileProgramEvent] = {}
|
||||
dev_events:dict[str, ProfileDeviceEvent] = {}
|
||||
for e in profile:
|
||||
if isinstance(e, (ProfilePMCEvent, ProfileSQTTEvent)): counter_events.setdefault((e.kern, e.exec_tag), {}).setdefault(type(e), []).append(e)
|
||||
if isinstance(e, ProfileRangeEvent) and e.device.startswith("AMD") and e.en is not None:
|
||||
durations.setdefault(str(e.name), []).append(float(e.en-e.st))
|
||||
if isinstance(e, ProfileProgramEvent): prg_events[str(e.name)] = e
|
||||
if isinstance(e, ProfileDeviceEvent): dev_events[e.device] = e
|
||||
if len(counter_events) == 0: return None
|
||||
ctxs.append({"name":"All Counters", "steps":[create_step("PMC", ("/all-pmc", len(ctxs), 0), (durations, all_counters:={}))]})
|
||||
run_number = {n:0 for n,_ in counter_events}
|
||||
@@ -263,7 +261,7 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
||||
all_counters[(name, run_number[k], k)] = pmc[0]
|
||||
if (sqtt:=v.get(ProfileSQTTEvent)):
|
||||
# to decode a SQTT trace, we need the raw stream, program binary and device properties
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), [*sqtt, prg_events[k], dev_events[sqtt[0].device]])))
|
||||
steps.append(create_step("SQTT", ("/prg-sqtt", len(ctxs), len(steps)), ((k, tag), sqtt, prg_events[k])))
|
||||
if getenv("SQTT_PARSE"):
|
||||
# run our decoder on startup, we don't use this since it only works on gfx11
|
||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||
@@ -272,11 +270,12 @@ def load_counters(profile:list[ProfileEvent]) -> None:
|
||||
|
||||
# ** SQTT OCC only unpacks wave start, end time and SIMD location
|
||||
|
||||
def unpack_sqtt(key:tuple[str, int], profile:list[ProfileEvent]) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
|
||||
def unpack_sqtt(key:tuple[str, int], data:list, p:ProfileProgramEvent) -> tuple[dict[str, list[ProfileEvent]], list[str], dict[str, dict[str, dict]]]:
|
||||
# * init decoder
|
||||
from extra.sqtt.roc import decode
|
||||
rctx = decode(profile)
|
||||
disasm = rctx.disasms[key[0]]
|
||||
base = unwrap(p.base)
|
||||
disasm = {addr+base:inst_disasm for addr,inst_disasm in llvm_disasm(device_props[p.device]["gfx_target_version"], unwrap(p.lib)).items()}
|
||||
rctx = decode(data, {p.name:disasm})
|
||||
cu_events:dict[str, list[ProfileEvent]] = {}
|
||||
# * INST waves
|
||||
wave_insts:dict[str, dict[str, dict]] = {}
|
||||
|
||||
Reference in New Issue
Block a user