mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
continue work on parse sqtt, enable with SQTT_PARSE (#13425)
* continue work on parse sqtt, enable with SQTT_PARSE * fix timing * delta is pre instruction * hi8 values * a few more * a bit more * let it crash if you enabled it * figure out simd * hide 0x11
This commit is contained in:
@@ -88,6 +88,7 @@ def run_asm(src, num_workgroups=1, num_waves=1):
|
||||
fxn(buf._buf, global_size=(num_workgroups,1,1), local_size=(WAVE_SIZE*num_waves,1,1), wait=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
with save_sqtt() as sqtt:
|
||||
run_asm([
|
||||
#"s_barrier",
|
||||
@@ -117,6 +118,7 @@ if __name__ == "__main__":
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v1, 0",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_nop 100",
|
||||
@@ -124,21 +126,21 @@ if __name__ == "__main__":
|
||||
"s_nop 100",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v3, v0, s[0:1]",
|
||||
"global_load_b32 v4, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
"global_load_b32 v5, v0, s[0:1]",
|
||||
"global_load_b32 v5, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v1, s[0:1]",
|
||||
"global_load_b32 v2, v1, s[0:1]",
|
||||
"global_load_b32 v2, v1, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
@@ -153,11 +155,12 @@ if __name__ == "__main__":
|
||||
"s_endpgm",
|
||||
], num_workgroups=1, num_waves=1)
|
||||
exit(0)
|
||||
"""
|
||||
|
||||
with save_sqtt() as sqtt:
|
||||
#(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize()
|
||||
Tensor.empty(1, 64).sum(axis=1).realize()
|
||||
#Tensor.empty(1).exp().realize()
|
||||
#Tensor.empty(1, 64).sum(axis=1).realize()
|
||||
Tensor.empty(1).log2().realize()
|
||||
exit(0)
|
||||
|
||||
with save_sqtt() as sqtt:
|
||||
|
||||
@@ -22,7 +22,7 @@ from extra.sqtt.roc import decode, ProfileSQTTEvent
|
||||
|
||||
# NOTE: INST runs before EXEC
|
||||
|
||||
GOOD_OPCODE_NAMES = {
|
||||
OPCODE_NAMES = {
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show)
|
||||
0x01: "VALUINST",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_VMEMEXEC_SHIFT
|
||||
@@ -39,41 +39,33 @@ GOOD_OPCODE_NAMES = {
|
||||
0x09: "WAVESTART",
|
||||
# gated by NOT SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT
|
||||
0x0D: "PERF",
|
||||
# pure time
|
||||
0x0F: "TS_DELTA_SHORT_PLUS4", # short delta; ROCm adds +4 before accumulate
|
||||
0x10: "NOP",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT
|
||||
0x12: "EVENT",
|
||||
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there
|
||||
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there. something is broken with the timing on this
|
||||
0x14: "REG",
|
||||
# marker
|
||||
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
|
||||
# this is the first packet
|
||||
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT
|
||||
0x18: "INST",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_UTILCTR_SHIFT
|
||||
0x19: "UTILCTR",
|
||||
}
|
||||
|
||||
OPCODE_NAMES = {
|
||||
**GOOD_OPCODE_NAMES,
|
||||
# this is the first (8 byte) packet in the bitstream
|
||||
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B (reversed)
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# 0x07–0x0F: pure timestamp-ish deltas
|
||||
# ------------------------------------------------------------------------
|
||||
# pure time (no extra bits)
|
||||
0x0F: "TS_DELTA_SHORT_PLUS4",
|
||||
0x10: "NOP",
|
||||
0x11: "TS_WAVE_STATE_SAMPLE", # almost pure time, has a small flag
|
||||
|
||||
# not a good name, but seen and understood mostly
|
||||
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
|
||||
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
|
||||
|
||||
# packets we haven't seen / rarely see 0x0b
|
||||
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
|
||||
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
|
||||
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
|
||||
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# 0x10–0x19: timestamps, layout headers, events, perf
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10)
|
||||
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
|
||||
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
|
||||
}
|
||||
|
||||
# these tables are from rocprof trace decoder
|
||||
@@ -201,7 +193,8 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
flag = (pkt >> 6) & 1
|
||||
wave = pkt >> 7
|
||||
fields.append(f"wave={wave:x}")
|
||||
fields.append(f"flag={flag:X}")
|
||||
assert flag == 0, "non 0 flag in 0x1"
|
||||
#fields.append(f"flag={flag:X}")
|
||||
case 0x02: # VMEMEXEC
|
||||
# 2 bit field (pipe is a guess)
|
||||
fields.append(f"pipe={pkt>>6:X}")
|
||||
@@ -216,6 +209,11 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
# 16 bit field
|
||||
# 1 bit per wave
|
||||
fields.append(f"mask={pkt>>8:16b}")
|
||||
case 0x6:
|
||||
# wave ready FFFF00
|
||||
# 16 bit field
|
||||
# 1 bit per wave
|
||||
fields.append(f"mask={pkt>>8:16b}")
|
||||
case 0x0d:
|
||||
# 20 bit field
|
||||
fields.append(f"arg = {pkt>>8:X}")
|
||||
@@ -226,8 +224,12 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
case 0x19:
|
||||
# wave end
|
||||
fields.append(f"ctr = {pkt>>9:X}")
|
||||
case 0xf:
|
||||
extracted_delta = (reg >> 4) & 0xF
|
||||
fields.append(f"strange_delta=0x{extracted_delta:x}")
|
||||
case 0x11:
|
||||
# DELTA_MAP_DEFAULT: shift=7, width=9 -> small delta.
|
||||
# FF0000 is the mask
|
||||
coarse = pkt >> 16
|
||||
fields.append(f"coarse=0x{coarse:02x}")
|
||||
# From decomp:
|
||||
@@ -237,19 +239,16 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
fields.append("flag_wave_interest=1")
|
||||
if coarse & 0x08:
|
||||
fields.append("flag_terminate_all=1")
|
||||
case 0x6:
|
||||
# wave ready
|
||||
fields.append(f"wave = {pkt>>8:X}")
|
||||
case 0x8:
|
||||
# wave end, this is 20 bits (FFF00)
|
||||
flag7 = (pkt >> 8) & 0x3
|
||||
wgp = (pkt >> 10) & 1
|
||||
flag7 = (pkt >> 8) & 1
|
||||
simd = (pkt >> 9) & 3
|
||||
slot4 = (pkt >> 11) & 0xF
|
||||
wave = (pkt >> 15) & 0x1f
|
||||
assert flag7 == 0, "flag7 should be 0"
|
||||
assert slot4 == 0, "slot4 should be 0"
|
||||
fields.append(f"wave={wave:x}")
|
||||
fields.append(f"wgp={wgp}")
|
||||
fields.append(f"flag7={flag7}")
|
||||
fields.append(f"slot4={slot4:x}")
|
||||
fields.append(f"simd={simd}")
|
||||
case 0x9:
|
||||
# From case 9 (WAVESTART) in multiple consumers:
|
||||
# flag7 = (w >> 7) & 1 (low bit of uVar41)
|
||||
@@ -258,15 +257,15 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
# idx_lo = (w >> 0xd) & 0x1f (low index, layout<4 path)
|
||||
# idx_hi = (w >> 0xf) & 0x1f (high index, layout>=4 path)
|
||||
# id7 = (w >> 0x19) & 0x7f (7-bit id)
|
||||
flag7 = (pkt >> 7) & 3
|
||||
wgp = (pkt >> 9) & 1
|
||||
flag7 = (pkt >> 7) & 1
|
||||
simd = (pkt >> 8) & 2
|
||||
slot3 = (pkt >> 10) & 0x7 # NOTE: this isn't 4!
|
||||
wave = (pkt >> 13) & 0x1F
|
||||
id7 = (pkt >> 17)
|
||||
assert flag7 == 0, "flag7 should be 0"
|
||||
assert slot3 == 0, "slot3 should be 0"
|
||||
fields.append(f"wave={wave:x}")
|
||||
fields.append(f"flag7={flag7}")
|
||||
fields.append(f"wgp={wgp}")
|
||||
fields.append(f"slot3={slot3:x}")
|
||||
fields.append(f"simd={simd}")
|
||||
fields.append(f"id7=0x{id7:x}")
|
||||
case 0x18:
|
||||
# FFF88 is the mask
|
||||
@@ -282,8 +281,26 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
wave = (pkt >> 8) & 0x1F
|
||||
hi8 = (pkt >> 13)
|
||||
fields.append(f"wave={wave:x}")
|
||||
fields.append(f"flag={flag:x}")
|
||||
fields.append(f"flag2={flag2:x}")
|
||||
assert flag == 0 and flag2 == 0, "non 0 flags in 0x18"
|
||||
#fields.append(f"flag={flag:x}")
|
||||
#fields.append(f"flag2={flag2:x}")
|
||||
# hi8 values:
|
||||
# SALU = 0x0 / s_mov_b32
|
||||
# SMEM = 0x1 / s_load_b*
|
||||
# NEXT = 0x4 / s_cbranch_execz
|
||||
# MESSAGE = 0x9 / s_sendmsg
|
||||
# VALU = 0xb / v_(exp,log)_f32_e32
|
||||
# VALU = 0xd / v_lshlrev_b64
|
||||
# VMEM = 0x21 / global_load_b32
|
||||
# VMEM = 0x22 / global_load_b32
|
||||
# VMEM = 0x24 / global_store_b32
|
||||
# VMEM = 0x25 / global_store_b64
|
||||
# LDS = 0x29 / ds_load_b128
|
||||
# LDS = 0x2b / ds_store_b32
|
||||
# ???? = 0x5a / hidden global_load instruction
|
||||
# ???? = 0x5b / hidden global_load instruction
|
||||
# ???? = 0x5c / hidden global_store instruction
|
||||
# VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST)
|
||||
fields.append(f"hi8=0x{hi8:x}")
|
||||
case 0x14:
|
||||
subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10)
|
||||
@@ -352,14 +369,14 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
fields.append(f"& {reg_mask(opcode):X}")
|
||||
return ",".join(fields)
|
||||
|
||||
FILTER_LEVEL = getenv("FILTER", 2)
|
||||
FILTER_LEVEL = getenv("FILTER", 1)
|
||||
|
||||
DEFAULT_FILTER = tuple()
|
||||
# NOP + pure time
|
||||
if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf)
|
||||
# NOP + pure time + "sample"
|
||||
if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf, 0x11)
|
||||
# reg + event + sample + marker
|
||||
# TODO: events are probably good
|
||||
if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x11, 0x14, 0x16, 0x12)
|
||||
if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x14, 0x12, 0x16)
|
||||
# instruction runs
|
||||
if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x02, 0x03)
|
||||
# instructions dispatch (inst, valuinst, immed)
|
||||
@@ -401,54 +418,53 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
|
||||
# 4) Set next nibble budget based on opcode
|
||||
nib_budget = NIBBLE_BUDGET[opcode & 0x1F]
|
||||
|
||||
# 5) Update time and handle special opcodes 0xF/0x16
|
||||
# 5) Get delta
|
||||
shift, width = DELTA_MAP_DEFAULT[opcode]
|
||||
delta = (reg >> shift) & ((1 << width) - 1)
|
||||
|
||||
# 6) Update time and handle special opcodes 0xF/0x16
|
||||
if opcode == 0x16:
|
||||
two_bits = (reg >> 8) & 0x3
|
||||
if two_bits == 1:
|
||||
flags |= 0x01
|
||||
|
||||
# Common 36-bit field at bits [12..47]
|
||||
|
||||
if (reg & 0x200) == 0:
|
||||
# delta mode: add 36-bit delta to time
|
||||
delta = (reg >> 12) & ((1 << 36) - 1)
|
||||
else:
|
||||
pass
|
||||
elif (reg & 0x100) == 0:
|
||||
# marker / other modes: no time advance
|
||||
if (reg & 0x100) == 0:
|
||||
# real marker: bit9=1, bit8=0, non-zero payload
|
||||
# "other" 0x16 variants, ignored for timing
|
||||
delta = 0
|
||||
else:
|
||||
# 6) Generic opcode (including 0x0F)
|
||||
shift, width = DELTA_MAP_DEFAULT[opcode]
|
||||
delta = (reg >> shift) & ((1 << width) - 1)
|
||||
|
||||
# real marker: bit9=1, bit8=0, non-zero payload
|
||||
# "other" 0x16 variants, ignored for timing
|
||||
delta = 0
|
||||
else:
|
||||
raise RuntimeError("unknown 0x16 delta")
|
||||
elif opcode == 0x0F:
|
||||
# opcode 0x0F has an offset of 4 to the delta
|
||||
if opcode == 0x0F:
|
||||
delta = delta + 4
|
||||
# update: it's actually computed to be 8 to match WAVESTART
|
||||
delta = delta + 8
|
||||
|
||||
# Append extra decoded fields into the note string
|
||||
note = decode_packet_fields(opcode, reg)
|
||||
|
||||
# this delta happens before the instruction
|
||||
time += delta
|
||||
token_index += 1
|
||||
|
||||
if verbose and (filter is None or opcode not in filter):
|
||||
print(
|
||||
f"{token_index:4d} "
|
||||
f"time={time:8d}+{delta+(time-last_printed_time):8d} "
|
||||
f"{time:8d}+{time-last_printed_time:8d} : "
|
||||
f"op={opcode:2x} "
|
||||
f"{OPCODE_NAMES[opcode]:24s} "
|
||||
f"{reg®_mask(opcode):16X} "
|
||||
f"{note}"
|
||||
)
|
||||
#f"off={offset//4:5d} "
|
||||
last_printed_time = time+delta
|
||||
|
||||
time += delta
|
||||
token_index += 1
|
||||
f"{note}")
|
||||
last_printed_time = time
|
||||
|
||||
# Optional summary at the end
|
||||
print(f"# done: tokens={token_index:_}, final_time={time}, flags=0x{flags:02x}")
|
||||
if verbose:
|
||||
print(f"opcodes({len(opcodes_seen):2d}):", ' '.join([colored(f"{op:2X}", "white" if op in GOOD_OPCODE_NAMES else "red") for op in opcodes_seen]))
|
||||
print(f"opcodes({len(opcodes_seen):2d}):",
|
||||
' '.join([colored(f"{op:2X}", "WHITE" if op in opcodes_seen else "BLACK") for op in sorted(opcode_mask)]))
|
||||
|
||||
|
||||
def parse(fn:str):
|
||||
|
||||
@@ -72,10 +72,10 @@ class _ROCParseCtx:
|
||||
return self.active_blob
|
||||
|
||||
def on_occupancy_ev(self, ev:rocprof.rocprofiler_thread_trace_decoder_occupancy_t):
|
||||
if DEBUG >= 5: print("OCC", ev.time, self.active_se, ev.cu, ev.simd, ev.wave_id, ev.start)
|
||||
if DEBUG >= 5: print(f"OCC {ev.time=} {self.active_se=} {ev.cu=} {ev.simd=} {ev.wave_id=} {ev.start=}")
|
||||
|
||||
def on_wave_ev(self, ev:rocprof.rocprofiler_thread_trace_decoder_wave_t):
|
||||
if DEBUG >= 5: print("WAVE", ev.wave_id, self.active_se, ev.cu, ev.simd, ev.contexts, ev.begin_time, ev.end_time)
|
||||
if DEBUG >= 5: print(f"WAVE {ev.wave_id=} {self.active_se=} {ev.cu=} {ev.simd=} {ev.contexts=} {ev.begin_time=} {ev.end_time=}")
|
||||
|
||||
insts_blob = bytearray(sz:=ev.instructions_size * ctypes.sizeof(rocprof.rocprofiler_thread_trace_decoder_inst_t))
|
||||
ctypes.memmove((ctypes.c_char * sz).from_buffer(insts_blob), ev.instructions_array, sz)
|
||||
@@ -109,6 +109,10 @@ def decode(profile:list[ProfileEvent]) -> _ROCParseCtx:
|
||||
for ev in (rocprof.rocprofiler_thread_trace_decoder_occupancy_t * n).from_address(events_ptr): ROCParseCtx.on_occupancy_ev(ev)
|
||||
case rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_WAVE:
|
||||
for ev in (rocprof.rocprofiler_thread_trace_decoder_wave_t * n).from_address(events_ptr): ROCParseCtx.on_wave_ev(ev)
|
||||
case rocprof.ROCPROFILER_THREAD_TRACE_DECODER_RECORD_REALTIME:
|
||||
if DEBUG >= 5:
|
||||
pairs = [(ev.shader_clock, ev.realtime_clock) for ev in (rocprof.rocprofiler_thread_trace_decoder_realtime_t * n).from_address(events_ptr)]
|
||||
print(f"REALTIME {pairs}")
|
||||
case _:
|
||||
if DEBUG >= 5: print(rocprof.enum_rocprofiler_thread_trace_decoder_record_type_t.get(record_type), events_ptr, n)
|
||||
return rocprof.ROCPROFILER_THREAD_TRACE_DECODER_STATUS_SUCCESS
|
||||
|
||||
@@ -199,7 +199,10 @@ class AMDComputeQueue(HWQueue):
|
||||
for xcc in range(self.dev.xccs):
|
||||
with self.pred_exec(xcc_mask=1 << xcc):
|
||||
for i in range(8 if prg.dev.target >= (11,0,0) else 4):
|
||||
self.wreg(getattr(self.gc, f'regCOMPUTE_STATIC_THREAD_MGMT_SE{i}'), min(0xffffffff, (1 << (se_cap + (1 if i == 0 else 0))) - 1))
|
||||
if SQTT_LIMIT_SE > 1: # only run unmasked shader engines
|
||||
self.wreg(getattr(self.gc, f'regCOMPUTE_STATIC_THREAD_MGMT_SE{i}'), 1 if SQTT_ITRACE_SE_MASK.value & (1 << i) else 0)
|
||||
else:
|
||||
self.wreg(getattr(self.gc, f'regCOMPUTE_STATIC_THREAD_MGMT_SE{i}'), min(0xffffffff, (1 << (se_cap + (1 if i == 0 else 0))) - 1))
|
||||
|
||||
def sqtt_userdata(self, data, *extra_dwords):
|
||||
data_ints = [x[0] for x in struct.iter_unpack('<I', bytes(data))] + list(extra_dwords)
|
||||
|
||||
@@ -222,6 +222,9 @@ def load_sqtt(profile:list[ProfileEvent]) -> None:
|
||||
except Exception: return err("DECODER IMPORT ISSUE")
|
||||
try: rctx = decode(profile)
|
||||
except Exception: return err("DECODER ERROR")
|
||||
if getenv("SQTT_PARSE"):
|
||||
from extra.sqtt.attempt_sqtt_parse import parse_sqtt_print_packets
|
||||
for e in sqtt_events: parse_sqtt_print_packets(e.blob)
|
||||
if not rctx.inst_execs: return err("EMPTY SQTT OUTPUT", f"{len(sqtt_events)} SQTT events recorded, none got decoded")
|
||||
steps:list[dict] = []
|
||||
for name,waves in rctx.inst_execs.items():
|
||||
|
||||
Reference in New Issue
Block a user