continue work on parse sqtt, enable with SQTT_PARSE (#13425)

* continue work on parse sqtt, enable with SQTT_PARSE

* fix timing

* delta is pre instruction

* hi8 values

* a few more

* a bit more

* let it crash if you enabled it

* figure out simd

* hide 0x11
This commit is contained in:
George Hotz
2025-11-22 19:03:17 -08:00
committed by GitHub
parent 92170d0ff1
commit 5110409339
5 changed files with 113 additions and 84 deletions

View File

@@ -88,6 +88,7 @@ def run_asm(src, num_workgroups=1, num_waves=1):
fxn(buf._buf, global_size=(num_workgroups,1,1), local_size=(WAVE_SIZE*num_waves,1,1), wait=True)
if __name__ == "__main__":
"""
with save_sqtt() as sqtt:
run_asm([
#"s_barrier",
@@ -117,6 +118,7 @@ if __name__ == "__main__":
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v1, 0",
"s_nop 0",
"s_nop 0",
"s_nop 100",
@@ -124,21 +126,21 @@ if __name__ == "__main__":
"s_nop 100",
"s_nop 0",
"s_nop 0",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v3, v0, s[0:1]",
"global_load_b32 v4, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v1, v0, s[0:1]",
"global_load_b32 v5, v0, s[0:1]",
"global_load_b32 v5, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v1, s[0:1]",
"global_load_b32 v2, v1, s[0:1]",
"global_load_b32 v2, v1, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"s_nop 0",
"s_nop 0",
"s_nop 0",
@@ -153,11 +155,12 @@ if __name__ == "__main__":
"s_endpgm",
], num_workgroups=1, num_waves=1)
exit(0)
"""
with save_sqtt() as sqtt:
#(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize()
Tensor.empty(1, 64).sum(axis=1).realize()
#Tensor.empty(1).exp().realize()
#Tensor.empty(1, 64).sum(axis=1).realize()
Tensor.empty(1).log2().realize()
exit(0)
with save_sqtt() as sqtt:

View File

@@ -22,7 +22,7 @@ from extra.sqtt.roc import decode, ProfileSQTTEvent
# NOTE: INST runs before EXEC
GOOD_OPCODE_NAMES = {
OPCODE_NAMES = {
# gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show)
0x01: "VALUINST",
# gated by SQ_TT_TOKEN_EXCLUDE_VMEMEXEC_SHIFT
@@ -39,41 +39,33 @@ GOOD_OPCODE_NAMES = {
0x09: "WAVESTART",
# gated by NOT SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT
0x0D: "PERF",
# pure time
0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate
0x10: "NOP",
# gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT
0x12: "EVENT",
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there. something is broken with the timing on this
0x14: "REG",
# marker
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
# this is the first packet
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B
# gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT
0x18: "INST",
# gated by SQ_TT_TOKEN_EXCLUDE_UTILCTR_SHIFT
0x19: "UTILCTR",
}
OPCODE_NAMES = {
**GOOD_OPCODE_NAMES,
# this is the first (8 byte) packet in the bitstream
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B (reversed)
# ------------------------------------------------------------------------
# 0x070x0F: pure timestamp-ish deltas
# ------------------------------------------------------------------------
# pure time (no extra bits)
0x0F: "TS_DELTA_SHORT_PLUS4",
0x10: "NOP",
0x11: "TS_WAVE_STATE_SAMPLE", # almost pure time, has a small flag
# not a good name, but seen and understood mostly
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
# packets we haven't seen / rarely see 0x0b
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
# ------------------------------------------------------------------------
# 0x100x19: timestamps, layout headers, events, perf
# ------------------------------------------------------------------------
0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10)
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
}
# these tables are from rocprof trace decoder
@@ -201,7 +193,8 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
flag = (pkt >> 6) & 1
wave = pkt >> 7
fields.append(f"wave={wave:x}")
fields.append(f"flag={flag:X}")
assert flag == 0, "non 0 flag in 0x1"
#fields.append(f"flag={flag:X}")
case 0x02: # VMEMEXEC
# 2 bit field (pipe is a guess)
fields.append(f"pipe={pkt>>6:X}")
@@ -216,6 +209,11 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
# 16 bit field
# 1 bit per wave
fields.append(f"mask={pkt>>8:16b}")
case 0x6:
# wave ready FFFF00
# 16 bit field
# 1 bit per wave
fields.append(f"mask={pkt>>8:16b}")
case 0x0d:
# 20 bit field
fields.append(f"arg = {pkt>>8:X}")
@@ -226,8 +224,12 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
case 0x19:
# wave end
fields.append(f"ctr = {pkt>>9:X}")
case 0xf:
extracted_delta = (reg >> 4) & 0xF
fields.append(f"strange_delta=0x{extracted_delta:x}")
case 0x11:
# DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta.
# FF0000 is the mask
coarse = pkt >> 16
fields.append(f"coarse=0x{coarse:02x}")
# From decomp:
@@ -237,19 +239,16 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
fields.append("flag_wave_interest=1")
if coarse & 0x08:
fields.append("flag_terminate_all=1")
case 0x6:
# wave ready
fields.append(f"wave = {pkt>>8:X}")
case 0x8:
# wave end, this is 20 bits (FFF00)
flag7 = (pkt >> 8) & 0x3
wgp = (pkt >> 10) & 1
flag7 = (pkt >> 8) & 1
simd = (pkt >> 9) & 3
slot4 = (pkt >> 11) & 0xF
wave = (pkt >> 15) & 0x1f
assert flag7 == 0, "flag7 should be 0"
assert slot4 == 0, "slot4 should be 0"
fields.append(f"wave={wave:x}")
fields.append(f"wgp={wgp}")
fields.append(f"flag7={flag7}")
fields.append(f"slot4={slot4:x}")
fields.append(f"simd={simd}")
case 0x9:
# From case 9 (WAVESTART) in multiple consumers:
# flag7 = (w >> 7) & 1 (low bit of uVar41)
@@ -258,15 +257,15 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
# idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path)
# idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path)
# id7 = (w >> 0x19) & 0x7f (7-bit id)
flag7 = (pkt >> 7) & 3
wgp = (pkt >> 9) & 1
flag7 = (pkt >> 7) & 1
simd = (pkt >> 8) & 2
slot3 = (pkt >> 10) & 0x7 # NOTE: this isn't 4!
wave = (pkt >> 13) & 0x1F
id7 = (pkt >> 17)
assert flag7 == 0, "flag7 should be 0"
assert slot3 == 0, "slot3 should be 0"
fields.append(f"wave={wave:x}")
fields.append(f"flag7={flag7}")
fields.append(f"wgp={wgp}")
fields.append(f"slot3={slot3:x}")
fields.append(f"simd={simd}")
fields.append(f"id7=0x{id7:x}")
case 0x18:
# FFF88 is the mask
@@ -282,8 +281,26 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
wave = (pkt >> 8) & 0x1F
hi8 = (pkt >> 13)
fields.append(f"wave={wave:x}")
fields.append(f"flag={flag:x}")
fields.append(f"flag2={flag2:x}")
assert flag == 0 and flag2 == 0, "non 0 flags in 0x18"
#fields.append(f"flag={flag:x}")
#fields.append(f"flag2={flag2:x}")
# hi8 values:
# SALU = 0x0 / s_mov_b32
# SMEM = 0x1 / s_load_b*
# NEXT = 0x4 / s_cbranch_execz
# MESSAGE = 0x9 / s_sendmsg
# VALU = 0xb / v_(exp,log)_f32_e32
# VALU = 0xd / v_lshlrev_b64
# VMEM = 0x21 / global_load_b32
# VMEM = 0x22 / global_load_b32
# VMEM = 0x24 / global_store_b32
# VMEM = 0x25 / global_store_b64
# LDS = 0x29 / ds_load_b128
# LDS = 0x2b / ds_store_b32
# ???? = 0x5a / hidden global_load instruction
# ???? = 0x5b / hidden global_load instruction
# ???? = 0x5c / hidden global_store instruction
# VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST)
fields.append(f"hi8=0x{hi8:x}")
case 0x14:
subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10)
@@ -352,14 +369,14 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
fields.append(f"& {reg_mask(opcode):X}")
return ",".join(fields)
FILTER_LEVEL = getenv("FILTER", 2)
FILTER_LEVEL = getenv("FILTER", 1)
DEFAULT_FILTER = tuple()
# NOP + pure time
if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf)
# NOP + pure time + "sample"
if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf, 0x11)
# reg + event + sample + marker
# TODO: events are probably good
if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x11, 0x14, 0x16, 0x12)
if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x14, 0x12, 0x16)
# instruction runs
if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x02, 0x03)
# instructions dispatch (inst, valuinst, immed)
@@ -401,54 +418,53 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
# 4) Set next nibble budget based on opcode
nib_budget = NIBBLE_BUDGET[opcode & 0x1F]
# 5) Update time and handle special opcodes 0xF/0x16
# 5) Get delta
shift, width = DELTA_MAP_DEFAULT[opcode]
delta = (reg >> shift) & ((1 << width) - 1)
# 6) Update time and handle special opcodes 0xF/0x16
if opcode == 0x16:
two_bits = (reg >> 8) & 0x3
if two_bits == 1:
flags |= 0x01
# Common 36-bit field at bits [12..47]
if (reg & 0x200) == 0:
# delta mode: add 36-bit delta to time
delta = (reg >> 12) & ((1 << 36) - 1)
else:
pass
elif (reg & 0x100) == 0:
# marker / other modes: no time advance
if (reg & 0x100) == 0:
# real marker: bit9=1, bit8=0, non-zero payload
# "other" 0x16 variants, ignored for timing
delta = 0
else:
# 6) Generic opcode (including 0x0F)
shift, width = DELTA_MAP_DEFAULT[opcode]
delta = (reg >> shift) & ((1 << width) - 1)
# real marker: bit9=1, bit8=0, non-zero payload
# "other" 0x16 variants, ignored for timing
delta = 0
else:
raise RuntimeError("unknown 0x16 delta")
elif opcode == 0x0F:
# opcode 0x0F has an offset of 4 to the delta
if opcode == 0x0F:
delta = delta + 4
# update: it's actually computed to be 8 to match WAVESTART
delta = delta + 8
# Append extra decoded fields into the note string
note = decode_packet_fields(opcode, reg)
# this delta happens before the instruction
time += delta
token_index += 1
if verbose and (filter is None or opcode not in filter):
print(
f"{token_index:4d} "
f"time={time:8d}+{delta+(time-last_printed_time):8d} "
f"{time:8d}+{time-last_printed_time:8d} : "
f"op={opcode:2x} "
f"{OPCODE_NAMES[opcode]:24s} "
f"{reg&reg_mask(opcode):16X} "
f"{note}"
)
#f"off={offset//4:5d} "
last_printed_time = time+delta
time += delta
token_index += 1
f"{note}")
last_printed_time = time
# Optional summary at the end
print(f"# done: tokens={token_index:_}, final_time={time}, flags=0x{flags:02x}")
if verbose:
print(f"opcodes({len(opcodes_seen):2d}):", ' '.join([colored(f"{op:2X}", "white" if op in GOOD_OPCODE_NAMES else "red") for op in opcodes_seen]))
print(f"opcodes({len(opcodes_seen):2d}):",
' '.join([colored(f"{op:2X}", "WHITE" if op in opcodes_seen else "BLACK") for op in sorted(opcode_mask)]))
def parse(fn:str):

