From 55be95da15e89dc5f0445866c4c9d818300c7fce Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 16 Nov 2025 13:11:51 -0800 Subject: [PATCH] cleanup sqtt raw parser (#13309) * cleanup sqtt raw parser * better names (don't merge yet) * clean up amd * a few more names * one more filter --- extra/sqtt/active_sqtt_parse.py | 56 +++++++++++++++++------ extra/sqtt/attempt_sqtt_parse.py | 76 +++++++++++++++++--------------- 2 files changed, 83 insertions(+), 49 deletions(-) diff --git a/extra/sqtt/active_sqtt_parse.py b/extra/sqtt/active_sqtt_parse.py index 0d70be863d..23eab51986 100644 --- a/extra/sqtt/active_sqtt_parse.py +++ b/extra/sqtt/active_sqtt_parse.py @@ -7,6 +7,7 @@ os.environ["AMD_LLVM"] = "0" from dataclasses import replace import atexit, contextlib +from tinygrad import Tensor from tinygrad.helpers import system, getenv from tinygrad.runtime.ops_amd import AMDProgram from extra.sqtt.roc import decode, WaveExec, ProfileSQTTEvent @@ -29,16 +30,15 @@ def save_sqtt(): yield sqtt events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())] - rctx = decode(events) - assert len(rctx.inst_execs) > 0, "empty sqtt output" - sqtt.update(rctx.inst_execs) + #rctx = decode(events) + #assert len(rctx.inst_execs) > 0, "empty sqtt output" + #sqtt.update(rctx.inst_execs) for e in events: if isinstance(e, ProfileSQTTEvent): print(replace(e, blob=b'')) if e.se == 0: - parse_sqtt_print_packets(e.blob, filter=[0xf, 0x11, 0x12, 0x14] if getenv("FILTER", 1) else None) - + parse_sqtt_print_packets(e.blob) template = """.text .globl matmul @@ -51,6 +51,7 @@ matmul: .rodata .p2align 6 .amdhsa_kernel matmul + .amdhsa_user_sgpr_kernarg_segment_ptr 1 .amdhsa_next_free_vgpr .amdgcn.next_free_vgpr .amdhsa_next_free_sgpr .amdgcn.next_free_sgpr .amdhsa_wavefront_size32 1 @@ -64,14 +65,21 @@ amdhsa.version: amdhsa.kernels: - .name: matmul .symbol: matmul.kd - .kernarg_segment_size: 0 .group_segment_fixed_size: 0 .private_segment_fixed_size: 0 - .kernarg_segment_align: 4 .wavefront_size: 32 .sgpr_count: 8 .vgpr_count: 32 .max_flat_workgroup_size: 1024 + .kernarg_segment_align: 8 + .kernarg_segment_size: 8 + .args: + - .address_space: global + .name: a + .offset: 0 + .size: 8 + .type_name: 'float*' + .value_kind: global_buffer ... .end_amdgpu_metadata """ @@ -80,20 +88,42 @@ def run_asm(src): NUM_WORKGROUPS = 1 WAVE_SIZE = 32 NUM_WAVES = 1 + t = Tensor.empty(0x1000).realize() + buf = t.uop.buffer.ensure_allocated() lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src))) dev.compiler.disassemble(lib) fxn = AMDProgram(dev, "matmul", lib) - fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) + fxn(buf._buf, global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True) if __name__ == "__main__": with save_sqtt() as sqtt: + #(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize() + Tensor.empty(1).elu().realize() + exit(0) + + with save_sqtt() as sqtt: + # what's in v0? run_asm([ + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v1, 0", + "s_clause 0x1", + "s_load_b64 s[0:1], s[0:1], null", + "s_waitcnt lgkmcnt(0)", + ]+[ + "global_load_b32 v1, v0, s[0:1]", + ]*10+[ + "global_load_b32 v10, v1, s[0:1]", + "s_waitcnt vmcnt(0)", + #"v_rcp_f32 v1, v0" - "v_add_f32_e32 v1 v0 v0", - "v_add_f32_e32 v3 v2 v2", - "v_add_f32_e32 v5 v4 v4", - "v_add_f32_e32 v7 v6 v6", + #"v_add_f32_e32 v1 v0 v0", + #"v_add_f32_e32 v5 v4 v4", + #"v_add_f32_e32 v7 v6 v6", #"v_add_f32_e32 v1 v0 v0", #"v_add_f32_e32 v2 v1 v1", #"s_nop 1" - ]*1) + ]*5+[ + "v_add_f32_e32 v3 v2 v2", + ]*5+[ + "v_mul_f32_e32 v3 v2 v2", + ]*7) diff --git a/extra/sqtt/attempt_sqtt_parse.py b/extra/sqtt/attempt_sqtt_parse.py index 996aeed964..640ffc1cc6 100644 --- a/extra/sqtt/attempt_sqtt_parse.py +++ b/extra/sqtt/attempt_sqtt_parse.py @@ -1,26 +1,39 @@ import pickle +from tinygrad.helpers import getenv from extra.sqtt.roc import decode, ProfileSQTTEvent # Instruction packets (one per ISA op) # NOTE: these are bad guesses and may be wrong! feel free to update if you know better +# some names were taken from SQ_TT_TOKEN_MASK_TOKEN_EXCLUDE_SHIFT OPCODE_NAMES = { - # ------------------------------------------------------------------------ - # 0x01–0x06: small “meta + maybe tiny delta” packets - # ------------------------------------------------------------------------ - 0x01: "META_ID12_TS_SMALL", # 12-bit ID + 3-bit delta field - 0x02: "META_FLAG8_TS_SMALL", # 8-bit flag/mode + small delta - 0x03: "META_SUBEVENT8_TS_SMALL", # 8-bit subevent/class + small delta - 0x04: "META_BASE_INDEX12_TS", # 12-bit base index + small delta - 0x05: "META_DESC24_TS_A", # 24-bit descriptor-ish + delta field - 0x06: "META_DESC24_TS_B", # second flavour, 24-bit, delta field + # gated by SQ_TT_TOKEN_EXCLUDE_VMEMEXEC_SHIFT + 0x02: "VMEMEXEC", + # gated by SQ_TT_TOKEN_EXCLUDE_ALUEXEC_SHIFT + 0x03: "ALUEXEC", + # gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show) + 0x01: "VALUINST", + # gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT + 0x06: "WAVERDY", + # gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT + 0x08: "WAVEEND", + 0x09: "WAVESTART", + # gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT + 0x04: "IMMEDIATE_4", + 0x05: "IMMEDIATE_5", + # some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there + 0x14: "REG", + # gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT + 0x12: "EVENT", + # gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT + 0x18: "INST", + # gated by SQ_TT_TOKEN_EXCLUDE_UTILCTR_SHIFT + 0x19: "UTILCTR", # ------------------------------------------------------------------------ # 0x07–0x0F: pure timestamp-ish deltas # ------------------------------------------------------------------------ 0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta) - 0x08: "EVT_MATCH_SMALL", # event-ish, see fields below - 0x09: "PERF_ROUTE_CONFIG", # routing/indirection config 0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2 0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3 0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer) @@ -34,15 +47,11 @@ OPCODE_NAMES = { 0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint 0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10) - 0x12: "EVT_SECONDARY_METRIC24", # 24-bit secondary timing/perf metric 0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19 - 0x14: "INST_EXEC_OR_CFG", # instruction exec record / config write / COR marker 0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot 0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker 0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B - 0x18: "PERF_EVENT_SELECT", # packed selector → FUN_0010aba0 - 0x19: "EVT_SUMMARY_48B", # 6-byte summary/aggregate metric } # these tables are from rocprof trace decoder @@ -181,9 +190,8 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: mode = "other" val36 = (pkt >> 12) & ((1 << 36) - 1) fields.append(f"mode={mode}") - fields.append(f"val36=0x{val36:x}") - if mode == "delta": - fields.append(f"delta36={delta}") + if mode != "delta": + fields.append(f"val36=0x{val36:x}") return ", ".join(fields) # For 0x07, 0x0A–0x0E, we know they drive time (via DELTA_MAP_DEFAULT), @@ -408,7 +416,15 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str: return ", ".join(fields) -def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) -> None: +# 0xb is time something +# 0xd is time something +# 0xf is small time advance +# 0x11 is time advance +# 0x16 is big time advance + markers +# 0x14 is REG +DEFAULT_FILTER = (0xb, 0xd, 0xf, 0x11, 0x16, 0x14) if getenv("FILTER", 1) else None + +def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=DEFAULT_FILTER) -> None: """ Minimal debug: print ONE LINE per decoded token (packet). @@ -466,23 +482,17 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) flags |= 0x01 # Common 36-bit field at bits [12..47] - val36 = (reg >> 12) & ((1 << 36) - 1) if (reg & 0x200) == 0: # delta mode: add 36-bit delta to time - delta = val36 + delta = (reg >> 12) & ((1 << 36) - 1) time += delta - note = "0x16-delta" else: # marker / other modes: no time advance - if (reg & 0x100) == 0 and val36 != 0: + if (reg & 0x100) == 0: # real marker: bit9=1, bit8=0, non-zero payload - delta = 0 - note = f"0x16-marker val=0x{val36:x}" - else: # "other" 0x16 variants, ignored for timing delta = 0 - note = "0x16-other" else: # 6) Generic opcode (including 0x0F) shift, width = DELTA_MAP_DEFAULT[opcode] @@ -492,19 +502,13 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) # TODO: add more opcode parsers here that add notes to other opcodes if opcode == 0x0F: delta_with_fix = delta + 4 - note = f"0x0f (+4) raw_delta={delta}" time += delta_with_fix delta = delta_with_fix else: time += delta - # ONE-LINE PRINT PER PACKET - #assert last_real_offset%8 == 0 - #assert (offset)%8 == 0, f"misalign offset {offset}" - # Append extra decoded fields into the note string - extra = decode_packet_fields(opcode, reg, delta) - if extra: note = (note + " ; " + extra) if note else extra + note = decode_packet_fields(opcode, reg, delta) if filter is None or opcode not in filter: my_reg = reg @@ -533,7 +537,7 @@ def parse(fn:str): if __name__ == "__main__": #dat_sqtt = parse("extra/sqtt/examples/profile_empty_run_0.pkl") - dat_sqtt = parse("extra/sqtt/examples/profile_plus_run_0.pkl") - #dat_sqtt = parse("extra/sqtt/examples/profile_gemm_run_0.pkl") + #dat_sqtt = parse("extra/sqtt/examples/profile_plus_run_0.pkl") + dat_sqtt = parse("extra/sqtt/examples/profile_gemm_run_0.pkl") blob_0 = dat_sqtt[0].blob parse_sqtt_print_packets(blob_0[8:])