mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
beautiful SQTT_PARSE=1 with color (#13428)
* beautiful SQTT_PARSE=1 with color * linter * linter 2 * a few more labels * filter and or * wave alloc * a few more
This commit is contained in:
@@ -88,74 +88,34 @@ def run_asm(src, num_workgroups=1, num_waves=1):
|
||||
fxn(buf._buf, global_size=(num_workgroups,1,1), local_size=(WAVE_SIZE*num_waves,1,1), wait=True)
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""
|
||||
with save_sqtt() as sqtt:
|
||||
run_asm([
|
||||
#"s_barrier",
|
||||
#"s_nop 0",
|
||||
#"s_nop 0",
|
||||
#"s_nop 15",
|
||||
#"s_nop 0",
|
||||
#"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_nop 100",
|
||||
"s_nop 100",
|
||||
"s_load_b64 s[0:1], s[0:1], null",
|
||||
"s_waitcnt lgkmcnt(0)",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_nop 100",
|
||||
"s_nop 100",
|
||||
"s_nop 100",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v1, 0",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_add_i32 s2, s2, 10",
|
||||
"s_add_i32 s2, s2, 10",
|
||||
"s_nop 100",
|
||||
"s_nop 100",
|
||||
"s_nop 100",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v3, v0, s[0:1]",
|
||||
"global_load_b32 v4, v0, s[0:1]",
|
||||
"global_load_b32 v5, v0, s[0:1]",
|
||||
"global_load_b32 v5, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v1, s[0:1]",
|
||||
"global_load_b32 v2, v1, s[0:1]",
|
||||
"global_load_b32 v2, v1, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"global_load_b32 v2, v0, s[0:1]",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
"s_waitcnt vmcnt(0)",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"v_mov_b32_e32 v0, 0",
|
||||
"s_nop 100",
|
||||
"s_nop 100",
|
||||
"v_dual_fmac_f32 v2, v48, v24 :: v_dual_fmac_f32 v9, v37, v51",
|
||||
"v_dual_fmac_f32 v2, v48, v24 :: v_dual_fmac_f32 v9, v37, v51",
|
||||
"s_nop 100",
|
||||
"s_nop 0",
|
||||
"s_nop 0",
|
||||
#"v_add_f32_e32 v1 v0 v0",
|
||||
#"s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)",
|
||||
"s_nop 100",
|
||||
"global_load_b128 v[2:5], v0, s[0:1]",
|
||||
"global_load_b128 v[2:5], v0, s[0:1]",
|
||||
"s_nop 100",
|
||||
"s_nop 100",
|
||||
"s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)",
|
||||
"s_endpgm",
|
||||
], num_workgroups=1, num_waves=1)
|
||||
exit(0)
|
||||
"""
|
||||
|
||||
with save_sqtt() as sqtt:
|
||||
#(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize()
|
||||
|
||||
@@ -22,6 +22,24 @@ from extra.sqtt.roc import decode, ProfileSQTTEvent
|
||||
|
||||
# NOTE: INST runs before EXEC
|
||||
|
||||
OPCODE_COLORS = {
|
||||
# dispatches are BLACK
|
||||
0x1: "BLACK",
|
||||
0x18: "BLACK",
|
||||
|
||||
# execs are yellow
|
||||
0x2: "yellow",
|
||||
0x3: "yellow",
|
||||
0x4: "YELLOW",
|
||||
0x5: "YELLOW",
|
||||
|
||||
# waves are blue
|
||||
0x8: "blue",
|
||||
0x9: "blue",
|
||||
0x6: "cyan",
|
||||
0xb: "cyan",
|
||||
}
|
||||
|
||||
OPCODE_NAMES = {
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show)
|
||||
0x01: "VALUINST",
|
||||
@@ -31,16 +49,21 @@ OPCODE_NAMES = {
|
||||
0x03: "ALUEXEC",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT
|
||||
0x04: "IMMEDIATE",
|
||||
0x05: "IMMEDIATE_MULTIWAVE",
|
||||
0x05: "IMMEDIATE_MASK",
|
||||
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT
|
||||
0x06: "WAVERDY",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT
|
||||
0x08: "WAVEEND",
|
||||
0x09: "WAVESTART",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_WAVEALLOC_SHIFT
|
||||
0x0B: "WAVEALLOC", # FFF00
|
||||
|
||||
# gated by NOT SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT
|
||||
0x0D: "PERF",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT
|
||||
0x12: "EVENT",
|
||||
0x13: "EVENT_BIG", # FFFFF800
|
||||
# some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there. something is broken with the timing on this
|
||||
0x14: "REG",
|
||||
# gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT
|
||||
@@ -49,25 +72,91 @@ OPCODE_NAMES = {
|
||||
0x19: "UTILCTR",
|
||||
|
||||
# this is the first (8 byte) packet in the bitstream
|
||||
0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B (reversed)
|
||||
0x17: "LAYOUT_HEADER", # layout/mode/group + selectors A/B (reversed)
|
||||
|
||||
# pure time (no extra bits)
|
||||
0x0F: "TS_DELTA_SHORT_PLUS4",
|
||||
0x0F: "TS_DELTA_SHORT",
|
||||
0x10: "NOP",
|
||||
0x11: "TS_WAVE_STATE_SAMPLE", # almost pure time, has a small flag
|
||||
0x11: "TS_WAVE_STATE", # almost pure time, has a small flag
|
||||
|
||||
# not a good name, but seen and understood mostly
|
||||
0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot
|
||||
0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker
|
||||
0x15: "SNAPSHOT", # small delta + 50-ish bits of snapshot
|
||||
0x16: "TS_DELTA_OR_MARK", # 36-bit long delta or 36-bit marker
|
||||
|
||||
# packets we haven't seen / rarely see 0x0b
|
||||
0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta)
|
||||
0x07: "TS_DELTA_S8_W3_7", # shift=8, width=3 (small delta)
|
||||
0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2
|
||||
0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3
|
||||
0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer)
|
||||
0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19
|
||||
}
|
||||
|
||||
# SALU = 0x0 / s_mov_b32
|
||||
# SMEM = 0x1 / s_load_b*
|
||||
# JUMP = 0x3 / s_cbranch_scc0
|
||||
# NEXT = 0x4 / s_cbranch_execz
|
||||
# MESSAGE = 0x9 / s_sendmsg
|
||||
# VALU = 0xb / v_(exp,log)_f32_e32
|
||||
# VALU = 0xd / v_lshlrev_b64
|
||||
# VALU = 0xe / v_mad_u64_u32
|
||||
# VMEM = 0x21 / global_load_b32
|
||||
# VMEM = 0x22 / global_load_b32
|
||||
# VMEM = 0x24 / global_store_b32
|
||||
# VMEM = 0x25 / global_store_b64
|
||||
# VMEM = 0x27 / global_store
|
||||
# VMEM = 0x28 / global_store_b64
|
||||
# LDS = 0x29 / ds_load_b128
|
||||
# LDS = 0x2b / ds_store_b32
|
||||
# LDS = 0x2e / ds_store_b128
|
||||
# ???? = 0x5a / hidden global_load instruction
|
||||
# ???? = 0x5b / hidden global_load instruction
|
||||
# ???? = 0x5c / hidden global_store instruction
|
||||
# VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST)
|
||||
OPNAME = {
|
||||
0x0: "SALU",
|
||||
0x1: "SMEM",
|
||||
0x3: "JUMP",
|
||||
0x4: "NEXT",
|
||||
0x9: "MESSAGE",
|
||||
0xb: "VALU",
|
||||
0xd: "VALU",
|
||||
0xe: "VALU",
|
||||
0x10: "__END",
|
||||
0x21: "VMEM_LOAD",
|
||||
0x22: "VMEM_LOAD",
|
||||
0x24: "VMEM_STORE",
|
||||
0x25: "VMEM_STORE",
|
||||
0x26: "VMEM_STORE",
|
||||
0x27: "VMEM_STORE",
|
||||
0x28: "VMEM_STORE",
|
||||
0x29: "LDS_LOAD",
|
||||
0x2b: "LDS_STORE",
|
||||
0x2e: "LDS_STORE",
|
||||
0x50: "__SIMD_LDS_LOAD",
|
||||
0x51: "__SIMD_LDS_LOAD",
|
||||
0x54: "__SIMD_LDS_STORE",
|
||||
0x5a: "__SIMD_VMEM_LOAD",
|
||||
0x5b: "__SIMD_VMEM_LOAD",
|
||||
0x5c: "__SIMD_VMEM_STORE",
|
||||
0x5d: "__SIMD_VMEM_STORE",
|
||||
0x5e: "__SIMD_VMEM_STORE",
|
||||
0x5f: "__SIMD_VMEM_STORE",
|
||||
0x72: "SALU_OR",
|
||||
0x73: "VALU_CMPX",
|
||||
}
|
||||
|
||||
ALUSRC = {
|
||||
1: "SALU",
|
||||
2: "VALU",
|
||||
3: "VALU_ALT",
|
||||
}
|
||||
|
||||
MEMSRC = {
|
||||
0: "LDS",
|
||||
1: "__LDS",
|
||||
2: "VMEM",
|
||||
3: "__VMEM",
|
||||
}
|
||||
|
||||
|
||||
# these tables are from rocprof trace decoder
|
||||
# rocprof_trace_decoder_parse_data-0x11c6a0
|
||||
# parse_sqtt_180 = b *rocprof_trace_decoder_parse_data-0x11c6a0+0x110040
|
||||
@@ -193,14 +282,15 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
flag = (pkt >> 6) & 1
|
||||
wave = pkt >> 7
|
||||
fields.append(f"wave={wave:x}")
|
||||
assert flag == 0, "non 0 flag in 0x1"
|
||||
#fields.append(f"flag={flag:X}")
|
||||
if flag: fields.append("flag")
|
||||
case 0x02: # VMEMEXEC
|
||||
# 2 bit field (pipe is a guess)
|
||||
fields.append(f"pipe={pkt>>6:X}")
|
||||
src = pkt>>6
|
||||
fields.append(f"src={src} [{MEMSRC.get(src, '')}]")
|
||||
case 0x03: # ALUEXEC
|
||||
# 2 bit field (pipe is a guess)
|
||||
fields.append(f"pipe={pkt>>6:X}")
|
||||
# 2 bit field
|
||||
src = pkt>>6
|
||||
fields.append(f"src={src} [{ALUSRC.get(src, '')}]")
|
||||
case 0x04: # IMMEDIATE_4
|
||||
# 5 bit field (actually 4)
|
||||
wave = pkt >> 7
|
||||
@@ -208,12 +298,12 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
case 0x05: # IMMEDIATE_5
|
||||
# 16 bit field
|
||||
# 1 bit per wave
|
||||
fields.append(f"mask={pkt>>8:16b}")
|
||||
fields.append(f"mask={pkt>>8:016b}")
|
||||
case 0x6:
|
||||
# wave ready FFFF00
|
||||
# 16 bit field
|
||||
# 1 bit per wave
|
||||
fields.append(f"mask={pkt>>8:16b}")
|
||||
fields.append(f"mask={pkt>>8:016b}")
|
||||
case 0x0d:
|
||||
# 20 bit field
|
||||
fields.append(f"arg = {pkt>>8:X}")
|
||||
@@ -274,32 +364,14 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
# hi8 = (w >> 0xc) & 0xff (layout 4 path)
|
||||
# hi7 = (w >> 0xd) & 0x7f (other layouts)
|
||||
# idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index
|
||||
flag = (pkt >> 3) & 1
|
||||
flag1 = (pkt >> 3) & 1
|
||||
flag2 = (pkt >> 7) & 1
|
||||
wave = (pkt >> 8) & 0x1F
|
||||
hi8 = (pkt >> 13)
|
||||
op = (pkt >> 13)
|
||||
fields.append(f"wave={wave:x}")
|
||||
assert flag == 0 and flag2 == 0, "non 0 flags in 0x18"
|
||||
#fields.append(f"flag={flag:x}")
|
||||
#fields.append(f"flag2={flag2:x}")
|
||||
# hi8 values:
|
||||
# SALU = 0x0 / s_mov_b32
|
||||
# SMEM = 0x1 / s_load_b*
|
||||
# NEXT = 0x4 / s_cbranch_execz
|
||||
# MESSAGE = 0x9 / s_sendmsg
|
||||
# VALU = 0xb / v_(exp,log)_f32_e32
|
||||
# VALU = 0xd / v_lshlrev_b64
|
||||
# VMEM = 0x21 / global_load_b32
|
||||
# VMEM = 0x22 / global_load_b32
|
||||
# VMEM = 0x24 / global_store_b32
|
||||
# VMEM = 0x25 / global_store_b64
|
||||
# LDS = 0x29 / ds_load_b128
|
||||
# LDS = 0x2b / ds_store_b32
|
||||
# ???? = 0x5a / hidden global_load instruction
|
||||
# ???? = 0x5b / hidden global_load instruction
|
||||
# ???? = 0x5c / hidden global_store instruction
|
||||
# VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST)
|
||||
fields.append(f"hi8=0x{hi8:x}")
|
||||
fields.append(f"op=0x{op:02x} [{OPNAME.get(op, '')}]")
|
||||
if flag1: fields.append("flag1")
|
||||
if flag2: fields.append("flag2")
|
||||
case 0x14:
|
||||
subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10)
|
||||
val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20)
|
||||
@@ -364,21 +436,21 @@ def decode_packet_fields(opcode: int, reg: int) -> str:
|
||||
if layout == 4:
|
||||
fields.append(f"layout4_flag={flag4}")
|
||||
case _:
|
||||
fields.append(f"& {reg_mask(opcode):X}")
|
||||
fields.append(f"{pkt:X} & {reg_mask(opcode):X}")
|
||||
return ",".join(fields)
|
||||
|
||||
FILTER_LEVEL = getenv("FILTER", 1)
|
||||
|
||||
DEFAULT_FILTER = tuple()
|
||||
DEFAULT_FILTER: tuple[int, ...] = tuple()
|
||||
# NOP + pure time + "sample"
|
||||
if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf, 0x11)
|
||||
# reg + event + sample + marker
|
||||
# TODO: events are probably good
|
||||
if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x14, 0x12, 0x16)
|
||||
# instruction runs
|
||||
if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x02, 0x03)
|
||||
# instructions dispatch (inst, valuinst, immed)
|
||||
if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x01, 0x4, 0x5, 0x18)
|
||||
# instruction runs + valuinst
|
||||
if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x01, 0x02, 0x03)
|
||||
# instructions dispatch (inst, immed)
|
||||
if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x4, 0x5, 0x18)
|
||||
# waves
|
||||
if FILTER_LEVEL >= 4: DEFAULT_FILTER += (0x6, 0x8, 0x9)
|
||||
|
||||
@@ -450,12 +522,7 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) -
|
||||
token_index += 1
|
||||
|
||||
if verbose and (filter is None or opcode not in filter):
|
||||
print(
|
||||
f"{time:8d}+{time-last_printed_time:8d} : "
|
||||
f"op={opcode:2x} "
|
||||
f"{OPCODE_NAMES[opcode]:24s} "
|
||||
f"{reg®_mask(opcode):16X} "
|
||||
f"{note}")
|
||||
print(f"{time:8d} +{time-last_printed_time:8d} : "+colored(f"{OPCODE_NAMES[opcode]:18s} ", OPCODE_COLORS.get(opcode, "white"))+f"{note}")
|
||||
last_printed_time = time
|
||||
|
||||
# Optional summary at the end
|
||||
|
||||
@@ -35,9 +35,9 @@ class TestTiny(unittest.TestCase):
|
||||
out = Tensor.cat(Tensor.ones(8).contiguous(), Tensor.zeros(8).contiguous())
|
||||
self.assertListEqual(out.tolist(), [1]*8+[0]*8)
|
||||
|
||||
def test_sum(self):
|
||||
out = Tensor.ones(256).contiguous().sum()
|
||||
self.assertEqual(out.item(), 256)
|
||||
def test_sum(self, N=getenv("SUM_N", 256)):
|
||||
out = Tensor.ones(N).contiguous().sum()
|
||||
self.assertEqual(out.item(), N)
|
||||
|
||||
def test_gemm(self, N=getenv("GEMM_N", 64), out_dtype=dtypes.float):
|
||||
a = Tensor.ones(N,N).contiguous()
|
||||
|
||||
Reference in New Issue
Block a user