View File

@@ -72,10 +72,10 @@ class _ROCParseCtx:
return self.active_blob
def on_occupancy_ev(self, ev:rocprof.rocprofiler_thread_trace_decoder_occupancy_t):
if DEBUG >= 5: print("OCC", ev.time, self.active_se, ev.cu, ev.simd, ev.wave_id, ev.start)
if DEBUG >= 5: print(f"OCC {ev.time=} {self.active_se=} {ev.cu=} {ev.simd=} {ev.wave_id=} {ev.start=}")
def on_wave_ev(self, ev:rocprof.rocprofiler_thread_trace_decoder_wave_t):
if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time)
if DEBUG >= 5: print(f"WAVE {ev.wave_id=} {self.active_se=} {ev.cu=} {ev.simd=} {ev.contexts=} {ev.begin_time=} {ev.end_time=}")
insts_blob = bytearray(sz:=ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t))
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
@@ -109,6 +109,10 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
for ev in (rocprof.rocprofiler_thread_trace_decoder_occupancy_t * n).from_address(events_ptr): ROCParseCtx.on_occupancy_ev(ev)
case rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_WAVE:
for ev in (rocprof.rocprofiler_thread_trace_decoder_wave_t * n).from_address(events_ptr): ROCParseCtx.on_wave_ev(ev)
case rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_REALTIME:
if DEBUG >= 5:
pairs = [(ev.shader_clock, ev.realtime_clock) for ev in (rocprof.rocprofiler_thread_trace_decoder_realtime_t * n).from_address(events_ptr)]
print(f"REALTIME {pairs}")
case _:
if DEBUG >= 5: print(rocprof.enum_rocprofiler_thread_trace_decoder_record_type_t.get(record_type), events_ptr, n)
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS

