mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
cleanup sqtt raw parser (#13309)
* cleanup sqtt raw parser * better names (don't merge yet) * clean up amd * a few more names * one more filter
This commit is contained in:
@@ -7,6 +7,7 @@ os.environ["AMD_LLVM"] = "0"
|
||||
|
||||
from dataclasses import replace
|
||||
import atexit, contextlib
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import system, getenv
|
||||
from tinygrad.runtime.ops_amd import AMDProgram
|
||||
from extra.sqtt.roc import decode, WaveExec, ProfileSQTTEvent
|
||||
@@ -29,16 +30,15 @@ def save_sqtt():
|
||||
yield sqtt
|
||||
events = dev.profile_events+[ProfileDeviceEvent("AMD", props=dev.device_props())]
|
||||
|
||||
rctx = decode(events)
|
||||
assert len(rctx.inst_execs) > 0, "empty sqtt output"
|
||||
sqtt.update(rctx.inst_execs)
|
||||
#rctx = decode(events)
|
||||
#assert len(rctx.inst_execs) > 0, "empty sqtt output"
|
||||
#sqtt.update(rctx.inst_execs)
|
||||
|
||||
for e in events:
|
||||
if isinstance(e, ProfileSQTTEvent):
|
||||
print(replace(e, blob=b''))
|
||||
if e.se == 0:
|
||||
parse_sqtt_print_packets(e.blob, filter=[0xf, 0x11, 0x12, 0x14] if getenv("FILTER", 1) else None)
|
||||
|
||||
parse_sqtt_print_packets(e.blob)
|
||||
|
||||
template = """.text
|
||||
.globl matmul
|
||||
@@ -51,6 +51,7 @@ matmul:
|
||||
.rodata
|
||||
.p2align 6
|
||||
.amdhsa_kernel matmul
|
||||
.amdhsa_user_sgpr_kernarg_segment_ptr 1
|
||||
.amdhsa_next_free_vgpr .amdgcn.next_free_vgpr
|
||||
.amdhsa_next_free_sgpr .amdgcn.next_free_sgpr
|
||||
.amdhsa_wavefront_size32 1
|
||||
@@ -64,14 +65,21 @@ amdhsa.version:
|
||||
amdhsa.kernels:
|
||||
- .name: matmul
|
||||
.symbol: matmul.kd
|
||||
.kernarg_segment_size: 0
|
||||
.group_segment_fixed_size: 0
|
||||
.private_segment_fixed_size: 0
|
||||
.kernarg_segment_align: 4
|
||||
.wavefront_size: 32
|
||||
.sgpr_count: 8
|
||||
.vgpr_count: 32
|
||||
.max_flat_workgroup_size: 1024
|
||||
.kernarg_segment_align: 8
|
||||
.kernarg_segment_size: 8
|
||||
.args:
|
||||
- .address_space: global
|
||||
.name: a
|
||||
.offset: 0
|
||||
.size: 8
|
||||
.type_name: 'float*'
|
||||
.value_kind: global_buffer
|
||||
...
|
||||
.end_amdgpu_metadata
|
||||
"""
|
||||
@@ -80,20 +88,42 @@ def run_asm(src):
|
||||
NUM_WORKGROUPS = 1
|
||||
WAVE_SIZE = 32
|
||||
NUM_WAVES = 1
|
||||
t = Tensor.empty(0x1000).realize()
|
||||
buf = t.uop.buffer.ensure_allocated()
|
||||
lib = dev.compiler.compile(template.replace("INSTRUCTION", '\n'.join(src)))
|
||||
dev.compiler.disassemble(lib)
|
||||
fxn = AMDProgram(dev, "matmul", lib)
|
||||
fxn(global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True)
|
||||
fxn(buf._buf, global_size=(NUM_WORKGROUPS,1,1), local_size=(WAVE_SIZE*NUM_WAVES,1,1), wait=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
with save_sqtt() as sqtt:
|
||||
#(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize()
|
||||
Tensor.empty(1).elu().realize()
|
||||
exit(0)
|
||||
|
||||
with save_sqtt() as sqtt:
|
||||
# what's in v0?
|
||||
run_asm([
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v1, 0",
|
||||
"s_clause 0x1",
|
||||
"s_load_b64 s[0:1], s[0:1], null",
|
||||
"s_waitcnt lgkmcnt(0)",
|
||||
]+[
|
||||
"global_load_b32 v1, v0, s[0:1]",
|
||||
]*10+[
|
||||
"global_load_b32 v10, v1, s[0:1]",
|
||||
"s_waitcnt vmcnt(0)",
|
||||
|
||||
#"v_rcp_f32 v1, v0"
|
||||
"v_add_f32_e32 v1 v0 v0",
|
||||
"v_add_f32_e32 v3 v2 v2",
|
||||
"v_add_f32_e32 v5 v4 v4",
|
||||
"v_add_f32_e32 v7 v6 v6",
|
||||
#"v_add_f32_e32 v1 v0 v0",
|
||||
#"v_add_f32_e32 v5 v4 v4",
|
||||
#"v_add_f32_e32 v7 v6 v6",
|
||||
#"v_add_f32_e32 v1 v0 v0",
|
||||
#"v_add_f32_e32 v2 v1 v1",
|
||||
#"s_nop 1"
|
||||
]*1)
|
||||
]*5+[
|
||||
"v_add_f32_e32 v3 v2 v2",
|
||||
]*5+[
|
||||
"v_mul_f32_e32 v3 v2 v2",
|
||||
]*7)
|
||||
|
||||
@@ -1,26 +1,39 @@
|
||||
import pickle
|
||||
from tinygrad.helpers import getenv
|
||||
from extra.sqtt.roc import decode, ProfileSQTTEvent
|
||||
|
||||
# Instruction packets (one per ISA op)
|
||||
# NOTE: these are bad guesses and may be wrong! feel free to update if you know better
|
||||
# some names were taken from SQ_TT_TOKEN_MASK_TOKEN_EXCLUDE_SHIFT
|
||||
|
||||
OPCODE_NAMES = {
|
||||
# ------------------------------------------------------------------------
|
||||
# 0x01–0x06: small “meta + maybe tiny delta” packets
|
||||
# ------------------------------------------------------------------------
|
||||
0x01: "META_ID12_TS_SMALL", # 12-bit ID + 3-bit delta field
|
||||
0x02: "META_FLAG8_TS_SMALL", # 8-bit flag/mode + small delta
|
||||
0x03: "META_SUBEVENT8_TS_SMALL", # 8-bit subevent/class + small delta
|
||||
0x04: "META_BASE_INDEX12_TS", # 12-bit base index + small delta
|
||||
0x05: "META_DESC24_TS_A", # 24-bit descriptor-ish + delta field
|
||||
0x06: "META_DESC24_TS_B", # second flavour, 24-bit, delta field
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_VMEMEXEC_SHIFT
|
||||
0x02: "VMEMEXEC",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_ALUEXEC_SHIFT
|
||||
0x03: "ALUEXEC",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show)
|
||||
0x01: "VALUINST",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT
|
||||
0x06: "WAVERDY",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT
|
||||
0x08: "WAVEEND",
|
||||
0x09: "WAVESTART",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT
|
||||
0x04: "IMMEDIATE_4",
|
||||
0x05: "IMMEDIATE_5",
|
||||
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there
|
||||
0x14: "REG",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT
|
||||
0x12: "EVENT",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT
|
||||
0x18: "INST",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_UTILCTR_SHIFT
|
||||
0x19: "UTILCTR",
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# 0x07–0x0F: pure timestamp-ish deltas
|
||||
# ------------------------------------------------------------------------
|
||||
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
|
||||
0x08: "EVT_MATCH_SMALL", # event-ish, see fields below
|
||||
0x09: "PERF_ROUTE_CONFIG", # routing/indirection config
|
||||
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
|
||||
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
|
||||
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
|
||||
@@ -34,15 +47,11 @@ OPCODE_NAMES = {
|
||||
0x10: "PSEUDO_NEED_MORE_BITS", # not a real packet; decoder refill hint
|
||||
|
||||
0x11: "TS_WAVE_STATE_SAMPLE", # wave stall/termination sample (byte at +10)
|
||||
0x12: "EVT_SECONDARY_METRIC24", # 24-bit secondary timing/perf metric
|
||||
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
|
||||
|
||||
0x14: "INST_EXEC_OR_CFG", # instruction exec record / config write / COR marker
|
||||
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
|
||||
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
|
||||
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B
|
||||
0x18: "PERF_EVENT_SELECT", # packed selector → FUN_0010aba0
|
||||
0x19: "EVT_SUMMARY_48B", # 6-byte summary/aggregate metric
|
||||
}
|
||||
|
||||
# these tables are from rocprof trace decoder
|
||||
@@ -181,9 +190,8 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
||||
mode = "other"
|
||||
val36 = (pkt >> 12) & ((1 << 36) - 1)
|
||||
fields.append(f"mode={mode}")
|
||||
fields.append(f"val36=0x{val36:x}")
|
||||
if mode == "delta":
|
||||
fields.append(f"delta36={delta}")
|
||||
if mode != "delta":
|
||||
fields.append(f"val36=0x{val36:x}")
|
||||
return ", ".join(fields)
|
||||
|
||||
# For 0x07, 0x0A–0x0E, we know they drive time (via DELTA_MAP_DEFAULT),
|
||||
@@ -408,7 +416,15 @@ def decode_packet_fields(opcode: int, reg: int, delta: int) -> str:
|
||||
|
||||
return ", ".join(fields)
|
||||
|
||||
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None) -> None:
|
||||
# 0xb is time something
|
||||
# 0xd is time something
|
||||
# 0xf is small time advance
|
||||
# 0x11 is time advance
|
||||
# 0x16 is big time advance + markers
|
||||
# 0x14 is REG
|
||||
DEFAULT_FILTER = (0xb, 0xd, 0xf, 0x11, 0x16, 0x14) if getenv("FILTER", 1) else None
|
||||
|
||||
def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=DEFAULT_FILTER) -> None:
|
||||
"""
|
||||
Minimal debug: print ONE LINE per decoded token (packet).
|
||||
|
||||
@@ -466,23 +482,17 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None)
|
||||
flags |= 0x01
|
||||
|
||||
# Common 36-bit field at bits [12..47]
|
||||
val36 = (reg >> 12) & ((1 << 36) - 1)
|
||||
|
||||
if (reg & 0x200) == 0:
|
||||
# delta mode: add 36-bit delta to time
|
||||
delta = val36
|
||||
delta = (reg >> 12) & ((1 << 36) - 1)
|
||||
time += delta
|
||||
note = "0x16-delta"
|
||||
else:
|
||||
# marker / other modes: no time advance
|
||||
if (reg & 0x100) == 0 and val36 != 0:
|
||||
if (reg & 0x100) == 0:
|
||||
# real marker: bit9=1, bit8=0, non-zero payload
|
||||
delta = 0
|
||||
note = f"0x16-marker val=0x{val36:x}"
|
||||
else:
|
||||
# "other" 0x16 variants, ignored for timing
|
||||
delta = 0
|
||||
note = "0x16-other"
|
||||
else:
|
||||
# 6) Generic opcode (including 0x0F)
|
||||
shift, width = DELTA_MAP_DEFAULT[opcode]
|
||||
@@ -492,19 +502,13 @@ def parse_sqtt_print_packets(data: bytes, max_tokens: int = 100000, filter=None)
|
||||
# TODO: add more opcode parsers here that add notes to other opcodes
|
||||
if opcode == 0x0F:
|
||||
delta_with_fix = delta + 4
|
||||
note = f"0x0f (+4) raw_delta={delta}"
|
||||
time += delta_with_fix
|
||||
delta = delta_with_fix
|
||||
else:
|
||||
time += delta
|
||||
|
||||
# ONE-LINE PRINT PER PACKET
|
||||
#assert last_real_offset%8 == 0
|
||||
#assert (offset)%8 == 0, f"misalign offset {offset}"
|
||||
|
||||
# Append extra decoded fields into the note string
|
||||
extra = decode_packet_fields(opcode, reg, delta)
|
||||
if extra: note = (note + " ; " + extra) if note else extra
|
||||
note = decode_packet_fields(opcode, reg, delta)
|
||||
|
||||
if filter is None or opcode not in filter:
|
||||
my_reg = reg
|
||||
@@ -533,7 +537,7 @@ def parse(fn:str):
|
||||
|
||||
if __name__ == "__main__":
|
||||
#dat_sqtt = parse("extra/sqtt/examples/profile_empty_run_0.pkl")
|
||||
dat_sqtt = parse("extra/sqtt/examples/profile_plus_run_0.pkl")
|
||||
#dat_sqtt = parse("extra/sqtt/examples/profile_gemm_run_0.pkl")
|
||||
#dat_sqtt = parse("extra/sqtt/examples/profile_plus_run_0.pkl")
|
||||
dat_sqtt = parse("extra/sqtt/examples/profile_gemm_run_0.pkl")
|
||||
blob_0 = dat_sqtt[0].blob
|
||||
parse_sqtt_print_packets(blob_0[8:])
|
||||
|
||||
Reference in New Issue
Block a user