lint sqtt parser with mypy (#13079)

* llvm address table errs

* mypy likes annotated dicts

* unwrap nullable
This commit is contained in:
qazal
2025-11-04 00:53:59 +08:00
committed by GitHub
parent 2d2040bc92
commit 6df34a5887

View File

@@ -1,13 +1,13 @@
import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools import ctypes, pathlib, argparse, pickle, re, functools, dataclasses, itertools
from extra.sqtt.rocprof import rocprof from extra.sqtt.rocprof import rocprof
from tinygrad.helpers import temp, DEBUG from tinygrad.helpers import temp, unwrap, DEBUG
from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent from tinygrad.device import ProfileEvent, ProfileDeviceEvent, ProfileProgramEvent
from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent from tinygrad.runtime.ops_amd import ProfileSQTTEvent, ProfilePMCEvent
from tinygrad.runtime.autogen import llvm from tinygrad.runtime.autogen import llvm
from tinygrad.runtime.support.elf import elf_loader from tinygrad.runtime.support.elf import elf_loader
# to pass NULL to callbacks # to pass NULL to callbacks
llvm.LLVMCreateDisasmCPUFeatures.argtypes = llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5] + [ctypes.c_void_p, ctypes.c_void_p] llvm.LLVMCreateDisasmCPUFeatures.argtypes = tuple(llvm.LLVMCreateDisasmCPUFeatures.argtypes[:5]) + (ctypes.c_void_p, ctypes.c_void_p)
def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]: def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
llvm.LLVMInitializeAMDGPUTargetInfo() llvm.LLVMInitializeAMDGPUTargetInfo()
llvm.LLVMInitializeAMDGPUTargetMC() llvm.LLVMInitializeAMDGPUTargetMC()
@@ -16,8 +16,8 @@ def llvm_disasm(arch:str, lib:bytes) -> dict[int, tuple[str, int]]:
ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, None, None) ctx = llvm.LLVMCreateDisasmCPUFeatures("amdgcn-amd-amdhsa".encode(), arch.encode(), "".encode(), None, 0, None, None)
image, sections, relocs = elf_loader(lib) image, sections, relocs = elf_loader(lib)
text = next((sh.header for sh in sections if sh.name == ".text"), -1) text = next((sh.header for sh in sections if sh.name == ".text"), None)
off, sz = text.sh_addr, text.sh_size off, sz = unwrap(text).sh_addr, unwrap(text).sh_size
addr_table:dict[int, tuple[str, int]] = {} addr_table:dict[int, tuple[str, int]] = {}
out = ctypes.create_string_buffer(128) out = ctypes.create_string_buffer(128)
@@ -44,12 +44,14 @@ class InstInfo:
class _ROCParseCtx: class _ROCParseCtx:
def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]): def __init__(self, dev_evs:dict[str, ProfileDeviceEvent], sqtt_evs:list[ProfileSQTTEvent], prog_evs:list[ProfileProgramEvent]):
self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs self.dev_evs, self.sqtt_evs, self.prog_evs = dev_evs, iter(sqtt_evs), prog_evs
self.wave_events, self.disasms, self.addr2prg = {}, {}, {} self.wave_events:dict[tuple[str, int, int, int], dict[int, InstInfo]] = {}
self.disasms:dict[int, tuple[str, int]] = {}
self.addr2prg:dict[int, ProfileProgramEvent] = {}
for prog in prog_evs: for prog in prog_evs:
for addr, info in llvm_disasm(dev_evs[prog.device].arch, prog.lib).items(): for addr, info in llvm_disasm(dev_evs[prog.device].arch, unwrap(prog.lib)).items():
self.disasms[prog.base + addr] = info self.disasms[unwrap(prog.base) + addr] = info
self.addr2prg[prog.base + addr] = prog self.addr2prg[unwrap(prog.base) + addr] = prog
def next_sqtt(self): def next_sqtt(self):
x = next(self.sqtt_evs, None) x = next(self.sqtt_evs, None)
@@ -64,7 +66,7 @@ class _ROCParseCtx:
def on_wave_ev(self, ev): def on_wave_ev(self, ev):
if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time) if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time)
asm = {} asm:dict[int, InstInfo] = {}
for j in range(ev.instructions_size): for j in range(ev.instructions_size):
inst_ev = ev.instructions_array[j] inst_ev = ev.instructions_array[j]
inst_typ = rocprof.rocprofiler_thread_trace_decoder_inst_category_t__enumvalues[inst_ev.category] inst_typ = rocprof.rocprofiler_thread_trace_decoder_inst_category_t__enumvalues[inst_ev.category]