use existing roc.py infra for sqtt tests (#14297)

* add pc, per kernel tracing

* work

* remove those imports

* min diff
This commit is contained in:
qazal
2026-01-23 00:07:11 -05:00
committed by GitHub
parent 5f32f7a06b
commit 3b8a7bb8c9

View File

@@ -1,16 +1,10 @@
#!/usr/bin/env python3
"""Tests for SQTT packet decoding using real captured examples."""
import pickle, unittest, ctypes, threading
import pickle, unittest
from pathlib import Path
from tinygrad.helpers import DEBUG
from tinygrad.runtime.autogen import rocprof
from tinygrad.runtime.support.elf import elf_loader
from extra.assembly.amd.decode import decode_inst
from extra.assembly.amd.autogen.rdna3.ins import SOPP
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets)
from extra.assembly.amd.test.helpers import TARGET_TO_ARCH
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
# INST ops for non-traced SIMDs (excluded from instruction count)
@@ -24,71 +18,18 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
# ROCPROF DECODER
# ═══════════════════════════════════════════════════════════════════════════════
def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
def run_rocprof_decoder(events: list, lib: bytes, base: int, target: int):
"""Run rocprof decoder on SQTT blobs, returning raw occupancy and instruction records."""
image, sections, _ = elf_loader(lib)
text = next((sh for sh in sections if sh.name == ".text"), None)
assert text is not None, "no .text section found"
text_off, text_size = text.header.sh_addr, text.header.sh_size
blob_iter, current_blob = iter(blobs), [None]
from tinygrad.viz.serve import llvm_disasm
from extra.sqtt.roc import decode as roc_decode
occupancy_records: list[tuple[int, int, int, int, bool]] = [] # (wave_id, simd, cu, time, is_start)
wave_insts: list[list[tuple[int, int]]] = [] # per-wave list of (time, stall)
@rocprof.rocprof_trace_decoder_se_data_callback_t
def copy_cb(buf, buf_size, _):
blob = next(blob_iter, None)
if blob is None: return 0
current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob)
buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte))
buf_size[0] = len(current_blob[0])
return len(current_blob[0])
@rocprof.rocprof_trace_decoder_trace_callback_t
def trace_cb(record_type, events_ptr, n, _):
if record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_OCCUPANCY:
for ev in (rocprof.rocprofiler_thread_trace_decoder_occupancy_t * n).from_address(events_ptr):
occupancy_records.append((ev.wave_id, ev.simd, ev.cu, ev.time, ev.start))
elif record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_WAVE:
for ev in (rocprof.rocprofiler_thread_trace_decoder_wave_t * n).from_address(events_ptr):
if ev.instructions_size > 0:
sz = ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t)
insts_blob = bytearray(sz)
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
insts = list((rocprof.rocprofiler_thread_trace_decoder_inst_t * ev.instructions_size).from_buffer(insts_blob))
wave_insts.append([(inst.time, inst.stall) for inst in insts])
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
arch = TARGET_TO_ARCH[target]
@rocprof.rocprof_trace_decoder_isa_callback_t
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, _):
offset = pc.address - base
if offset < text_off or offset >= text_off + text_size:
mem_size_ptr[0] = 0
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
try:
inst = decode_inst(image[offset:], arch=arch)
mem_size_ptr[0] = inst._size()
# this could be an error in our decode_inst
except (ValueError, AssertionError):
mem_size_ptr[0] = 0
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: mem_size_ptr[0] = 0
# rocprof parses instruction string to determine type; v_nop works for all
if (max_sz := size_ptr[0]) == 0: return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_ERROR_OUT_OF_RESOURCES
ctypes.memmove(instr_ptr, b"v_nop", min(5, max_sz - 1))
size_ptr[0] = min(5, max_sz - 1)
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
exc = None
def worker():
nonlocal exc
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
except Exception as e: exc = e
(t:=threading.Thread(target=worker, daemon=True)).start()
t.join(timeout=1)
if exc is not None: raise exc
if t.is_alive(): raise RuntimeError("rocprof decoder timeout")
disasm = {addr+base:inst_disasm for addr, inst_disasm in llvm_disasm(110000, lib).items()}
rctx = roc_decode(events, {(e:=events[0]).kern:disasm})
occ_events = rctx.occ_events[(e.kern, e.exec_tag)]
wave_events = rctx.inst_execs.get((e.kern, e.exec_tag), [])
for e in occ_events: occupancy_records.append((e.wave_id, e.simd, e.cu, e.time, e.start))
for e in wave_events: wave_insts.append([(i.time, i.stall) for i in e.unpack_insts()])
return occupancy_records, wave_insts
class TestSQTTExamples(unittest.TestCase):
@@ -100,10 +41,13 @@ class TestSQTTExamples(unittest.TestCase):
for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")):
with open(pkl_path, "rb") as f:
data = pickle.load(f)
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
prg = next((e for e in data if type(e).__name__ == "ProfileProgramEvent"), None)
if sqtt_events and prg:
cls.examples[pkl_path.stem] = (sqtt_events, prg.lib, prg.base)
sqtt_events:dict[str, list] = {}
for e in data:
if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"): cls.gfx_num = e.props["gfx_target_version"]
if type(e).__name__ == "ProfileSQTTEvent": sqtt_events.setdefault(e.kern, []).append(e)
prg = {e.name:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
for name, events in sqtt_events.items():
cls.examples[pkl_path.stem+"_"+name] = (events, prg[name].lib, prg[name].base)
def test_examples_loaded(self):
self.assertGreater(len(self.examples), 0, "no example files found")
@@ -147,17 +91,20 @@ class TestSQTTExamples(unittest.TestCase):
self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}")
expected = {
"profile_empty_run_0": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_empty_run_1": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_gemm_run_0": [2531, 1844, 1864, 1915, 1942, 1848, 3074, 1919, 1939, 1990, 2017, 1923, 19026, 1919, 1939, 1990, 2017, 1929],
"profile_gemm_run_1": [2554, 1844, 1864, 1915, 1942, 1848, 3084, 1919, 1939, 1990, 2017, 1923, 19010, 1919, 1939, 1990, 2017, 1923],
"profile_plus_run_0": [1900, 1908, 1928, 1979, 2006, 1912],
"profile_plus_run_1": [1856, 1908, 1928, 1979, 2006, 1912],
"profile_empty_run_0_E": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_empty_run_1_E": [1803, 1908, 1928, 1979, 2006, 1912],
"profile_gemm_run_0_E_32_32_4": [2531, 1844, 1864, 1915, 1942, 1848],
"profile_gemm_run_0_E_8_8_16_4": [3074, 1919, 1939, 1990, 2017, 1923],
"profile_gemm_run_0_r_2_8_16_4_4_16_4": [19026, 1919, 1939, 1990, 2017, 1929],
"profile_gemm_run_1_E_32_32_4": [2554, 1844, 1864, 1915, 1942, 1848],
"profile_gemm_run_1_E_8_8_16_4": [3084, 1919, 1939, 1990, 2017, 1923],
"profile_gemm_run_1_r_2_8_16_4_4_16_4": [19010, 1919, 1939, 1990, 2017, 1923],
"profile_plus_run_0_E_3": [1900, 1908, 1928, 1979, 2006, 1912],
"profile_plus_run_1_E_3": [1856, 1908, 1928, 1979, 2006, 1912],
}
def test_packet_counts(self):
for name, (events, *_) in self.examples.items():
with self.subTest(example=name):
if not self.expected.get(name): continue
counts = [len(list(decode(e.blob))) for e in events]
self.assertEqual(counts, self.expected[name], f"packet count mismatch in {name}")
@@ -165,7 +112,7 @@ class TestSQTTExamples(unittest.TestCase):
"""Wave start/end times must match rocprof exactly."""
for name, (events, lib, base) in self.examples.items():
with self.subTest(example=name):
occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
occupancy, _ = run_rocprof_decoder(events, lib, base, self.gfx_num)
# extract from rocprof occupancy records
roc_starts: dict[tuple[int, int, int], int] = {}
roc_waves: list[tuple[int, int]] = []
@@ -187,7 +134,7 @@ class TestSQTTExamples(unittest.TestCase):
"""Instruction times must match rocprof exactly (excluding s_endpgm)."""
for name, (events, lib, base) in self.examples.items():
with self.subTest(example=name):
_, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
_, wave_insts = run_rocprof_decoder(events, lib, base, self.gfx_num)
# skip last inst per wave (s_endpgm) - it needs special handling (time + duration instead of time + stall)
roc_insts = [time + stall for insts in wave_insts for time, stall in insts[:-1]]
# extract from our decoder