From 423b76a8523e75470ef7579e14f84d4f097b3a9b Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:04:10 -0800 Subject: [PATCH] improve sqtt format parser (saturday coffee shop project) (#13419) * improve sqtt format parser * actually read the trash code ChatGPT wrote * cleanups * hand written parser * quality * more * was missing first packet * maybe * filt * fixups * label the waves * progress --- extra/sqtt/active_sqtt_parse.py | 84 +++- extra/sqtt/attempt_sqtt_parse.py | 660 ++++++++++++++----------------- 2 files changed, 366 insertions(+), 378 deletions(-) diff --git a/extra/sqtt/active_sqtt_parse.py b/extra/sqtt/active_sqtt_parse.py index a803b807d0..2c04a068cf 100644 --- a/extra/sqtt/active_sqtt_parse.py +++ b/extra/sqtt/active_sqtt_parse.py @@ -25,9 +25,9 @@ def save_sqtt(): yield sqtt events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())] - rctx = decode(events) - assert len(rctx.inst_execs) > 0, "empty sqtt output" - sqtt.update(rctx.inst_execs) + #rctx = decode(events) + #assert len(rctx.inst_execs) > 0, "empty sqtt output" + #sqtt.update(rctx.inst_execs) for e in events: if isinstance(e, ProfileSQTTEvent): @@ -41,7 +41,6 @@ template = """.text .type matmul,@function matmul: INSTRUCTION - s_endpgm .rodata .p2align 6 @@ -64,7 +63,7 @@ amdhsa.kernels: .private_segment_fixed_size: 0 .wavefront_size: 32 .sgpr_count: 8 - .vgpr_count: 32 + .vgpr_count: 8 .max_flat_workgroup_size: 1024 .kernarg_segment_align: 8 .kernarg_segment_size: 8 @@ -79,21 +78,86 @@ amdhsa.kernels: .end_amdgpu_metadata """ -def run_asm(src): - NUM_WORKGROUPS = 1 +def run_asm(src, num_workgroups=1, num_waves=1): WAVE_SIZE = 32 - NUM_WAVES = 1 t = Tensor.empty(0x1000).realize() buf = t.uop.buffer.ensure_allocated() lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src))) dev.compiler.disassemble(lib) fxn = AMDProgram(dev, "matmul", lib) - fxn(buf._buf, global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) + fxn(buf._buf, global_size=(num_workgroups,1,1), local_size=(WAVE_SIZE*num_waves,1,1), wait=True) if __name__ == "__main__": + with save_sqtt() as sqtt: + run_asm([ + #"s_barrier", + #"s_nop 0", + #"s_nop 0", + #"s_nop 15", + #"s_nop 0", + #"s_nop 0", + "s_nop 0", + "s_nop 0", + "s_load_b64 s[0:1], s[0:1], null", + "s_waitcnt lgkmcnt(0)", + "s_nop 0", + "s_nop 0", + "s_nop 100", + "s_nop 100", + "s_nop 100", + "s_nop 0", + "s_nop 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", + "s_nop 0", + "s_nop 0", + "s_nop 100", + "s_nop 100", + "s_nop 100", + "s_nop 0", + "s_nop 0", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v2, v0, s[0:1]", + "global_load_b32 v3, v0, s[0:1]", + "global_load_b32 v4, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "global_load_b32 v1, v0, s[0:1]", + "s_nop 0", + "s_nop 0", + "s_nop 0", + "s_waitcnt vmcnt(0)", + "s_nop 100", + "s_nop 100", + "s_nop 100", + "s_nop 0", + "s_nop 0", + #"v_add_f32_e32 v1 v0 v0", + #"s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", + "s_endpgm", + ], num_workgroups=1, num_waves=1) + exit(0) + with save_sqtt() as sqtt: #(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize() - Tensor.empty(1).elu().realize() + Tensor.empty(1, 64).sum(axis=1).realize() + #Tensor.empty(1).exp().realize() exit(0) with save_sqtt() as sqtt: diff --git a/extra/sqtt/attempt_sqtt_parse.py b/extra/sqtt/attempt_sqtt_parse.py index 640ffc1cc6..af8ad71411 100644 --- a/extra/sqtt/attempt_sqtt_parse.py +++ b/extra/sqtt/attempt_sqtt_parse.py @@ -1,34 +1,63 @@ -import pickle -from tinygrad.helpers import getenv +import pickle, sys +from tinygrad.helpers import getenv, Timing, colored from extra.sqtt.roc import decode, ProfileSQTTEvent +# do these enums match fields in the packets? +#from tinygrad.runtime.support.amd import import_soc +#soc = import_soc([11]) +#perf_sel = {getattr(soc, k):k for k in dir(soc) if k.startswith("SQ_PERF_")} + # Instruction packets (one per ISA op) # NOTE: these are bad guesses and may be wrong! feel free to update if you know better # some names were taken from SQ_TT_TOKEN_MASK_TOKEN_EXCLUDE_SHIFT -OPCODE_NAMES = { +# we see 18 opcodes +# opcodes(18): 1 2 3 4 5 6 8 9 F 10 11 12 14 15 16 17 18 19 +# if you exclude everything, you are left with 6 +# opcodes( 6): 10 11 14 15 16 17 +# sometimes we see a lot of B, but not repeatable + +# not seen +# 7 A C + +# NOTE: INST runs before EXEC + +GOOD_OPCODE_NAMES = { + # gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show) + 0x01: "VALUINST", # gated by SQ_TT_TOKEN_EXCLUDE_VMEMEXEC_SHIFT 0x02: "VMEMEXEC", # gated by SQ_TT_TOKEN_EXCLUDE_ALUEXEC_SHIFT 0x03: "ALUEXEC", - # gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show) - 0x01: "VALUINST", + # gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT + 0x04: "IMMEDIATE", + 0x05: "IMMEDIATE_MULTIWAVE", # gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT 0x06: "WAVERDY", # gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT 0x08: "WAVEEND", 0x09: "WAVESTART", - # gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT - 0x04: "IMMEDIATE_4", - 0x05: "IMMEDIATE_5", - # some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there - 0x14: "REG", + # gated by NOT SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT + 0x0D: "PERF", + # pure time + 0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate + 0x10: "NOP", # gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT 0x12: "EVENT", + # some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there + 0x14: "REG", + # marker + 0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker + # this is the first packet + 0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B # gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT 0x18: "INST", # gated by SQ_TT_TOKEN_EXCLUDE_UTILCTR_SHIFT 0x19: "UTILCTR", +} + +OPCODE_NAMES = { + **GOOD_OPCODE_NAMES, # ------------------------------------------------------------------------ # 0x07–0x0F: pure timestamp-ish deltas @@ -37,30 +66,23 @@ OPCODE_NAMES = { 0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2 0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3 0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer) - 0x0D: "TS_DELTA_S5_W3_C", # shift=5, width=3 - 0x0E: "TS_DELTA_S7_W2", # shift=7, width=2 - 0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate # ------------------------------------------------------------------------ # 0x10–0x19: timestamps, layout headers, events, perf # ------------------------------------------------------------------------ - 0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint 0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10) 0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19 - 0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot - 0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker - 0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B } # these tables are from rocprof trace decoder # rocprof_trace_decoder_parse_data-0x11c6a0 # parse_sqtt_180 = b *rocprof_trace_decoder_parse_data-0x11c6a0+0x110040 -# ---------- 1. local_138: 256-byte state->token table ---------- +# ---------- 1. local_138: 256-byte state->opcode table ---------- -STATE_TO_TOKEN: bytes = bytes([ +STATE_TO_OPCODE: bytes = bytes([ 0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, @@ -79,17 +101,47 @@ STATE_TO_TOKEN: bytes = bytes([ 0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, ]) +# opcode mask (the bits used to determine the opcode, worked out by looking at the repeats in STATE_TO_OPCODE) + +opcode_mask = { + 0x10: 0b1111, + + 0x16: 0b1111111, + 0x17: 0b1111111, + 0x07: 0b1111111, + 0x19: 0b1111111, + 0x11: 0b1111111, + 0x12: 0b11111111, + 0x13: 0b11111111, + 0x15: 0b1111111, + + 0x18: 0b111, + 0x1: 0b111, + + 0x5: 0b11111, + 0x6: 0b11111, + 0xb: 0b11111, + 0x8: 0b11111, + 0xc: 0b11111, + 0xd: 0b11111, + + 0xf: 0b1111, + 0x14: 0b1111, + + 0x9: 0b11111, + 0xa: 0b11111, + + 0x4: 0b1111, + 0x3: 0b1111, + 0x2: 0b1111, +} # ---------- 2. DAT_0012e280: nibble budget per opcode&0x1F ---------- NIBBLE_BUDGET = [ - 0x08, 0x0C, 0x08, 0x08, 0x0C, 0x18, 0x18, 0x40, - 0x14, 0x20, 0x30, 0x14, 0x34, 0x1C, 0x30, 0x08, - 0x04, 0x18, 0x18, 0x20, 0x40, 0x40, 0x30, 0x40, - 0x14, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x08, 0x0C, 0x08, 0x08, 0x0C, 0x18, 0x18, 0x40, 0x14, 0x20, 0x30, 0x14, 0x34, 0x1C, 0x30, 0x08, + 0x04, 0x18, 0x18, 0x20, 0x40, 0x40, 0x30, 0x40, 0x14, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ] -assert len(NIBBLE_BUDGET) == 32 - # ---------- 3. delta_map from your hash nodes ---------- @@ -108,7 +160,8 @@ DELTA_MAP_DEFAULT = { 0x0B: (5, 3), # shift=5, end=8 0x0C: (5, 3), # shift=5, end=8 0x0D: (5, 3), # shift=5, end=8 - 0x0E: (7, 2), # shift=7, end=9 + # NOTE: 0x0e can never be decoded, it's not in the STATE_TO_OPCODE table + #0x0E: (7, 2), # shift=7, end=9 0x0F: (4, 4), # shift=4, end=8 0x10: (0, 0), # shift=0, end=0 (no delta) 0x11: (7, 9), # shift=7, end=16 @@ -124,307 +177,197 @@ DELTA_MAP_DEFAULT = { # ---------- 4. One-line-per-packet parser ---------- -def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: +def reg_mask(opcode): + nb_bits = NIBBLE_BUDGET[opcode & 0x1F] + shift, width = DELTA_MAP_DEFAULT[opcode] + delta_mask = ((1 << width) - 1) << shift + assert delta_mask & opcode_mask[opcode] == 0, "masks shouldn't overlap" + return ((1 << nb_bits) - 1) & ~(delta_mask | opcode_mask[opcode]) + +def decode_packet_fields(opcode: int, reg: int) -> str: """ Decode packet payloads conservatively, using: - NIBBLE_BUDGET[opcode & 0x1F] to mask reg down to true width. - DELTA_MAP_DEFAULT[opcode] to expose the "primary" field (often delta). - Per-opcode layouts derived from rocprof's decompiled consumers. """ - # --- 0. Restrict to real packet bits --------------------------------- - nb_bits = NIBBLE_BUDGET[opcode & 0x1F] - if nb_bits <= 0 or nb_bits >= 64: - pkt = reg & ((1 << 64) - 1) - else: - pkt = reg & ((1 << nb_bits) - 1) - + # --- 0. Restrict to real packet bits not used in delta --------------------------------- + pkt = reg & reg_mask(opcode) fields: list[str] = [] - shift, width = DELTA_MAP_DEFAULT.get(opcode, (0, 0)) - if width: - field_mask = (1 << width) - 1 - shaped_field = (pkt >> shift) & field_mask - else: - field_mask = 0 - shaped_field = 0 + match opcode: + case 0x01: # VALUINST + # 6 bit field + flag = (pkt >> 6) & 1 + wave = pkt >> 7 + fields.append(f"wave={wave:x}") + fields.append(f"flag={flag:X}") + case 0x02: # VMEMEXEC + # 2 bit field (pipe is a guess) + fields.append(f"pipe={pkt>>6:X}") + case 0x03: # ALUEXEC + # 2 bit field (pipe is a guess) + fields.append(f"pipe={pkt>>6:X}") + case 0x04: # IMMEDIATE_4 + # 5 bit field (actually 4) + wave = pkt >> 7 + fields.append(f"wave={wave:x}") + case 0x05: # IMMEDIATE_5 + # 16 bit field + # 1 bit per wave + fields.append(f"mask={pkt>>8:16b}") + case 0x0d: + # 20 bit field + fields.append(f"arg = {pkt>>8:X}") + case 0x12: + fields.append(f"event = {pkt>>11:X}") + case 0x15: + fields.append(f"snap = {pkt>>10:X}") + case 0x19: + # wave end + fields.append(f"ctr = {pkt>>9:X}") + case 0x11: + # DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta. + coarse = pkt >> 16 + fields.append(f"coarse=0x{coarse:02x}") + # From decomp: + # - when layout<3 and coarse&1, it sets a "has interesting wave" flag + # - when coarse&8, it marks all live waves as "terminated" + if coarse & 0x01: + fields.append("flag_wave_interest=1") + if coarse & 0x08: + fields.append("flag_terminate_all=1") + case 0x6: + # wave ready + fields.append(f"wave = {pkt>>8:X}") + case 0x8: + # wave end, this is 20 bits (FFF00) + flag7 = (pkt >> 8) & 0x3 + wgp = (pkt >> 10) & 1 + slot4 = (pkt >> 11) & 0xF + wave = (pkt >> 15) & 0x1f + fields.append(f"wave={wave:x}") + fields.append(f"wgp={wgp}") + fields.append(f"flag7={flag7}") + fields.append(f"slot4={slot4:x}") + case 0x9: + # From case 9 (WAVESTART) in multiple consumers: + # flag7 = (w >> 7) & 1 (low bit of uVar41) + # cls2 = (w >> 8) & 3 (class / group) + # slot4 = (w >> 10) & 0xf (slot / group index) + # idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path) + # idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path) + # id7 = (w >> 0x19) & 0x7f (7-bit id) + flag7 = (pkt >> 7) & 3 + wgp = (pkt >> 9) & 1 + slot3 = (pkt >> 10) & 0x7 # NOTE: this isn't 4! + wave = (pkt >> 13) & 0x1F + id7 = (pkt >> 17) + fields.append(f"wave={wave:x}") + fields.append(f"flag7={flag7}") + fields.append(f"wgp={wgp}") + fields.append(f"slot3={slot3:x}") + fields.append(f"id7=0x{id7:x}") + case 0x18: + # FFF88 is the mask + # From case 0x18: + # low3 = w & 7 + # grp3 = (w >> 3) or (w >> 4) & 7 (layout-dependent) + # flags = bits 6 (B6) and 7 (B7) + # hi8 = (w >> 0xc) & 0xff (layout 4 path) + # hi7 = (w >> 0xd) & 0x7f (other layouts) + # idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index + flag = (pkt >> 3) & 1 + flag2 = (pkt >> 7) & 1 + wave = (pkt >> 8) & 0x1F + hi8 = (pkt >> 13) + fields.append(f"wave={wave:x}") + fields.append(f"flag={flag:x}") + fields.append(f"flag2={flag2:x}") + fields.append(f"hi8=0x{hi8:x}") + case 0x14: + subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10) + val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20) + slot = (pkt >> 7) & 0x7 # index in local_168[...] tables + hi_byte = (pkt >> 8) & 0xFF # determines config vs marker - # ===================================================================== - # 1. Timestamp-centric opcodes (actually drive 'time') - # ===================================================================== + fields.append(f"subop=0x{subop:04x}") + fields.append(f"slot={slot}") + fields.append(f"val32=0x{val32:08x}") - if opcode == 0x0F: # TS_DELTA_SHORT_PLUS4 - # In the caller, delta already has +4 applied. - raw_delta = shaped_field - fields.append(f"raw_delta={raw_delta}") - fields.append(f"ts_short_plus4={delta}") - return ", ".join(fields) + if hi_byte & 0x80: + # Config flavour: writes config words into per-slot state arrays. + fields.append("kind=config") + if subop == 0x000C: + fields.append("slot=lo") + elif subop == 0x000D: + fields.append("slot=hi") + else: + # COR marker: subop 0xC342, payload "COR\0" → start of a COR region. + if subop == 0xC342: + fields.append("kind=cor_stream") + if val32 == 0x434F5200: + fields.append("cor_magic='COR\\0'") + case 0x16: + # Bits: + # bit8 -> 0x100 + # bit9 -> 0x200 + # bits 12..47 -> 36-bit field used as delta or marker + bit8 = bool(pkt & 0x100) + bit9 = bool(pkt & 0x200) + if not bit9: + mode = "delta" + elif not bit8: + mode = "marker" + else: + mode = "other" + # need to use reg here + val36 = (reg >> 12) & ((1 << 36) - 1) + fields.append(f"mode={mode}") + if mode != "delta": + fields.append(f"val36=0x{val36:x}") + case 0x17: + # From decomp (two sites with identical logic): + # layout = (w >> 7) & 0x3f + # mode = (w >> 0xd) & 3 + # group = (w >> 0xf) & 7 + # sel_a = (w >> 0x1c) & 0xf + # sel_b = (w >> 0x21) & 7 + # flag4 = (w >> 0x3b) & 1 (only meaningful when layout == 4) + layout = (pkt >> 7) & 0x3F + simd = (pkt >> 13) & 0x3 # you can change this by changing traced simd + group = (pkt >> 15) & 0x7 + sel_a = (pkt >> 0x1C) & 0xF + sel_b = (pkt >> 0x21) & 0x7 + flag4 = (pkt >> 0x3B) & 0x1 - if opcode == 0x11: # TS_WAVE_STATE_SAMPLE - # DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta. - raw_delta = shaped_field - coarse = (pkt >> (shift + width)) & 0xFF # matches byte at +10 in C - fields.append(f"raw_delta={raw_delta}") - if coarse: - fields.append(f"coarse_state=0x{coarse:02x}") - # From decomp: - # - when layout<3 and coarse&1, it sets a "has interesting wave" flag - # - when coarse&8, it marks all live waves as "terminated" - if coarse & 0x01: - fields.append("flag_wave_interest=1") - if coarse & 0x08: - fields.append("flag_terminate_all=1") - return ", ".join(fields) + fields.append(f"layout={layout}") + fields.append(f"group={group}") + fields.append(f"simd={simd}") + fields.append(f"sel_a={sel_a}") + fields.append(f"sel_b={sel_b}") + if layout == 4: + fields.append(f"layout4_flag={flag4}") + case _: + fields.append(f"& {reg_mask(opcode):X}") + return ",".join(fields) - if opcode == 0x16: # TS_DELTA36_OR_MARK - # Bits: - # bit8 -> 0x100 - # bit9 -> 0x200 - # bits 12..47 -> 36-bit field used as delta or marker - bit8 = bool(pkt & 0x100) - bit9 = bool(pkt & 0x200) - if not bit9: - mode = "delta" - elif not bit8: - mode = "marker" - else: - mode = "other" - val36 = (pkt >> 12) & ((1 << 36) - 1) - fields.append(f"mode={mode}") - if mode != "delta": - fields.append(f"val36=0x{val36:x}") - return ", ".join(fields) +FILTER_LEVEL = getenv("FILTER", 2) - # For 0x07, 0x0A–0x0E, we know they drive time (via DELTA_MAP_DEFAULT), - # but we don't see any other fields used in the decomp. - if opcode in (0x07, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E): - if width: - raw_delta = shaped_field - leftover = pkt & ~(field_mask << shift) - fields.append(f"raw_delta={raw_delta}") - if leftover: - fields.append(f"payload=0x{leftover:x}") - return ", ".join(fields) +DEFAULT_FILTER = tuple() +# NOP + pure time +if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf) +# reg + event + sample + marker +# TODO: events are probably good +if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x11, 0x14, 0x16, 0x12) +# instruction runs +if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x02, 0x03) +# instructions dispatch (inst, valuinst, immed) +if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x01, 0x4, 0x5, 0x18) +# waves +if FILTER_LEVEL >= 4: DEFAULT_FILTER += (0x6, 0x8, 0x9) - # ===================================================================== - # 2. Small "meta + tiny delta" packets (0x01–0x06) - # ===================================================================== - - if opcode == 0x01: # META_ID12_TS_SMALL - id12 = pkt & 0xFFF - fields.append(f"id12=0x{id12:03x}") - if width: - fields.append(f"field_s{shift}_w{width}={shaped_field}") - return ", ".join(fields) - - if opcode == 0x02: # META_FLAG8_TS_SMALL - flag8 = pkt & 0xFF - fields.append(f"flag8=0x{flag8:02x}") - if width: - fields.append(f"field_s{shift}_w{width}={shaped_field}") - return ", ".join(fields) - - if opcode == 0x03: # META_SUBEVENT8_TS_SMALL - sub8 = pkt & 0xFF - fields.append(f"subevent8=0x{sub8:02x}") - if width: - fields.append(f"field_s{shift}_w{width}={shaped_field}") - return ", ".join(fields) - - if opcode == 0x04: # META_BASE_INDEX12_TS - idx12 = pkt & 0xFFF - fields.append(f"base_index12=0x{idx12:03x}") - if width: - fields.append(f"field_s{shift}_w{width}={shaped_field}") - return ", ".join(fields) - - if opcode in (0x05, 0x06): # META_DESC24_TS_A/B - desc24 = pkt & 0xFFFFFF - fields.append(f"desc24=0x{desc24:06x}") - if width: - fields.append(f"field_s{shift}_w{width}={shaped_field}") - return ", ".join(fields) - - # ===================================================================== - # 3. Opcode 0x14: exec/config record (+ COR marker) - # ===================================================================== - - if opcode == 0x14: # INST_EXEC_OR_CFG - subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10) - val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20) - slot = (pkt >> 7) & 0x7 # index in local_168[...] tables - hi_byte = (pkt >> 8) & 0xFF # determines config vs marker - - fields.append(f"subop=0x{subop:04x}") - fields.append(f"slot={slot}") - fields.append(f"val32=0x{val32:08x}") - - if hi_byte & 0x80: - # Config flavour: writes config words into per-slot state arrays. - fields.append("kind=config") - if subop == 0x000C: - fields.append("cfg_target=local_168[slot].lo") - elif subop == 0x000D: - fields.append("cfg_target=local_168[slot].hi") - else: - # COR marker: subop 0xC342, payload "COR\0" → start of a COR region. - if subop == 0xC342: - fields.append("kind=cor_stream") - if val32 == 0x434F5200: - fields.append("cor_magic='COR\\0'") - return ", ".join(fields) - - # ===================================================================== - # 4. Opcode 0x17: layout / mode header - # ===================================================================== - - if opcode == 0x17: # LAYOUT_MODE_HEADER - # From decomp (two sites with identical logic): - # layout = (w >> 7) & 0x3f - # mode = (w >> 0xd) & 3 - # group = (w >> 0xf) & 7 - # sel_a = (w >> 0x1c) & 0xf - # sel_b = (w >> 0x21) & 7 - # flag4 = (w >> 0x3b) & 1 (only meaningful when layout == 4) - layout = (pkt >> 7) & 0x3F - mode = (pkt >> 13) & 0x3 - group = (pkt >> 15) & 0x7 - sel_a = (pkt >> 0x1C) & 0xF - sel_b = (pkt >> 0x21) & 0x7 - flag4 = (pkt >> 0x3B) & 0x1 - - fields.append(f"layout={layout}") - fields.append(f"group={group}") - fields.append(f"mode={mode}") - fields.append(f"sel_a={sel_a}") - fields.append(f"sel_b={sel_b}") - if layout == 4: - fields.append(f"layout4_flag={flag4}") - return ", ".join(fields) - - # ===================================================================== - # 5. Opcode 0x09: state / route config record - # ===================================================================== - - if opcode == 0x09: # PERF_ROUTE_CONFIG - # From case 9 in multiple consumers: - # flag7 = (w >> 7) & 1 (low bit of uVar41) - # cls2 = (w >> 8) & 3 (class / group) - # slot4 = (w >> 10) & 0xf (slot / group index) - # idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path) - # idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path) - # id7 = (w >> 0x19) & 0x7f (7-bit id) - flag7 = (pkt >> 7) & 0x1 - cls2 = (pkt >> 8) & 0x3 - slot4 = (pkt >> 10) & 0xF - idx_lo = (pkt >> 13) & 0x1F - idx_hi = (pkt >> 15) & 0x1F - id7 = (pkt >> 0x19) & 0x7F - - fields.append(f"flag7={flag7}") - fields.append(f"cls2={cls2}") - fields.append(f"slot4=0x{slot4:x}") - fields.append(f"idx_lo5=0x{idx_lo:x}") - fields.append(f"idx_hi5=0x{idx_hi:x}") - fields.append(f"id7=0x{id7:x}") - return ", ".join(fields) - - # ===================================================================== - # 6. Opcode 0x18: perf/event selector (FUN_0010aba0) - # ===================================================================== - - if opcode == 0x18: # PERF_EVENT_SELECT - # From case 0x18: - # low3 = w & 7 - # grp3 = (w >> 3) or (w >> 4) & 7 (layout-dependent) - # flags = bits 6 (B6) and 7 (B7) - # hi8 = (w >> 0xc) & 0xff (layout 4 path) - # hi7 = (w >> 0xd) & 0x7f (other layouts) - # idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index - low3 = pkt & 0x7 - grp3_a = (pkt >> 3) & 0x7 - grp3_b = (pkt >> 4) & 0x7 - flag_b6 = (pkt >> 6) & 0x1 - flag_b7 = (pkt >> 7) & 0x1 - idx5_a = (pkt >> 7) & 0x1F - idx5_b = (pkt >> 8) & 0x1F - hi8 = (pkt >> 12) & 0xFF - hi7 = (pkt >> 13) & 0x7F - - fields.append(f"low3=0x{low3:x}") - fields.append(f"grp3_a=0x{grp3_a:x}") - fields.append(f"grp3_b=0x{grp3_b:x}") - fields.append(f"flag_b6={flag_b6}") - fields.append(f"flag_b7={flag_b7}") - fields.append(f"idx5_a=0x{idx5_a:x}") - fields.append(f"idx5_b=0x{idx5_b:x}") - fields.append(f"hi8=0x{hi8:02x}") - fields.append(f"hi7=0x{hi7:02x}") - return ", ".join(fields) - - # ===================================================================== - # 7. Opcode 0x15: perfcounter snapshot - # ===================================================================== - - if opcode == 0x15: # PERFCOUNTER_SNAPSHOT - # NIBBLE_BUDGET gives full 64 bits here. - # DELTA_MAP_DEFAULT: shift=7, width=3 → tiny delta field. - raw_delta = shaped_field if width else 0 - # low bits below the delta field - snap_low = pkt & ((1 << shift) - 1) if shift else 0 - # everything above delta field - snap_hi = pkt >> (shift + width) if width else (pkt >> shift) - - fields.append(f"raw_delta={raw_delta}") - fields.append(f"snap_low_s{shift}=0x{snap_low:x}") - fields.append(f"snap_hi=0x{snap_hi:x}") - return ", ".join(fields) - - # ===================================================================== - # 8. Small event-ish packets (0x08 / 0x12 / 0x13 / 0x19) - # ===================================================================== - - if opcode in (0x08, 0x12, 0x13, 0x19): - # These are all "small event / metric" style tokens. The exact semantics - # depend on layout (0x17) and accumulated state (local_500 etc), so we - # expose: - # - low 8 bits as kind byte - # - rest as opaque payload. - kind = pkt & 0xFF - payload = pkt >> 8 - fields.append(f"kind_byte=0x{kind:02x}") - if payload: - fields.append(f"payload=0x{payload:x}") - return ", ".join(fields) - - # ===================================================================== - # 9. Pseudo opcode 0x10: never a "real" packet - # ===================================================================== - - if opcode == 0x10: # PSEUDO_NEED_MORE_BITS - # The main loop never prints these; they're just a control token. - return "" - - # ===================================================================== - # 10. Generic fallback: expose the DELTA_MAP_DEFAULT field + leftover - # ===================================================================== - - if width: - fields.append(f"field_s{shift}_w{width}={shaped_field}") - leftover = pkt & ~(field_mask << shift) - if leftover: - fields.append(f"payload=0x{leftover:x}") - - return ", ".join(fields) - -# 0xb is time something -# 0xd is time something -# 0xf is small time advance -# 0x11 is time advance -# 0x16 is big time advance + markers -# 0x14 is REG -DEFAULT_FILTER = (0xb, 0xd, 0xf, 0x11, 0x16, 0x14) if getenv("FILTER", 1) else None - -def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=DEFAULT_FILTER) -> None: +def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -> None: """ Minimal debug: print ONE LINE per decoded token (packet). @@ -433,49 +376,32 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=DEFAU """ n = len(data) time = 0 + last_printed_time = 0 reg = 0 # shift register offset = 0 # bit offset, in steps of 4 (one nibble) nib_budget = 0x40 flags = 0 token_index = 0 + opcodes_seen = set() - while (offset >> 3) < n and token_index < max_tokens: - # Remember where we started refilling for this step (bit offset), - # but the *logical* start of the current packet is last_real_offset. - refill_start = offset - + while (offset >> 3) < n: # 1) Fill register with nibbles according to nib_budget if nib_budget != 0: - target = refill_start + 4 + ((nib_budget - 1) & ~3) - cur = refill_start - while cur != target and (cur >> 3) < n: - byte_index = cur >> 3 - byte = data[byte_index] - shift = 4 if (cur & 4) else 0 # low then high nibble - nib = (byte >> shift) & 0xF + target = offset + 4 + ((nib_budget - 1) & ~3) + while offset != target and (offset >> 3) < n: + byte = data[offset >> 3] + nib = (byte >> (offset & 4)) & 0xF reg = ((reg >> 4) | (nib << 60)) & ((1 << 64) - 1) - cur += 4 - offset = cur + offset += 4 # 2) Decode token from low 8 bits - state = reg & 0xFF - opcode = STATE_TO_TOKEN[state] + opcode = STATE_TO_OPCODE[reg & 0xFF] + opcodes_seen.add(opcode) - # 3) Handle pseudo-token 0x10: need more bits, don't print. Looks like a NOP. - if opcode == 0x10: - # "need more bits" pseudo-token: adjust nibble budget and continue - nib_budget = 4 - if (offset >> 3) >= n: - break - # Do NOT count this as a real packet; do not update last_real_offset. - continue + # 4) Set next nibble budget based on opcode + nib_budget = NIBBLE_BUDGET[opcode & 0x1F] - # 4) Set next nibble budget - nb_index = opcode & 0x1F - nib_budget = NIBBLE_BUDGET[nb_index] - time_before = time - note = "" - # 5) Special opcode 0x16 (timestamp / marker) + # 5) Update time and handle special opcodes 0xF/0x16 if opcode == 0x16: two_bits = (reg >> 8) & 0x3 if two_bits == 1: @@ -486,7 +412,6 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=DEFAU if (reg & 0x200) == 0: # delta mode: add 36-bit delta to time delta = (reg >> 12) & ((1 << 36) - 1) - time += delta else: # marker / other modes: no time advance if (reg & 0x100) == 0: @@ -496,48 +421,47 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=DEFAU else: # 6) Generic opcode (including 0x0F) shift, width = DELTA_MAP_DEFAULT[opcode] - mask = (1 << width) - 1 - delta = (reg >> shift) & mask + delta = (reg >> shift) & ((1 << width) - 1) - # TODO: add more opcode parsers here that add notes to other opcodes + # opcode 0x0F has an offset of 4 to the delta if opcode == 0x0F: - delta_with_fix = delta + 4 - time += delta_with_fix - delta = delta_with_fix - else: - time += delta + delta = delta + 4 # Append extra decoded fields into the note string - note = decode_packet_fields(opcode, reg, delta) + note = decode_packet_fields(opcode, reg) - if filter is None or opcode not in filter: - my_reg = reg - my_reg &= (1 << nib_budget) - 1 + if verbose and (filter is None or opcode not in filter): print( f"{token_index:4d} " - f"off={offset//4:5d} " - f"op=0x{opcode:02x} " + f"time={time:8d}+{delta+(time-last_printed_time):8d} " + f"op={opcode:2x} " f"{OPCODE_NAMES[opcode]:24s} " - f" time={time_before:8d}+{delta:8d} " - f"{my_reg:16X} " + f"{reg®_mask(opcode):16X} " f"{note}" ) + #f"off={offset//4:5d} " + last_printed_time = time+delta + time += delta token_index += 1 # Optional summary at the end - print(f"# done: tokens={token_index}, final_time={time}, flags=0x{flags:02x}") + print(f"# done: tokens={token_index:_}, final_time={time}, flags=0x{flags:02x}") + if verbose: + print(f"opcodes({len(opcodes_seen):2d}):", ' '.join([colored(f"{op:2X}", "white" if op in GOOD_OPCODE_NAMES else "red") for op in opcodes_seen])) + def parse(fn:str): - dat = pickle.load(open(fn, "rb")) - ctx = decode(dat) + with Timing(f"unpickle {fn}: "): dat = pickle.load(open(fn, "rb")) + if getenv("ROCM", 0): + with Timing(f"decode {fn}: "): ctx = decode(dat) dat_sqtt = [x for x in dat if isinstance(x, ProfileSQTTEvent)] print(f"got {len(dat_sqtt)} SQTT events in {fn}") return dat_sqtt if __name__ == "__main__": - #dat_sqtt = parse("extra/sqtt/examples/profile_empty_run_0.pkl") - #dat_sqtt = parse("extra/sqtt/examples/profile_plus_run_0.pkl") - dat_sqtt = parse("extra/sqtt/examples/profile_gemm_run_0.pkl") - blob_0 = dat_sqtt[0].blob - parse_sqtt_print_packets(blob_0[8:]) + fn = "extra/sqtt/examples/profile_gemm_run_0.pkl" + dat_sqtt = parse(sys.argv[1] if len(sys.argv) > 1 else fn) + for i,dat in enumerate(dat_sqtt): + with Timing(f"decode pkt {i} with len {len(dat.blob):_}: "): + parse_sqtt_print_packets(dat.blob, verbose=getenv("V", 1))