cleanup sqtt raw parser (#13309)

* cleanup sqtt raw parser

* better names (don't merge yet)

* clean up amd

* a few more names

* one more filter
This commit is contained in:
George Hotz
2025-11-16 13:11:51 -08:00
committed by GitHub
parent cabd4add48
commit 55be95da15
2 changed files with 83 additions and 49 deletions

View File

@@ -7,6 +7,7 @@ os.environ["AMD_LLVM"] = "0"
from dataclasses import replace
import atexit, contextlib
from tinygrad import Tensor
from tinygrad.helpers import system, getenv
from tinygrad.runtime.ops_amd import AMDProgram
from extra.sqtt.roc import decode, WaveExec, ProfileSQTTEvent
@@ -29,16 +30,15 @@ def save_sqtt():
yield sqtt
events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())]
rctx = decode(events)
assert len(rctx.inst_execs) > 0, "empty sqtt output"
sqtt.update(rctx.inst_execs)
#rctx = decode(events)
#assert len(rctx.inst_execs) > 0, "empty sqtt output"
#sqtt.update(rctx.inst_execs)
for e in events:
if isinstance(e, ProfileSQTTEvent):
print(replace(e, blob=b''))
if e.se == 0:
parse_sqtt_print_packets(e.blob, filter=[0xf, 0x11, 0x12, 0x14] if getenv("FILTER", 1) else None)
parse_sqtt_print_packets(e.blob)
template = """.text
.globl matmul
@@ -51,6 +51,7 @@ matmul:
.rodata
.p2align 6
.amdhsa_kernel matmul
.amdhsa_user_sgpr_kernarg_segment_ptr 1
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
.amdhsa_wavefront_size32 1
@@ -64,14 +65,21 @@ amdhsa.version:
amdhsa.kernels:
- .name: matmul
.symbol: matmul.kd
.kernarg_segment_size: 0
.group_segment_fixed_size: 0
.private_segment_fixed_size: 0
.kernarg_segment_align: 4
.wavefront_size: 32
.sgpr_count: 8
.vgpr_count: 32
.max_flat_workgroup_size: 1024
.kernarg_segment_align: 8
.kernarg_segment_size: 8
.args:
- .address_space: global
.name: a
.offset: 0
.size: 8
.type_name: 'float*'
.value_kind: global_buffer
...
.end_amdgpu_metadata
"""
@@ -80,20 +88,42 @@ def run_asm(src):
NUM_WORKGROUPS = 1
WAVE_SIZE = 32
NUM_WAVES = 1
t = Tensor.empty(0x1000).realize()
buf = t.uop.buffer.ensure_allocated()
lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src)))
dev.compiler.disassemble(lib)
fxn = AMDProgram(dev, "matmul", lib)
fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True)
fxn(buf._buf, global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True)
if __name__ == "__main__":
with save_sqtt() as sqtt:
#(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize()
Tensor.empty(1).elu().realize()
exit(0)
with save_sqtt() as sqtt:
# what's in v0?
run_asm([
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v1, 0",
"s_clause 0x1",
"s_load_b64 s[0:1], s[0:1], null",
"s_waitcnt lgkmcnt(0)",
]+[
"global_load_b32 v1, v0, s[0:1]",
]*10+[
"global_load_b32 v10, v1, s[0:1]",
"s_waitcnt vmcnt(0)",
#"v_rcp_f32 v1, v0"
"v_add_f32_e32 v1 v0 v0",
"v_add_f32_e32 v3 v2 v2",
"v_add_f32_e32 v5 v4 v4",
"v_add_f32_e32 v7 v6 v6",
#"v_add_f32_e32 v1 v0 v0",
#"v_add_f32_e32 v5 v4 v4",
#"v_add_f32_e32 v7 v6 v6",
#"v_add_f32_e32 v1 v0 v0",
#"v_add_f32_e32 v2 v1 v1",
#"s_nop 1"
]*1)
]*5+[
"v_add_f32_e32 v3 v2 v2",
]*5+[
"v_mul_f32_e32 v3 v2 v2",
]*7)

View File

