lint sqtt parser with mypy (#13079)

* llvm address table errs

* mypy likes annotated dicts

* unwrap nullable
This commit is contained in:
qazal
2025-11-04 00:53:59 +08:00
committed by GitHub
parent 2d2040bc92
commit 6df34a5887

View File

@@ -1,13 +1,13 @@
import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools
from extra.sqtt.rocprof import rocprof
from tinygrad.helpers import temp, DEBUG
from tinygrad.helpers import temp, unwrap, DEBUG
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
from tinygrad.runtime.autogen import llvm
from tinygrad.runtime.support.elf import elf_loader
# to pass NULL to callbacks
llvm.LLVMCreateDisasmCPUFeatures.argtypes = llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5] + [ctypes.c_void_p, ctypes.c_void_p]
llvm.LLVMCreateDisasmCPUFeatures.argtypes = tuple(llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5]) + (ctypes.c_void_p, ctypes.c_void_p)
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
llvm.LLVMInitializeAMDGPUTargetInfo()
llvm.LLVMInitializeAMDGPUTargetMC()
@@ -16,8 +16,8 @@ def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, None, None)
image, sections, relocs = elf_loader(lib)
text = next((sh.header for sh in sections if sh.name == ".text"), -1)
off, sz = text.sh_addr, text.sh_size
text = next((sh.header for sh in sections if sh.name == ".text"), None)
off, sz = unwrap(text).sh_addr, unwrap(text).sh_size
addr_table:dict[int, tuple[str, int]] = {}
out = ctypes.create_string_buffer(128)
@@ -44,12 +44,14 @@ class InstInfo:
class _ROCParseCtx:
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs
self.wave_events, self.disasms, self.addr2prg = {}, {}, {}
self.wave_events:dict[tuple[str, int, int, int], dict[int, InstInfo]] = {}
self.disasms:dict[int, tuple[str, int]] = {}
self.addr2prg:dict[int, ProfileProgramEvent] = {}
for prog in prog_evs:
for addr, info in llvm_disasm(dev_evs[prog.device].arch, prog.lib).items():
self.disasms[prog.base + addr] = info
self.addr2prg[prog.base + addr] = prog
for addr, info in llvm_disasm(dev_evs[prog.device].arch, unwrap(prog.lib)).items():
self.disasms[unwrap(prog.base) + addr] = info
self.addr2prg[unwrap(prog.base) + addr] = prog
def next_sqtt(self):
x = next(self.sqtt_evs, None)
@@ -64,7 +66,7 @@ class _ROCParseCtx:
def on_wave_ev(self, ev):
if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time)
asm = {}
asm:dict[int, InstInfo] = {}
for j in range(ev.instructions_size):
inst_ev = ev.instructions_array[j]
inst_typ = rocprof.rocprofiler_thread_trace_decoder_inst_category_t__enumvalues[inst_ev.category]