beautiful SQTT_PARSE=1 with color (#13428)

* beautiful SQTT_PARSE=1 with color

* linter

* linter 2

* a few more labels

* filter and or

* wave alloc

* a few more
This commit is contained in:
George Hotz
2025-11-23 01:05:14 -08:00
committed by GitHub
parent 474a631877
commit 9d7a17ee39
3 changed files with 135 additions and 108 deletions

View File

@@ -88,74 +88,34 @@ def run_asm(src, num_workgroups=1, num_waves=1):
fxn(buf._buf, global_size=(num_workgroups,1,1), local_size=(WAVE_SIZE*num_waves,1,1), wait=True)
if __name__ == "__main__":
"""
with save_sqtt() as sqtt:
run_asm([
#"s_barrier",
#"s_nop 0",
#"s_nop 0",
#"s_nop 15",
#"s_nop 0",
#"s_nop 0",
"s_nop 0",
"s_nop 0",
"s_nop 100",
"s_nop 100",
"s_load_b64 s[0:1], s[0:1], null",
"s_waitcnt lgkmcnt(0)",
"s_nop 0",
"s_nop 0",
"s_nop 100",
"s_nop 100",
"s_nop 100",
"s_nop 0",
"s_nop 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v1, 0",
"s_nop 0",
"s_nop 0",
"s_add_i32 s2, s2, 10",
"s_add_i32 s2, s2, 10",
"s_nop 100",
"s_nop 100",
"s_nop 100",
"s_nop 0",
"s_nop 0",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v3, v0, s[0:1]",
"global_load_b32 v4, v0, s[0:1]",
"global_load_b32 v5, v0, s[0:1]",
"global_load_b32 v5, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v1, s[0:1]",
"global_load_b32 v2, v1, s[0:1]",
"global_load_b32 v2, v1, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"global_load_b32 v2, v0, s[0:1]",
"s_nop 0",
"s_nop 0",
"s_nop 0",
"s_waitcnt vmcnt(0)",
"v_mov_b32_e32 v0, 0",
"v_mov_b32_e32 v0, 0",
"s_nop 100",
"s_nop 100",
"v_dual_fmac_f32 v2, v48, v24 :: v_dual_fmac_f32 v9, v37, v51",
"v_dual_fmac_f32 v2, v48, v24 :: v_dual_fmac_f32 v9, v37, v51",
"s_nop 100",
"s_nop 0",
"s_nop 0",
#"v_add_f32_e32 v1 v0 v0",
#"s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)",
"s_nop 100",
"global_load_b128 v[2:5], v0, s[0:1]",
"global_load_b128 v[2:5], v0, s[0:1]",
"s_nop 100",
"s_nop 100",
"s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)",
"s_endpgm",
], num_workgroups=1, num_waves=1)
exit(0)
"""
with save_sqtt() as sqtt:
#(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize()

View File

@@ -22,6 +22,24 @@ from extra.sqtt.roc import decode, ProfileSQTTEvent
# NOTE: INST runs before EXEC
OPCODE_COLORS = {
# dispatches are BLACK
0x1: "BLACK",
0x18: "BLACK",
# execs are yellow
0x2: "yellow",
0x3: "yellow",
0x4: "YELLOW",
0x5: "YELLOW",
# waves are blue
0x8: "blue",
0x9: "blue",
0x6: "cyan",
0xb: "cyan",
}
OPCODE_NAMES = {
# gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show)
0x01: "VALUINST",
@@ -31,16 +49,21 @@ OPCODE_NAMES = {
0x03: "ALUEXEC",
# gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT
0x04: "IMMEDIATE",
0x05: "IMMEDIATE_MULTIWAVE",
0x05: "IMMEDIATE_MASK",
# gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT
0x06: "WAVERDY",
# gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT
0x08: "WAVEEND",
0x09: "WAVESTART",
# gated by SQ_TT_TOKEN_EXCLUDE_WAVEALLOC_SHIFT
0x0B: "WAVEALLOC", # FFF00
# gated by NOT SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT
0x0D: "PERF",
# gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT
0x12: "EVENT",
0x13: "EVENT_BIG", # FFFFF800
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there. something is broken with the timing on this
0x14: "REG",
# gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT
@@ -49,25 +72,91 @@ OPCODE_NAMES = {
0x19: "UTILCTR",
# this is the first (8 byte) packet in the bitstream
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B (reversed)
0x17: "LAYOUT_HEADER", # layout/mode/group + selectors A/B (reversed)
# pure time (no extra bits)
0x0F: "TS_DELTA_SHORT_PLUS4",
0x0F: "TS_DELTA_SHORT",
0x10: "NOP",
0x11: "TS_WAVE_STATE_SAMPLE", # almost pure time, has a small flag
0x11: "TS_WAVE_STATE", # almost pure time, has a small flag
# not a good name, but seen and understood mostly
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
0x15: "SNAPSHOT", # small delta + 50-ish bits of snapshot
0x16: "TS_DELTA_OR_MARK", # 36-bit long delta or 36-bit marker
# packets we haven't seen / rarely see 0x0b
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
0x07: "TS_DELTA_S8_W3_7", # shift=8, width=3 (small delta)
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
}
# SALU = 0x0 / s_mov_b32
# SMEM = 0x1 / s_load_b*
# JUMP = 0x3 / s_cbranch_scc0
# NEXT = 0x4 / s_cbranch_execz
# MESSAGE = 0x9 / s_sendmsg
# VALU = 0xb / v_(exp,log)_f32_e32
# VALU = 0xd / v_lshlrev_b64
# VALU = 0xe / v_mad_u64_u32
# VMEM = 0x21 / global_load_b32
# VMEM = 0x22 / global_load_b32
# VMEM = 0x24 / global_store_b32
# VMEM = 0x25 / global_store_b64
# VMEM = 0x27 / global_store
# VMEM = 0x28 / global_store_b64
# LDS = 0x29 / ds_load_b128
# LDS = 0x2b / ds_store_b32
# LDS = 0x2e / ds_store_b128
# ???? = 0x5a / hidden global_load instruction
# ???? = 0x5b / hidden global_load instruction
# ???? = 0x5c / hidden global_store instruction
# VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST)
OPNAME = {
0x0: "SALU",
0x1: "SMEM",
0x3: "JUMP",
0x4: "NEXT",
0x9: "MESSAGE",
0xb: "VALU",
0xd: "VALU",
0xe: "VALU",
0x10: "__END",
0x21: "VMEM_LOAD",
0x22: "VMEM_LOAD",
0x24: "VMEM_STORE",
0x25: "VMEM_STORE",
0x26: "VMEM_STORE",
0x27: "VMEM_STORE",
0x28: "VMEM_STORE",
0x29: "LDS_LOAD",
0x2b: "LDS_STORE",
0x2e: "LDS_STORE",
0x50: "__SIMD_LDS_LOAD",
0x51: "__SIMD_LDS_LOAD",
0x54: "__SIMD_LDS_STORE",
0x5a: "__SIMD_VMEM_LOAD",
0x5b: "__SIMD_VMEM_LOAD",
0x5c: "__SIMD_VMEM_STORE",
0x5d: "__SIMD_VMEM_STORE",
0x5e: "__SIMD_VMEM_STORE",
0x5f: "__SIMD_VMEM_STORE",
0x72: "SALU_OR",
0x73: "VALU_CMPX",
}
ALUSRC = {
1: "SALU",
2: "VALU",
3: "VALU_ALT",
}
MEMSRC = {
0: "LDS",
1: "__LDS",
2: "VMEM",
3: "__VMEM",
}
# these tables are from rocprof trace decoder
# rocprof_trace_decoder_parse_data-0x11c6a0
# parse_sqtt_180 = b *rocprof_trace_decoder_parse_data-0x11c6a0+0x110040
@@ -193,14 +282,15 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
flag = (pkt >> 6) & 1
wave = pkt >> 7
fields.append(f"wave={wave:x}")
assert flag == 0, "non 0 flag in 0x1"
#fields.append(f"flag={flag:X}")
if flag: fields.append("flag")
case 0x02: # VMEMEXEC
# 2 bit field (pipe is a guess)
fields.append(f"pipe={pkt>>6:X}")
src = pkt>>6
fields.append(f"src={src} [{MEMSRC.get(src, '')}]")
case 0x03: # ALUEXEC
# 2 bit field (pipe is a guess)
fields.append(f"pipe={pkt>>6:X}")
# 2 bit field
src = pkt>>6
fields.append(f"src={src} [{ALUSRC.get(src, '')}]")
case 0x04: # IMMEDIATE_4
# 5 bit field (actually 4)
wave = pkt >> 7
@@ -208,12 +298,12 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
case 0x05: # IMMEDIATE_5
# 16 bit field
# 1 bit per wave
fields.append(f"mask={pkt>>8:16b}")
fields.append(f"mask={pkt>>8:016b}")
case 0x6:
# wave ready FFFF00
# 16 bit field
# 1 bit per wave
fields.append(f"mask={pkt>>8:16b}")
fields.append(f"mask={pkt>>8:016b}")
case 0x0d:
# 20 bit field
fields.append(f"arg = {pkt>>8:X}")
@@ -274,32 +364,14 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
# hi8 = (w >> 0xc) & 0xff (layout 4 path)
# hi7 = (w >> 0xd) & 0x7f (other layouts)
# idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index
flag = (pkt >> 3) & 1
flag1 = (pkt >> 3) & 1
flag2 = (pkt >> 7) & 1
wave = (pkt >> 8) & 0x1F
hi8 = (pkt >> 13)
op = (pkt >> 13)
fields.append(f"wave={wave:x}")
assert flag == 0 and flag2 == 0, "non 0 flags in 0x18"
#fields.append(f"flag={flag:x}")
#fields.append(f"flag2={flag2:x}")
# hi8 values:
# SALU = 0x0 / s_mov_b32
# SMEM = 0x1 / s_load_b*
# NEXT = 0x4 / s_cbranch_execz
# MESSAGE = 0x9 / s_sendmsg
# VALU = 0xb / v_(exp,log)_f32_e32
# VALU = 0xd / v_lshlrev_b64
# VMEM = 0x21 / global_load_b32
# VMEM = 0x22 / global_load_b32
# VMEM = 0x24 / global_store_b32
# VMEM = 0x25 / global_store_b64
# LDS = 0x29 / ds_load_b128
# LDS = 0x2b / ds_store_b32
# ???? = 0x5a / hidden global_load instruction
# ???? = 0x5b / hidden global_load instruction
# ???? = 0x5c / hidden global_store instruction
# VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST)
fields.append(f"hi8=0x{hi8:x}")
fields.append(f"op=0x{op:02x} [{OPNAME.get(op, '')}]")
if flag1: fields.append("flag1")
if flag2: fields.append("flag2")
case 0x14:
subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10)
val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20)
@@ -364,21 +436,21 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
if layout == 4:
fields.append(f"layout4_flag={flag4}")
case _:
fields.append(f"& {reg_mask(opcode):X}")
fields.append(f"{pkt:X} & {reg_mask(opcode):X}")
return ",".join(fields)
FILTER_LEVEL = getenv("FILTER", 1)
DEFAULT_FILTER = tuple()
DEFAULT_FILTER: tuple[int, ...] = tuple()
# NOP + pure time + "sample"
if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf, 0x11)
# reg + event + sample + marker
# TODO: events are probably good
if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x14, 0x12, 0x16)
# instruction runs
if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x02, 0x03)
# instructions dispatch (inst, valuinst, immed)
if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x01, 0x4, 0x5, 0x18)
# instruction runs + valuinst
if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x01, 0x02, 0x03)
# instructions dispatch (inst, immed)
if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x4, 0x5, 0x18)
# waves
if FILTER_LEVEL >= 4: DEFAULT_FILTER += (0x6, 0x8, 0x9)
@@ -450,12 +522,7 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
token_index += 1
if verbose and (filter is None or opcode not in filter):
print(
f"{time:8d}+{time-last_printed_time:8d} : "
f"op={opcode:2x} "
f"{OPCODE_NAMES[opcode]:24s} "
f"{reg&reg_mask(opcode):16X} "
f"{note}")
print(f"{time:8d} +{time-last_printed_time:8d} : "+colored(f"{OPCODE_NAMES[opcode]:18s} ", OPCODE_COLORS.get(opcode, "white"))+f"{note}")
last_printed_time = time
# Optional summary at the end

View File

@@ -35,9 +35,9 @@ class TestTiny(unittest.TestCase):
out = Tensor.cat(Tensor.ones(8).contiguous(), Tensor.zeros(8).contiguous())
self.assertListEqual(out.tolist(), [1]*8+[0]*8)
def test_sum(self):
out = Tensor.ones(256).contiguous().sum()
self.assertEqual(out.item(), 256)
def test_sum(self, N=getenv("SUM_N", 256)):
out = Tensor.ones(N).contiguous().sum()
self.assertEqual(out.item(), N)
def test_gemm(self, N=getenv("GEMM_N", 64), out_dtype=dtypes.float):
a = Tensor.ones(N,N).contiguous()