diff --git a/extra/assembly/amd/test/test_sqtt_examples.py b/extra/assembly/amd/test/test_sqtt_examples.py index 4d921971f3..c030517b6a 100644 --- a/extra/assembly/amd/test/test_sqtt_examples.py +++ b/extra/assembly/amd/test/test_sqtt_examples.py @@ -1,16 +1,10 @@ #!/usr/bin/env python3 """Tests for SQTT packet decoding using real captured examples.""" -import pickle, unittest, ctypes, threading +import pickle, unittest from pathlib import Path from tinygrad.helpers import DEBUG -from tinygrad.runtime.autogen import rocprof -from tinygrad.runtime.support.elf import elf_loader -from extra.assembly.amd.decode import decode_inst -from extra.assembly.amd.autogen.rdna3.ins import SOPP -from extra.assembly.amd.autogen.rdna3.enum import SOPPOp from extra.assembly.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVEEND, INST, VALUINST, IMMEDIATE, IMMEDIATE_MASK, ALUEXEC, VMEMEXEC, PACKET_TYPES, InstOp, print_packets) -from extra.assembly.amd.test.helpers import TARGET_TO_ARCH EXAMPLES_DIR = Path(__file__).parent.parent.parent.parent / "sqtt/examples" # INST ops for non-traced SIMDs (excluded from instruction count) @@ -24,71 +18,18 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD # ROCPROF DECODER # ═══════════════════════════════════════════════════════════════════════════════ -def run_rocprof_decoder(blobs: list[bytes], lib: bytes, base: int, target: str): +def run_rocprof_decoder(events: list, lib: bytes, base: int, target: int): """Run rocprof decoder on SQTT blobs, returning raw occupancy and instruction records.""" - image, sections, _ = elf_loader(lib) - text = next((sh for sh in sections if sh.name == ".text"), None) - assert text is not None, "no .text section found" - text_off, text_size = text.header.sh_addr, text.header.sh_size - - blob_iter, current_blob = iter(blobs), [None] + from tinygrad.viz.serve import llvm_disasm + from extra.sqtt.roc import decode as roc_decode occupancy_records: list[tuple[int, int, int, int, bool]] = [] # (wave_id, simd, cu, time, is_start) wave_insts: list[list[tuple[int, int]]] = [] # per-wave list of (time, stall) - - @rocprof.rocprof_trace_decoder_se_data_callback_t - def copy_cb(buf, buf_size, _): - blob = next(blob_iter, None) - if blob is None: return 0 - current_blob[0] = (ctypes.c_ubyte * len(blob)).from_buffer_copy(blob) - buf[0] = ctypes.cast(current_blob[0], ctypes.POINTER(ctypes.c_ubyte)) - buf_size[0] = len(current_blob[0]) - return len(current_blob[0]) - - @rocprof.rocprof_trace_decoder_trace_callback_t - def trace_cb(record_type, events_ptr, n, _): - if record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_OCCUPANCY: - for ev in (rocprof.rocprofiler_thread_trace_decoder_occupancy_t * n).from_address(events_ptr): - occupancy_records.append((ev.wave_id, ev.simd, ev.cu, ev.time, ev.start)) - elif record_type == rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_WAVE: - for ev in (rocprof.rocprofiler_thread_trace_decoder_wave_t * n).from_address(events_ptr): - if ev.instructions_size > 0: - sz = ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t) - insts_blob = bytearray(sz) - ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz) - insts = list((rocprof.rocprofiler_thread_trace_decoder_inst_t * ev.instructions_size).from_buffer(insts_blob)) - wave_insts.append([(inst.time, inst.stall) for inst in insts]) - return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS - - arch = TARGET_TO_ARCH[target] - @rocprof.rocprof_trace_decoder_isa_callback_t - def isa_cb(instr_ptr, mem_size_ptr, size_ptr, pc, _): - offset = pc.address - base - if offset < text_off or offset >= text_off + text_size: - mem_size_ptr[0] = 0 - return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS - try: - inst = decode_inst(image[offset:], arch=arch) - mem_size_ptr[0] = inst._size() - # this could be an error in our decode_inst - except (ValueError, AssertionError): - mem_size_ptr[0] = 0 - return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS - if isinstance(inst, SOPP) and inst.op == SOPPOp.S_ENDPGM: mem_size_ptr[0] = 0 - # rocprof parses instruction string to determine type; v_nop works for all - if (max_sz := size_ptr[0]) == 0: return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_ERROR_OUT_OF_RESOURCES - ctypes.memmove(instr_ptr, b"v_nop", min(5, max_sz - 1)) - size_ptr[0] = min(5, max_sz - 1) - return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS - - exc = None - def worker(): - nonlocal exc - try: rocprof.rocprof_trace_decoder_parse_data(copy_cb, trace_cb, isa_cb, None) - except Exception as e: exc = e - (t:=threading.Thread(target=worker, daemon=True)).start() - t.join(timeout=1) - if exc is not None: raise exc - if t.is_alive(): raise RuntimeError("rocprof decoder timeout") + disasm = {addr+base:inst_disasm for addr, inst_disasm in llvm_disasm(110000, lib).items()} + rctx = roc_decode(events, {(e:=events[0]).kern:disasm}) + occ_events = rctx.occ_events[(e.kern, e.exec_tag)] + wave_events = rctx.inst_execs.get((e.kern, e.exec_tag), []) + for e in occ_events: occupancy_records.append((e.wave_id, e.simd, e.cu, e.time, e.start)) + for e in wave_events: wave_insts.append([(i.time, i.stall) for i in e.unpack_insts()]) return occupancy_records, wave_insts class TestSQTTExamples(unittest.TestCase): @@ -100,10 +41,13 @@ class TestSQTTExamples(unittest.TestCase): for pkl_path in sorted((EXAMPLES_DIR/cls.target).glob("*.pkl")): with open(pkl_path, "rb") as f: data = pickle.load(f) - sqtt_events = [e for e in data if type(e).__name__ == "ProfileSQTTEvent"] - prg = next((e for e in data if type(e).__name__ == "ProfileProgramEvent"), None) - if sqtt_events and prg: - cls.examples[pkl_path.stem] = (sqtt_events, prg.lib, prg.base) + sqtt_events:dict[str, list] = {} + for e in data: + if type(e).__name__ == "ProfileDeviceEvent" and e.device.startswith("AMD"): cls.gfx_num = e.props["gfx_target_version"] + if type(e).__name__ == "ProfileSQTTEvent": sqtt_events.setdefault(e.kern, []).append(e) + prg = {e.name:e for e in data if type(e).__name__ == "ProfileProgramEvent"} + for name, events in sqtt_events.items(): + cls.examples[pkl_path.stem+"_"+name] = (events, prg[name].lib, prg[name].base) def test_examples_loaded(self): self.assertGreater(len(self.examples), 0, "no example files found") @@ -147,17 +91,20 @@ class TestSQTTExamples(unittest.TestCase): self.assertGreater(len([p for p in all_packets if isinstance(p, INST)]), 0, f"no INST packets in {name}") expected = { - "profile_empty_run_0": [1803, 1908, 1928, 1979, 2006, 1912], - "profile_empty_run_1": [1803, 1908, 1928, 1979, 2006, 1912], - "profile_gemm_run_0": [2531, 1844, 1864, 1915, 1942, 1848, 3074, 1919, 1939, 1990, 2017, 1923, 19026, 1919, 1939, 1990, 2017, 1929], - "profile_gemm_run_1": [2554, 1844, 1864, 1915, 1942, 1848, 3084, 1919, 1939, 1990, 2017, 1923, 19010, 1919, 1939, 1990, 2017, 1923], - "profile_plus_run_0": [1900, 1908, 1928, 1979, 2006, 1912], - "profile_plus_run_1": [1856, 1908, 1928, 1979, 2006, 1912], + "profile_empty_run_0_E": [1803, 1908, 1928, 1979, 2006, 1912], + "profile_empty_run_1_E": [1803, 1908, 1928, 1979, 2006, 1912], + "profile_gemm_run_0_E_32_32_4": [2531, 1844, 1864, 1915, 1942, 1848], + "profile_gemm_run_0_E_8_8_16_4": [3074, 1919, 1939, 1990, 2017, 1923], + "profile_gemm_run_0_r_2_8_16_4_4_16_4": [19026, 1919, 1939, 1990, 2017, 1929], + "profile_gemm_run_1_E_32_32_4": [2554, 1844, 1864, 1915, 1942, 1848], + "profile_gemm_run_1_E_8_8_16_4": [3084, 1919, 1939, 1990, 2017, 1923], + "profile_gemm_run_1_r_2_8_16_4_4_16_4": [19010, 1919, 1939, 1990, 2017, 1923], + "profile_plus_run_0_E_3": [1900, 1908, 1928, 1979, 2006, 1912], + "profile_plus_run_1_E_3": [1856, 1908, 1928, 1979, 2006, 1912], } def test_packet_counts(self): for name, (events, *_) in self.examples.items(): with self.subTest(example=name): - if not self.expected.get(name): continue counts = [len(list(decode(e.blob))) for e in events] self.assertEqual(counts, self.expected[name], f"packet count mismatch in {name}") @@ -165,7 +112,7 @@ class TestSQTTExamples(unittest.TestCase): """Wave start/end times must match rocprof exactly.""" for name, (events, lib, base) in self.examples.items(): with self.subTest(example=name): - occupancy, _ = run_rocprof_decoder([e.blob for e in events], lib, base, self.target) + occupancy, _ = run_rocprof_decoder(events, lib, base, self.gfx_num) # extract from rocprof occupancy records roc_starts: dict[tuple[int, int, int], int] = {} roc_waves: list[tuple[int, int]] = [] @@ -187,7 +134,7 @@ class TestSQTTExamples(unittest.TestCase): """Instruction times must match rocprof exactly (excluding s_endpgm).""" for name, (events, lib, base) in self.examples.items(): with self.subTest(example=name): - _, wave_insts = run_rocprof_decoder([e.blob for e in events], lib, base, self.target) + _, wave_insts = run_rocprof_decoder(events, lib, base, self.gfx_num) # skip last inst per wave (s_endpgm) - it needs special handling (time + duration instead of time + stall) roc_insts = [time + stall for insts in wave_insts for time, stall in insts[:-1]] # extract from our decoder