View File

@@ -199,7 +199,10 @@ class AMDComputeQueue(HWQueue):
for xcc in range(self.dev.xccs):
with self.pred_exec(xcc_mask=1 << xcc):
for i in range(8 if prg.dev.target >= (11,0,0) else 4):
self.wreg(getattr(self.gc, f'regCOMPUTE_STATIC_THREAD_MGMT_SE{i}'), min(0xffffffff, (1 << (se_cap + (1 if i == 0 else 0))) - 1))
if SQTT_LIMIT_SE > 1: # only run unmasked shader engines
self.wreg(getattr(self.gc, f'regCOMPUTE_STATIC_THREAD_MGMT_SE{i}'), 1 if SQTT_ITRACE_SE_MASK.value & (1 << i) else 0)
else:
self.wreg(getattr(self.gc, f'regCOMPUTE_STATIC_THREAD_MGMT_SE{i}'), min(0xffffffff, (1 << (se_cap + (1 if i == 0 else 0))) - 1))
def sqtt_userdata(self, data, *extra_dwords):
data_ints = [x[0] for x in struct.iter_unpack('<I', bytes(data))] + list(extra_dwords)

View File

@@ -222,6 +222,9 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
except Exception: return err("DECODER IMPORT ISSUE")
try: rctx = decode(profile)
except Exception: return err("DECODER ERROR")
if getenv("SQTT_PARSE"):
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
for e in sqtt_events: parse_sqtt_print_packets(e.blob)
if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
steps:list[dict] = []
for name,waves in rctx.inst_execs.items():