@@ -1,26 +1,39 @@
import pickle
from tinygrad.helpers import getenv
from extra.sqtt.roc import decode, ProfileSQTTEvent
# Instruction packets (one per ISA op)
# NOTE: these are bad guesses and may be wrong! feel free to update if you know better
# some names were taken from SQ_TT_TOKEN_MASK_TOKEN_EXCLUDE_SHIFT
OPCODE_NAMES = {
# ------------------------------------------------------------------------
# 0x010x06: small “meta + maybe tiny delta” packets
# ------------------------------------------------------------------------
0x01: "META_ID12_TS_SMALL", # 12-bit ID + 3-bit delta field
0x02: "META_FLAG8_TS_SMALL", # 8-bit flag/mode + small delta
0x03: "META_SUBEVENT8_TS_SMALL", # 8-bit subevent/class + small delta
0x04: "META_BASE_INDEX12_TS", # 12-bit base index + small delta
0x05: "META_DESC24_TS_A", # 24-bit descriptor-ish + delta field
0x06: "META_DESC24_TS_B", # second flavour, 24-bit, delta field
# gated by SQ_TT_TOKEN_EXCLUDE_VMEMEXEC_SHIFT
0x02: "VMEMEXEC",
# gated by SQ_TT_TOKEN_EXCLUDE_ALUEXEC_SHIFT
0x03: "ALUEXEC",
# gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show)
0x01: "VALUINST",
# gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT
0x06: "WAVERDY",
# gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT
0x08: "WAVEEND",
0x09: "WAVESTART",
# gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT
0x04: "IMMEDIATE_4",
0x05: "IMMEDIATE_5",
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there
0x14: "REG",
# gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT
0x12: "EVENT",
# gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT
0x18: "INST",
# gated by SQ_TT_TOKEN_EXCLUDE_UTILCTR_SHIFT
0x19: "UTILCTR",
# ------------------------------------------------------------------------
# 0x070x0F: pure timestamp-ish deltas
# ------------------------------------------------------------------------
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
0x08: "EVT_MATCH_SMALL", # event-ish, see fields below
0x09: "PERF_ROUTE_CONFIG", # routing/indirection config
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
@@ -34,15 +47,11 @@ OPCODE_NAMES = {
0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint
0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10)
0x12: "EVT_SECONDARY_METRIC24", # 24-bit secondary timing/perf metric
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
0x14: "INST_EXEC_OR_CFG", # instruction exec record / config write / COR marker
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B
0x18: "PERF_EVENT_SELECT", # packed selector → FUN_0010aba0
0x19: "EVT_SUMMARY_48B", # 6-byte summary/aggregate metric
}
# these tables are from rocprof trace decoder
@@ -181,9 +190,8 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
mode = "other"
val36 = (pkt >> 12) & ((1 << 36) - 1)
fields.append(f"mode={mode}")
fields.append(f"val36=0x{val36:x}")
if mode == "delta":
fields.append(f"delta36={delta}")
if mode != "delta":
fields.append(f"val36=0x{val36:x}")
return ", ".join(fields)
# For 0x07, 0x0A0x0E, we know they drive time (via DELTA_MAP_DEFAULT),
@@ -408,7 +416,15 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
return ", ".join(fields)
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) -> None:
# 0xb is time something
# 0xd is time something
# 0xf is small time advance
# 0x11 is time advance
# 0x16 is big time advance + markers
# 0x14 is REG
DEFAULT_FILTER = (0xb, 0xd, 0xf, 0x11, 0x16, 0x14) if getenv("FILTER", 1) else None
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=DEFAULT_FILTER) -> None:
"""
Minimal debug: print ONE LINE per decoded token (packet).
@@ -466,23 +482,17 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None)
flags |= 0x01
# Common 36-bit field at bits [12..47]
val36 = (reg >> 12) & ((1 << 36) - 1)
if (reg & 0x200) == 0:
# delta mode: add 36-bit delta to time
delta = val36
delta = (reg >> 12) & ((1 << 36) - 1)
time += delta
note = "0x16-delta"
else:
# marker / other modes: no time advance
if (reg & 0x100) == 0 and val36 != 0:
if (reg & 0x100) == 0:
# real marker: bit9=1, bit8=0, non-zero payload
delta = 0
note = f"0x16-marker val=0x{val36:x}"
else:
# "other" 0x16 variants, ignored for timing
delta = 0
note = "0x16-other"
else:
# 6) Generic opcode (including 0x0F)
shift, width = DELTA_MAP_DEFAULT[opcode]
@@ -492,19 +502,13 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None)
# TODO: add more opcode parsers here that add notes to other opcodes
if opcode == 0x0F:
delta_with_fix = delta + 4
note = f"0x0f (+4) raw_delta={delta}"
time += delta_with_fix
delta = delta_with_fix
else:
time += delta
# ONE-LINE PRINT PER PACKET
#assert last_real_offset%8 == 0
#assert (offset)%8 == 0, f"misalign offset {offset}"
# Append extra decoded fields into the note string
extra = decode_packet_fields(opcode, reg, delta)
if extra: note = (note + " ; " + extra) if note else extra
note = decode_packet_fields(opcode, reg, delta)
if filter is None or opcode not in filter:
my_reg = reg
@@ -533,7 +537,7 @@ def parse(fn:str):
if __name__ == "__main__":
#dat_sqtt = parse("extra/sqtt/examples/profile_empty_run_0.pkl")
dat_sqtt = parse("extra/sqtt/examples/profile_plus_run_0.pkl")
#dat_sqtt = parse("extra/sqtt/examples/profile_gemm_run_0.pkl")
#dat_sqtt = parse("extra/sqtt/examples/profile_plus_run_0.pkl")
dat_sqtt = parse("extra/sqtt/examples/profile_gemm_run_0.pkl")
blob_0 = dat_sqtt[0].blob
parse_sqtt_print_packets(blob_0[8:])