diff --git a/extra/sqtt/roc.py b/extra/sqtt/roc.py index 36e078ce51..8021c816cf 100644 --- a/extra/sqtt/roc.py +++ b/extra/sqtt/roc.py @@ -1,13 +1,13 @@ import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools from extra.sqtt.rocprof import rocprof -from tinygrad.helpers import temp, DEBUG +from tinygrad.helpers import temp, unwrap, DEBUG from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent from tinygrad.runtime.autogen import llvm from tinygrad.runtime.support.elf import elf_loader # to pass NULL to callbacks -llvm.LLVMCreateDisasmCPUFeatures.argtypes = llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5] + [ctypes.c_void_p, ctypes.c_void_p] +llvm.LLVMCreateDisasmCPUFeatures.argtypes = tuple(llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5]) + (ctypes.c_void_p, ctypes.c_void_p) def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]: llvm.LLVMInitializeAMDGPUTargetInfo() llvm.LLVMInitializeAMDGPUTargetMC() @@ -16,8 +16,8 @@ def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]: ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, None, None) image, sections, relocs = elf_loader(lib) - text = next((sh.header for sh in sections if sh.name == ".text"), -1) - off, sz = text.sh_addr, text.sh_size + text = next((sh.header for sh in sections if sh.name == ".text"), None) + off, sz = unwrap(text).sh_addr, unwrap(text).sh_size addr_table:dict[int, tuple[str, int]] = {} out = ctypes.create_string_buffer(128) @@ -44,12 +44,14 @@ class InstInfo: class _ROCParseCtx: def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]): self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs - self.wave_events, self.disasms, self.addr2prg = {}, {}, {} + self.wave_events:dict[tuple[str, int, int, int], dict[int, InstInfo]] = {} + self.disasms:dict[int, tuple[str, int]] = {} + self.addr2prg:dict[int, ProfileProgramEvent] = {} for prog in prog_evs: - for addr, info in llvm_disasm(dev_evs[prog.device].arch, prog.lib).items(): - self.disasms[prog.base + addr] = info - self.addr2prg[prog.base + addr] = prog + for addr, info in llvm_disasm(dev_evs[prog.device].arch, unwrap(prog.lib)).items(): + self.disasms[unwrap(prog.base) + addr] = info + self.addr2prg[unwrap(prog.base) + addr] = prog def next_sqtt(self): x = next(self.sqtt_evs, None) @@ -64,7 +66,7 @@ class _ROCParseCtx: def on_wave_ev(self, ev): if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time) - asm = {} + asm:dict[int, InstInfo] = {} for j in range(ev.instructions_size): inst_ev = ev.instructions_array[j] inst_typ = rocprof.rocprofiler_thread_trace_decoder_inst_category_t__enumvalues[inst_ev.category]