mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
use existing roc.py infra for sqtt tests (#14297)
* add pc, per kernel tracing * work * remove those imports * min diff
This commit is contained in:
@@ -1,16 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for SQTT packet decoding using real captured examples."""
|
||||
import pickle, unittest, ctypes, threading
|
||||
import pickle, unittest
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.runtime.autogen import rocprof
|
||||
from tinygrad.runtime.support.elf import elf_loader
|
||||
from extra.assembly.amd.decode import decode_inst
|
||||
from extra.assembly.amd.autogen.rdna3.ins import SOPP
|
||||
from extra.assembly.amd.autogen.rdna3.enum import SOPPOp
|
||||
from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK,
|
||||
ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets)
|
||||
from extra.assembly.amd.test.helpers import TARGET_TO_ARCH
|
||||
|
||||
EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples"
|
||||
# INST ops for non-traced SIMDs (excluded from instruction count)
|
||||
@@ -24,71 +18,18 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
|
||||
# ROCPROF DECODER
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str):
|
||||
def run_rocprof_decoder(events: list, lib: bytes, base: int, target: int):
|
||||
"""Run rocprof decoder on SQTT blobs, returning raw occupancy and instruction records."""
|
||||
image, sections, _ = elf_loader(lib)
|
||||
text = next((sh for sh in sections if sh.name == ".text"), None)
|
||||
assert text is not None, "no .text section found"
|
||||
text_off, text_size = text.header.sh_addr, text.header.sh_size
|
||||
|
||||
blob_iter, current_blob = iter(blobs), [None]
|
||||
from tinygrad.viz.serve import llvm_disasm
|
||||
from extra.sqtt.roc import decode as roc_decode
|
||||
occupancy_records: list[tuple[int, int, int, int, bool]] = [] # (wave_id, simd, cu, time, is_start)
|
||||
wave_insts: list[list[tuple[int, int]]] = [] # per-wave list of (time, stall)
|
||||
|
||||
@rocprof.rocprof_trace_decoder_se_data_callback_t
|
||||
def copy_cb(buf, buf_size, _):
|
||||
blob = next(blob_iter, None)
|
||||
if blob is None: return 0
|
||||
current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob)
|
||||
buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte))
|
||||
buf_size[0] = len(current_blob[0])
|
||||
return len(current_blob[0])
|
||||
|
||||
@rocprof.rocprof_trace_decoder_trace_callback_t
|
||||
def trace_cb(record_type, events_ptr, n, _):
|
||||
if record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_OCCUPANCY:
|
||||
for ev in (rocprof.rocprofiler_thread_trace_decoder_occupancy_t * n).from_address(events_ptr):
|
||||
occupancy_records.append((ev.wave_id, ev.simd, ev.cu, ev.time, ev.start))
|
||||
elif record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_WAVE:
|
||||
for ev in (rocprof.rocprofiler_thread_trace_decoder_wave_t * n).from_address(events_ptr):
|
||||
if ev.instructions_size > 0:
|
||||
sz = ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t)
|
||||
insts_blob = bytearray(sz)
|
||||
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
|
||||
insts = list((rocprof.rocprofiler_thread_trace_decoder_inst_t * ev.instructions_size).from_buffer(insts_blob))
|
||||
wave_insts.append([(inst.time, inst.stall) for inst in insts])
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
|
||||
arch = TARGET_TO_ARCH[target]
|
||||
@rocprof.rocprof_trace_decoder_isa_callback_t
|
||||
def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, _):
|
||||
offset = pc.address - base
|
||||
if offset < text_off or offset >= text_off + text_size:
|
||||
mem_size_ptr[0] = 0
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
try:
|
||||
inst = decode_inst(image[offset:], arch=arch)
|
||||
mem_size_ptr[0] = inst._size()
|
||||
# this could be an error in our decode_inst
|
||||
except (ValueError, AssertionError):
|
||||
mem_size_ptr[0] = 0
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: mem_size_ptr[0] = 0
|
||||
# rocprof parses instruction string to determine type; v_nop works for all
|
||||
if (max_sz := size_ptr[0]) == 0: return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_ERROR_OUT_OF_RESOURCES
|
||||
ctypes.memmove(instr_ptr, b"v_nop", min(5, max_sz - 1))
|
||||
size_ptr[0] = min(5, max_sz - 1)
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
|
||||
exc = None
|
||||
def worker():
|
||||
nonlocal exc
|
||||
try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None)
|
||||
except Exception as e: exc = e
|
||||
(t:=threading.Thread(target=worker, daemon=True)).start()
|
||||
t.join(timeout=1)
|
||||
if exc is not None: raise exc
|
||||
if t.is_alive(): raise RuntimeError("rocprof decoder timeout")
|
||||
disasm = {addr+base:inst_disasm for addr, inst_disasm in llvm_disasm(110000, lib).items()}
|
||||
rctx = roc_decode(events, {(e:=events[0]).kern:disasm})
|
||||
occ_events = rctx.occ_events[(e.kern, e.exec_tag)]
|
||||
wave_events = rctx.inst_execs.get((e.kern, e.exec_tag), [])
|
||||
for e in occ_events: occupancy_records.append((e.wave_id, e.simd, e.cu, e.time, e.start))
|
||||
for e in wave_events: wave_insts.append([(i.time, i.stall) for i in e.unpack_insts()])
|
||||
return occupancy_records, wave_insts
|
||||
|
||||
class TestSQTTExamples(unittest.TestCase):
|
||||
@@ -100,10 +41,13 @@ class TestSQTTExamples(unittest.TestCase):
|
||||
for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")):
|
||||
with open(pkl_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"]
|
||||
prg = next((e for e in data if type(e).__name__ == "ProfileProgramEvent"), None)
|
||||
if sqtt_events and prg:
|
||||
cls.examples[pkl_path.stem] = (sqtt_events, prg.lib, prg.base)
|
||||
sqtt_events:dict[str, list] = {}
|
||||
for e in data:
|
||||
if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"): cls.gfx_num = e.props["gfx_target_version"]
|
||||
if type(e).__name__ == "ProfileSQTTEvent": sqtt_events.setdefault(e.kern, []).append(e)
|
||||
prg = {e.name:e for e in data if type(e).__name__ == "ProfileProgramEvent"}
|
||||
for name, events in sqtt_events.items():
|
||||
cls.examples[pkl_path.stem+"_"+name] = (events, prg[name].lib, prg[name].base)
|
||||
|
||||
def test_examples_loaded(self):
|
||||
self.assertGreater(len(self.examples), 0, "no example files found")
|
||||
@@ -147,17 +91,20 @@ class TestSQTTExamples(unittest.TestCase):
|
||||
self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}")
|
||||
|
||||
expected = {
|
||||
"profile_empty_run_0": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_empty_run_1": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_gemm_run_0": [2531, 1844, 1864, 1915, 1942, 1848, 3074, 1919, 1939, 1990, 2017, 1923, 19026, 1919, 1939, 1990, 2017, 1929],
|
||||
"profile_gemm_run_1": [2554, 1844, 1864, 1915, 1942, 1848, 3084, 1919, 1939, 1990, 2017, 1923, 19010, 1919, 1939, 1990, 2017, 1923],
|
||||
"profile_plus_run_0": [1900, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_plus_run_1": [1856, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_empty_run_0_E": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_empty_run_1_E": [1803, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_gemm_run_0_E_32_32_4": [2531, 1844, 1864, 1915, 1942, 1848],
|
||||
"profile_gemm_run_0_E_8_8_16_4": [3074, 1919, 1939, 1990, 2017, 1923],
|
||||
"profile_gemm_run_0_r_2_8_16_4_4_16_4": [19026, 1919, 1939, 1990, 2017, 1929],
|
||||
"profile_gemm_run_1_E_32_32_4": [2554, 1844, 1864, 1915, 1942, 1848],
|
||||
"profile_gemm_run_1_E_8_8_16_4": [3084, 1919, 1939, 1990, 2017, 1923],
|
||||
"profile_gemm_run_1_r_2_8_16_4_4_16_4": [19010, 1919, 1939, 1990, 2017, 1923],
|
||||
"profile_plus_run_0_E_3": [1900, 1908, 1928, 1979, 2006, 1912],
|
||||
"profile_plus_run_1_E_3": [1856, 1908, 1928, 1979, 2006, 1912],
|
||||
}
|
||||
def test_packet_counts(self):
|
||||
for name, (events, *_) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
if not self.expected.get(name): continue
|
||||
counts = [len(list(decode(e.blob))) for e in events]
|
||||
self.assertEqual(counts, self.expected[name], f"packet count mismatch in {name}")
|
||||
|
||||
@@ -165,7 +112,7 @@ class TestSQTTExamples(unittest.TestCase):
|
||||
"""Wave start/end times must match rocprof exactly."""
|
||||
for name, (events, lib, base) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
|
||||
occupancy, _ = run_rocprof_decoder(events, lib, base, self.gfx_num)
|
||||
# extract from rocprof occupancy records
|
||||
roc_starts: dict[tuple[int, int, int], int] = {}
|
||||
roc_waves: list[tuple[int, int]] = []
|
||||
@@ -187,7 +134,7 @@ class TestSQTTExamples(unittest.TestCase):
|
||||
"""Instruction times must match rocprof exactly (excluding s_endpgm)."""
|
||||
for name, (events, lib, base) in self.examples.items():
|
||||
with self.subTest(example=name):
|
||||
_, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base, self.target)
|
||||
_, wave_insts = run_rocprof_decoder(events, lib, base, self.gfx_num)
|
||||
# skip last inst per wave (s_endpgm) - it needs special handling (time + duration instead of time + stall)
|
||||
roc_insts = [time + stall for insts in wave_insts for time, stall in insts[:-1]]
|
||||
# extract from our decoder
|
||||
|
||||
Reference in New Issue
Block a user