import pickle, sys from tinygrad.helpers import getenv, Timing, colored from extra.sqtt.roc import decode, ProfileSQTTEvent # do these enums match fields in the packets? #from tinygrad.runtime.support.amd import import_soc #soc = import_soc([11]) #perf_sel = {getattr(soc, k):k for k in dir(soc) if k.startswith("SQ_PERF_")} # Instruction packets (one per ISA op) # NOTE: these are bad guesses and may be wrong! feel free to update if you know better # some names were taken from SQ_TT_TOKEN_MASK_TOKEN_EXCLUDE_SHIFT # we see 18 opcodes # opcodes(18): 1 2 3 4 5 6 8 9 F 10 11 12 14 15 16 17 18 19 # if you exclude everything, you are left with 6 # opcodes( 6): 10 11 14 15 16 17 # sometimes we see a lot of B, but not repeatable # not seen # 7 A C # NOTE: INST runs before EXEC OPCODE_COLORS = { # dispatches are BLACK 0x1: "BLACK", 0x18: "BLACK", # execs are yellow 0x2: "yellow", 0x3: "yellow", 0x4: "YELLOW", 0x5: "YELLOW", # waves are blue 0x8: "blue", 0x9: "blue", 0x6: "cyan", 0xb: "cyan", } OPCODE_NAMES = { # gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show) 0x01: "VALUINST", # gated by SQ_TT_TOKEN_EXCLUDE_VMEMEXEC_SHIFT 0x02: "VMEMEXEC", # gated by SQ_TT_TOKEN_EXCLUDE_ALUEXEC_SHIFT 0x03: "ALUEXEC", # gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT 0x04: "IMMEDIATE", 0x05: "IMMEDIATE_MASK", # gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT 0x06: "WAVERDY", # gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT 0x08: "WAVEEND", 0x09: "WAVESTART", # gated by SQ_TT_TOKEN_EXCLUDE_WAVEALLOC_SHIFT 0x0B: "WAVEALLOC", # FFF00 # gated by NOT SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT 0x0D: "PERF", # gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT 0x12: "EVENT", 0x13: "EVENT_BIG", # FFFFF800 # some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there. something is broken with the timing on this 0x14: "REG", # gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT 0x18: "INST", # gated by SQ_TT_TOKEN_EXCLUDE_UTILCTR_SHIFT 0x19: "UTILCTR", # this is the first (8 byte) packet in the bitstream 0x17: "LAYOUT_HEADER", # layout/mode/group + selectors A/B (reversed) # pure time (no extra bits) 0x0F: "TS_DELTA_SHORT", 0x10: "NOP", 0x11: "TS_WAVE_STATE", # almost pure time, has a small flag # not a good name, but seen and understood mostly 0x15: "SNAPSHOT", # small delta + 50-ish bits of snapshot 0x16: "TS_DELTA_OR_MARK", # 36-bit long delta or 36-bit marker # packets we haven't seen / rarely see 0x0b 0x07: "TS_DELTA_S8_W3_7", # shift=8, width=3 (small delta) 0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2 0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer) } # SALU = 0x0 / s_mov_b32 # SMEM = 0x1 / s_load_b* # JUMP = 0x3 / s_cbranch_scc0 # NEXT = 0x4 / s_cbranch_execz # MESSAGE = 0x9 / s_sendmsg # VALU = 0xb / v_(exp,log)_f32_e32 # VALU = 0xd / v_lshlrev_b64 # VALU = 0xe / v_mad_u64_u32 # VMEM = 0x21 / global_load_b32 # VMEM = 0x22 / global_load_b32 # VMEM = 0x24 / global_store_b32 # VMEM = 0x25 / global_store_b64 # VMEM = 0x27 / global_store # VMEM = 0x28 / global_store_b64 # LDS = 0x29 / ds_load_b128 # LDS = 0x2b / ds_store_b32 # LDS = 0x2e / ds_store_b128 # ???? = 0x5a / hidden global_load instruction # ???? = 0x5b / hidden global_load instruction # ???? = 0x5c / hidden global_store instruction # VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST) OPNAME = { 0x0: "SALU", 0x1: "SMEM", 0x3: "JUMP", 0x4: "NEXT", 0x9: "MESSAGE", 0xb: "VALU", 0xd: "VALU", 0xe: "VALU", 0x21: "VMEM_LOAD", 0x22: "VMEM_LOAD", 0x24: "VMEM_STORE", 0x25: "VMEM_STORE", 0x26: "VMEM_STORE", 0x27: "VMEM_STORE", 0x28: "VMEM_STORE", 0x29: "LDS_LOAD", 0x2b: "LDS_STORE", 0x2e: "LDS_STORE", 0x50: "__SIMD_LDS_LOAD", 0x51: "__SIMD_LDS_LOAD", 0x54: "__SIMD_LDS_STORE", 0x5a: "__SIMD_VMEM_LOAD", 0x5b: "__SIMD_VMEM_LOAD", 0x5c: "__SIMD_VMEM_STORE", 0x5d: "__SIMD_VMEM_STORE", 0x5e: "__SIMD_VMEM_STORE", 0x5f: "__SIMD_VMEM_STORE", 0x72: "SALU_OR", 0x73: "VALU_CMPX", } ALUSRC = { 1: "SALU", 2: "VALU", 3: "VALU_ALT", } MEMSRC = { 0: "LDS", 1: "__LDS", 2: "VMEM", 3: "__VMEM", } # these tables are from rocprof trace decoder # rocprof_trace_decoder_parse_data-0x11c6a0 # parse_sqtt_180 = b *rocprof_trace_decoder_parse_data-0x11c6a0+0x110040 # ---------- 1. local_138: 256-byte state->opcode table ---------- STATE_TO_OPCODE: bytes = bytes([ 0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x12, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x16, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x17, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x07, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x19, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x00, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x11, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, 0x10, 0x13, 0x18, 0x01, 0x05, 0x0b, 0x0c, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x09, 0x04, 0x03, 0x02, 0x10, 0x15, 0x18, 0x01, 0x06, 0x08, 0x0d, 0x00, 0x0f, 0x14, 0x18, 0x01, 0x0a, 0x04, 0x03, 0x02, ]) # opcode mask (the bits used to determine the opcode, worked out by looking at the repeats in STATE_TO_OPCODE) opcode_mask = { 0x10: 0b1111, 0x16: 0b1111111, 0x17: 0b1111111, 0x07: 0b1111111, 0x19: 0b1111111, 0x11: 0b1111111, 0x12: 0b11111111, 0x13: 0b11111111, 0x15: 0b1111111, 0x18: 0b111, 0x1: 0b111, 0x5: 0b11111, 0x6: 0b11111, 0xb: 0b11111, 0x8: 0b11111, 0xc: 0b11111, 0xd: 0b11111, 0xf: 0b1111, 0x14: 0b1111, 0x9: 0b11111, 0xa: 0b11111, 0x4: 0b1111, 0x3: 0b1111, 0x2: 0b1111, } # ---------- 2. DAT_0012e280: nibble budget per opcode&0x1F ---------- NIBBLE_BUDGET = [ 0x08, 0x0C, 0x08, 0x08, 0x0C, 0x18, 0x18, 0x40, 0x14, 0x20, 0x30, 0x14, 0x34, 0x1C, 0x30, 0x08, 0x04, 0x18, 0x18, 0x20, 0x40, 0x40, 0x30, 0x40, 0x14, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ] # ---------- 3. delta_map from your hash nodes ---------- # opcode -> (shift, width) DELTA_MAP_DEFAULT = { 0x01: (3, 3), # shift=3, end=6 0x02: (4, 2), # shift=4, end=6 0x03: (4, 2), # shift=4, end=6 0x04: (4, 3), # shift=4, end=7 0x05: (5, 3), # shift=5, end=8 0x06: (5, 3), # shift=5, end=8 0x07: (8, 3), # shift=8, end=11 0x08: (5, 3), # shift=5, end=8 0x09: (5, 2), # shift=5, end=7 0x0A: (5, 2), # shift=5, end=7 0x0B: (5, 3), # shift=5, end=8 0x0C: (5, 3), # shift=5, end=8 0x0D: (5, 3), # shift=5, end=8 # NOTE: 0x0e can never be decoded, it's not in the STATE_TO_OPCODE table #0x0E: (7, 2), # shift=7, end=9 0x0F: (4, 4), # shift=4, end=8 0x10: (0, 0), # shift=0, end=0 (no delta) 0x11: (7, 9), # shift=7, end=16 0x12: (8, 3), # shift=8, end=11 0x13: (8, 3), # shift=8, end=11 0x14: (4, 3), # shift=4, end=7 0x15: (7, 3), # shift=7, end=10 0x16: (12, 36), # shift=12, end=48 (36-bit field, matches the 0x16 special-case) 0x17: (0, 0), # shift=0, end=0 (no delta) 0x18: (4, 3), # shift=4, end=7 0x19: (7, 2), # shift=7, end=9 } # ---------- 4. One-line-per-packet parser ---------- def reg_mask(opcode): nb_bits = NIBBLE_BUDGET[opcode & 0x1F] shift, width = DELTA_MAP_DEFAULT[opcode] delta_mask = ((1 << width) - 1) << shift assert delta_mask & opcode_mask[opcode] == 0, "masks shouldn't overlap" return ((1 << nb_bits) - 1) & ~(delta_mask | opcode_mask[opcode]) def decode_packet_fields(opcode: int, reg: int) -> str: """ Decode packet payloads conservatively, using: - NIBBLE_BUDGET[opcode & 0x1F] to mask reg down to true width. - DELTA_MAP_DEFAULT[opcode] to expose the "primary" field (often delta). - Per-opcode layouts derived from rocprof's decompiled consumers. """ # --- 0. Restrict to real packet bits not used in delta --------------------------------- pkt = reg & reg_mask(opcode) fields: list[str] = [] match opcode: case 0x01: # VALUINST # 6 bit field flag = (pkt >> 6) & 1 wave = pkt >> 7 fields.append(f"wave={wave:x}") if flag: fields.append("flag") case 0x02: # VMEMEXEC # 2 bit field (pipe is a guess) src = pkt>>6 fields.append(f"src={src} [{MEMSRC.get(src, '')}]") case 0x03: # ALUEXEC # 2 bit field src = pkt>>6 fields.append(f"src={src} [{ALUSRC.get(src, '')}]") case 0x04: # IMMEDIATE_4 # 5 bit field (actually 4) wave = pkt >> 7 fields.append(f"wave={wave:x}") case 0x05: # IMMEDIATE_5 # 16 bit field # 1 bit per wave fields.append(f"mask={pkt>>8:016b}") case 0x6: # wave ready FFFF00 # 16 bit field # 1 bit per wave fields.append(f"mask={pkt>>8:016b}") case 0x0d: # 20 bit field fields.append(f"arg = {pkt>>8:X}") case 0x12: fields.append(f"event = {pkt>>11:X}") case 0x15: fields.append(f"snap = {pkt>>10:X}") case 0x19: # wave end fields.append(f"ctr = {pkt>>9:X}") case 0xf: extracted_delta = (reg >> 4) & 0xF fields.append(f"strange_delta=0x{extracted_delta:x}") case 0x11: # DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta. # FF0000 is the mask coarse = pkt >> 16 fields.append(f"coarse=0x{coarse:02x}") # From decomp: # - when layout<3 and coarse&1, it sets a "has interesting wave" flag # - when coarse&8, it marks all live waves as "terminated" if coarse & 0x01: fields.append("flag_wave_interest=1") if coarse & 0x08: fields.append("flag_terminate_all=1") case 0x8: # wave end, this is 20 bits (FFF00) flag7 = (pkt >> 8) & 1 simd = (pkt >> 9) & 3 cu = ((pkt >> 11) & 0x7) | (flag7 << 3) wave = (pkt >> 15) & 0x1f fields.append(f"wave={wave:x}") fields.append(f"simd={simd}") fields.append(f"cu={cu}") case 0x9: # From case 9 (WAVESTART) in multiple consumers: # flag7 = (w >> 7) & 1 (low bit of uVar41) # cls2 = (w >> 8) & 3 (class / group) # slot4 = (w >> 10) & 0xf (slot / group index) # idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path) # idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path) # id7 = (w >> 0x19) & 0x7f (7-bit id) flag7 = (pkt >> 7) & 1 simd = (pkt >> 8) & 3 cu = ((pkt >> 10) & 0x7) | (flag7 << 3) wave = (pkt >> 13) & 0x1F id7 = (pkt >> 17) fields.append(f"wave={wave:x}") fields.append(f"simd={simd}") fields.append(f"cu={cu}") fields.append(f"id7=0x{id7:x}") case 0x18: # FFF88 is the mask # From case 0x18: # low3 = w & 7 # grp3 = (w >> 3) or (w >> 4) & 7 (layout-dependent) # flags = bits 6 (B6) and 7 (B7) # hi8 = (w >> 0xc) & 0xff (layout 4 path) # hi7 = (w >> 0xd) & 0x7f (other layouts) # idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index flag1 = (pkt >> 3) & 1 flag2 = (pkt >> 7) & 1 wave = (pkt >> 8) & 0x1F op = (pkt >> 13) fields.append(f"wave={wave:x}") fields.append(f"op=0x{op:02x} [{OPNAME.get(op, '')}]") if flag1: fields.append("flag1") if flag2: fields.append("flag2") case 0x14: subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10) val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20) slot = (pkt >> 7) & 0x7 # index in local_168[...] tables hi_byte = (pkt >> 8) & 0xFF # determines config vs marker fields.append(f"subop=0x{subop:04x}") fields.append(f"slot={slot}") fields.append(f"val32=0x{val32:08x}") if hi_byte & 0x80: # Config flavour: writes config words into per-slot state arrays. fields.append("kind=config") if subop == 0x000C: fields.append("slot=lo") elif subop == 0x000D: fields.append("slot=hi") else: # COR marker: subop 0xC342, payload "COR\0" → start of a COR region. if subop == 0xC342: fields.append("kind=cor_stream") if val32 == 0x434F5200: fields.append("cor_magic='COR\\0'") case 0x16: # Bits: # bit8 -> 0x100 # bit9 -> 0x200 # bits 12..47 -> 36-bit field used as delta or marker bit8 = bool(pkt & 0x100) bit9 = bool(pkt & 0x200) if not bit9: mode = "delta" elif not bit8: mode = "marker" else: mode = "other" # need to use reg here val36 = (reg >> 12) & ((1 << 36) - 1) fields.append(f"mode={mode}") if mode != "delta": fields.append(f"val36=0x{val36:x}") case 0x17: # From decomp (two sites with identical logic): # layout = (w >> 7) & 0x3f # mode = (w >> 0xd) & 3 # group = (w >> 0xf) & 7 # sel_a = (w >> 0x1c) & 0xf # sel_b = (w >> 0x21) & 7 # flag4 = (w >> 0x3b) & 1 (only meaningful when layout == 4) layout = (pkt >> 7) & 0x3F simd = (pkt >> 13) & 0x3 # you can change this by changing traced simd group = (pkt >> 15) & 0x7 sel_a = (pkt >> 0x1C) & 0xF sel_b = (pkt >> 0x21) & 0x7 flag4 = (pkt >> 0x3B) & 0x1 fields.append(f"layout={layout}") fields.append(f"group={group}") fields.append(f"simd={simd}") fields.append(f"sel_a={sel_a}") fields.append(f"sel_b={sel_b}") if layout == 4: fields.append(f"layout4_flag={flag4}") case _: fields.append(f"{pkt:X} & {reg_mask(opcode):X}") return ",".join(fields) FILTER_LEVEL = getenv("FILTER", 1) DEFAULT_FILTER: tuple[int, ...] = tuple() # NOP + pure time + "sample" if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf, 0x11) # reg + event + sample + marker # TODO: events are probably good if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x14, 0x12, 0x16) # instruction runs + valuinst if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x01, 0x02, 0x03) # instructions dispatch (inst, immed) if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x4, 0x5, 0x18) # waves if FILTER_LEVEL >= 4: DEFAULT_FILTER += (0x6, 0x8, 0x9) def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -> None: """ Minimal debug: print ONE LINE per decoded token (packet). Now prints only the actual nibbles that belong to each packet, instead of the full 64-bit shift register. """ n = len(data) time = 0 last_printed_time = 0 reg = 0 # shift register offset = 0 # bit offset, in steps of 4 (one nibble) nib_budget = 0x40 flags = 0 token_index = 0 opcodes_seen = set() while (offset >> 3) < n: # 1) Fill register with nibbles according to nib_budget if nib_budget != 0: target = offset + 4 + ((nib_budget - 1) & ~3) while offset != target and (offset >> 3) < n: byte = data[offset >> 3] nib = (byte >> (offset & 4)) & 0xF reg = ((reg >> 4) | (nib << 60)) & ((1 << 64) - 1) offset += 4 if offset != target: break # don't parse past the end # 2) Decode token from low 8 bits opcode = STATE_TO_OPCODE[reg & 0xFF] opcodes_seen.add(opcode) # 4) Set next nibble budget based on opcode nib_budget = NIBBLE_BUDGET[opcode & 0x1F] # 5) Get delta shift, width = DELTA_MAP_DEFAULT[opcode] delta = (reg >> shift) & ((1 << width) - 1) # 6) Update time and handle special opcodes 0xF/0x16 if opcode == 0x16: two_bits = (reg >> 8) & 0x3 if two_bits == 1: flags |= 0x01 # Common 36-bit field at bits [12..47] if (reg & 0x200) == 0: # delta mode: add 36-bit delta to time pass elif (reg & 0x100) == 0: # marker / other modes: no time advance # real marker: bit9=1, bit8=0, non-zero payload # "other" 0x16 variants, ignored for timing delta = 0 else: raise RuntimeError("unknown 0x16 delta") elif opcode == 0x0F: # opcode 0x0F has an offset of 4 to the delta # update: it's actually computed to be 8 to match WAVESTART delta = delta + 8 # Append extra decoded fields into the note string note = decode_packet_fields(opcode, reg) # this delta happens before the instruction time += delta token_index += 1 if verbose and (filter is None or opcode not in filter): print(f"{time:8d} +{time-last_printed_time:8d} : "+colored(f"{OPCODE_NAMES[opcode]:18s} ", OPCODE_COLORS.get(opcode, "white"))+f"{note}") last_printed_time = time # Optional summary at the end print(f"# done: tokens={token_index:_}, final_time={time}, flags=0x{flags:02x}") if verbose: print(f"opcodes({len(opcodes_seen):2d}):", ' '.join([colored(f"{op:2X}", "WHITE" if op in opcodes_seen else "BLACK") for op in sorted(opcode_mask)])) def parse(fn:str): with Timing(f"unpickle {fn}: "): dat = pickle.load(open(fn, "rb")) #if getenv("ROCM", 0): # with Timing(f"decode {fn}: "): ctx = decode(dat) dat_sqtt = [x for x in dat if isinstance(x, ProfileSQTTEvent)] print(f"got {len(dat_sqtt)} SQTT events in {fn}") return dat_sqtt if __name__ == "__main__": fn = "extra/sqtt/examples/profile_gemm_run_0.pkl" dat_sqtt = parse(sys.argv[1] if len(sys.argv) > 1 else fn) for i,dat in enumerate(dat_sqtt): with Timing(f"decode pkt {i} with len {len(dat.blob):_}: "): parse_sqtt_print_packets(dat.blob, verbose=getenv("V", 1))