diff --git a/extra/sqtt/active_sqtt_parse.py b/extra/sqtt/active_sqtt_parse.py index fdbdbf4717..80f2984849 100644 --- a/extra/sqtt/active_sqtt_parse.py +++ b/extra/sqtt/active_sqtt_parse.py @@ -88,74 +88,34 @@ def run_asm(src, num_workgroups=1, num_waves=1): fxn(buf._buf, global_size=(num_workgroups,1,1), local_size=(WAVE_SIZE*num_waves,1,1), wait=True) if __name__ == "__main__": - """ with save_sqtt() as sqtt: run_asm([ - #"s_barrier", - #"s_nop 0", - #"s_nop 0", - #"s_nop 15", - #"s_nop 0", - #"s_nop 0", - "s_nop 0", - "s_nop 0", + "s_nop 100", + "s_nop 100", "s_load_b64 s[0:1], s[0:1], null", "s_waitcnt lgkmcnt(0)", - "s_nop 0", - "s_nop 0", "s_nop 100", "s_nop 100", - "s_nop 100", - "s_nop 0", - "s_nop 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v0, 0", - "v_mov_b32_e32 v1, 0", - "s_nop 0", - "s_nop 0", + "s_add_i32 s2, s2, 10", + "s_add_i32 s2, s2, 10", "s_nop 100", "s_nop 100", - "s_nop 100", - "s_nop 0", - "s_nop 0", - "global_load_b32 v2, v0, s[0:1]", - "global_load_b32 v3, v0, s[0:1]", - "global_load_b32 v4, v0, s[0:1]", - "global_load_b32 v5, v0, s[0:1]", - "global_load_b32 v5, v0, s[0:1]", - "global_load_b32 v2, v0, s[0:1]", - "global_load_b32 v2, v1, s[0:1]", - "global_load_b32 v2, v1, s[0:1]", - "global_load_b32 v2, v1, s[0:1]", - "global_load_b32 v2, v0, s[0:1]", - "global_load_b32 v2, v0, s[0:1]", - "global_load_b32 v2, v0, s[0:1]", - "global_load_b32 v2, v0, s[0:1]", - "global_load_b32 v2, v0, s[0:1]", - "global_load_b32 v2, v0, s[0:1]", - "s_nop 0", - "s_nop 0", - "s_nop 0", - "s_waitcnt vmcnt(0)", + "v_mov_b32_e32 v0, 0", + "v_mov_b32_e32 v0, 0", "s_nop 100", "s_nop 100", + "v_dual_fmac_f32 v2, v48, v24 :: v_dual_fmac_f32 v9, v37, v51", + "v_dual_fmac_f32 v2, v48, v24 :: v_dual_fmac_f32 v9, v37, v51", "s_nop 100", - "s_nop 0", - "s_nop 0", - #"v_add_f32_e32 v1 v0 v0", - #"s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", + "s_nop 100", + "global_load_b128 v[2:5], v0, s[0:1]", + "global_load_b128 v[2:5], v0, s[0:1]", + "s_nop 100", + "s_nop 100", + "s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)", "s_endpgm", ], num_workgroups=1, num_waves=1) exit(0) - """ with save_sqtt() as sqtt: #(Tensor.empty(16,16) @ Tensor.empty(16,16)).elu().realize() diff --git a/extra/sqtt/attempt_sqtt_parse.py b/extra/sqtt/attempt_sqtt_parse.py index af2b565713..9bf5d1a62f 100644 --- a/extra/sqtt/attempt_sqtt_parse.py +++ b/extra/sqtt/attempt_sqtt_parse.py @@ -22,6 +22,24 @@ from extra.sqtt.roc import decode, ProfileSQTTEvent # NOTE: INST runs before EXEC +OPCODE_COLORS = { + # dispatches are BLACK + 0x1: "BLACK", + 0x18: "BLACK", + + # execs are yellow + 0x2: "yellow", + 0x3: "yellow", + 0x4: "YELLOW", + 0x5: "YELLOW", + + # waves are blue + 0x8: "blue", + 0x9: "blue", + 0x6: "cyan", + 0xb: "cyan", +} + OPCODE_NAMES = { # gated by SQ_TT_TOKEN_EXCLUDE_VALUINST_SHIFT (but others must be enabled for it to show) 0x01: "VALUINST", @@ -31,16 +49,21 @@ OPCODE_NAMES = { 0x03: "ALUEXEC", # gated by SQ_TT_TOKEN_EXCLUDE_IMMEDIATE_SHIFT 0x04: "IMMEDIATE", - 0x05: "IMMEDIATE_MULTIWAVE", + 0x05: "IMMEDIATE_MASK", + # gated by SQ_TT_TOKEN_EXCLUDE_WAVERDY_SHIFT 0x06: "WAVERDY", # gated by SQ_TT_TOKEN_EXCLUDE_WAVESTARTEND_SHIFT 0x08: "WAVEEND", 0x09: "WAVESTART", + # gated by SQ_TT_TOKEN_EXCLUDE_WAVEALLOC_SHIFT + 0x0B: "WAVEALLOC", # FFF00 + # gated by NOT SQ_TT_TOKEN_EXCLUDE_PERF_SHIFT 0x0D: "PERF", # gated by SQ_TT_TOKEN_EXCLUDE_EVENT_SHIFT 0x12: "EVENT", + 0x13: "EVENT_BIG", # FFFFF800 # some gated by SQ_TT_TOKEN_EXCLUDE_REG_SHIFT, some always there. something is broken with the timing on this 0x14: "REG", # gated by SQ_TT_TOKEN_EXCLUDE_INST_SHIFT @@ -49,25 +72,91 @@ OPCODE_NAMES = { 0x19: "UTILCTR", # this is the first (8 byte) packet in the bitstream - 0x17: "LAYOUT_MODE_HEADER", # layout/mode/group + selectors A/B (reversed) + 0x17: "LAYOUT_HEADER", # layout/mode/group + selectors A/B (reversed) # pure time (no extra bits) - 0x0F: "TS_DELTA_SHORT_PLUS4", + 0x0F: "TS_DELTA_SHORT", 0x10: "NOP", - 0x11: "TS_WAVE_STATE_SAMPLE", # almost pure time, has a small flag + 0x11: "TS_WAVE_STATE", # almost pure time, has a small flag # not a good name, but seen and understood mostly - 0x15: "PERFCOUNTER_SNAPSHOT", # small delta + 50-ish bits of snapshot - 0x16: "TS_DELTA36_OR_MARK", # 36-bit long delta or 36-bit marker + 0x15: "SNAPSHOT", # small delta + 50-ish bits of snapshot + 0x16: "TS_DELTA_OR_MARK", # 36-bit long delta or 36-bit marker # packets we haven't seen / rarely see 0x0b - 0x07: "TS_DELTA_S8_W3", # shift=8, width=3 (small delta) + 0x07: "TS_DELTA_S8_W3_7", # shift=8, width=3 (small delta) 0x0A: "TS_DELTA_S5_W2_A", # shift=5, width=2 - 0x0B: "TS_DELTA_S5_W3_A", # shift=5, width=3 0x0C: "TS_DELTA_S5_W3_B", # shift=5, width=3 (different consumer) - 0x13: "EVT_SMALL_GENERIC", # same structural family as 0x08/0x12/0x19 } +# SALU = 0x0 / s_mov_b32 +# SMEM = 0x1 / s_load_b* +# JUMP = 0x3 / s_cbranch_scc0 +# NEXT = 0x4 / s_cbranch_execz +# MESSAGE = 0x9 / s_sendmsg +# VALU = 0xb / v_(exp,log)_f32_e32 +# VALU = 0xd / v_lshlrev_b64 +# VALU = 0xe / v_mad_u64_u32 +# VMEM = 0x21 / global_load_b32 +# VMEM = 0x22 / global_load_b32 +# VMEM = 0x24 / global_store_b32 +# VMEM = 0x25 / global_store_b64 +# VMEM = 0x27 / global_store +# VMEM = 0x28 / global_store_b64 +# LDS = 0x29 / ds_load_b128 +# LDS = 0x2b / ds_store_b32 +# LDS = 0x2e / ds_store_b128 +# ???? = 0x5a / hidden global_load instruction +# ???? = 0x5b / hidden global_load instruction +# ???? = 0x5c / hidden global_store instruction +# VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST) +OPNAME = { + 0x0: "SALU", + 0x1: "SMEM", + 0x3: "JUMP", + 0x4: "NEXT", + 0x9: "MESSAGE", + 0xb: "VALU", + 0xd: "VALU", + 0xe: "VALU", + 0x10: "__END", + 0x21: "VMEM_LOAD", + 0x22: "VMEM_LOAD", + 0x24: "VMEM_STORE", + 0x25: "VMEM_STORE", + 0x26: "VMEM_STORE", + 0x27: "VMEM_STORE", + 0x28: "VMEM_STORE", + 0x29: "LDS_LOAD", + 0x2b: "LDS_STORE", + 0x2e: "LDS_STORE", + 0x50: "__SIMD_LDS_LOAD", + 0x51: "__SIMD_LDS_LOAD", + 0x54: "__SIMD_LDS_STORE", + 0x5a: "__SIMD_VMEM_LOAD", + 0x5b: "__SIMD_VMEM_LOAD", + 0x5c: "__SIMD_VMEM_STORE", + 0x5d: "__SIMD_VMEM_STORE", + 0x5e: "__SIMD_VMEM_STORE", + 0x5f: "__SIMD_VMEM_STORE", + 0x72: "SALU_OR", + 0x73: "VALU_CMPX", +} + +ALUSRC = { + 1: "SALU", + 2: "VALU", + 3: "VALU_ALT", +} + +MEMSRC = { + 0: "LDS", + 1: "__LDS", + 2: "VMEM", + 3: "__VMEM", +} + + # these tables are from rocprof trace decoder # rocprof_trace_decoder_parse_data-0x11c6a0 # parse_sqtt_180 = b *rocprof_trace_decoder_parse_data-0x11c6a0+0x110040 @@ -193,14 +282,15 @@ def decode_packet_fields(opcode: int, reg: int) -> str: flag = (pkt >> 6) & 1 wave = pkt >> 7 fields.append(f"wave={wave:x}") - assert flag == 0, "non 0 flag in 0x1" - #fields.append(f"flag={flag:X}") + if flag: fields.append("flag") case 0x02: # VMEMEXEC # 2 bit field (pipe is a guess) - fields.append(f"pipe={pkt>>6:X}") + src = pkt>>6 + fields.append(f"src={src} [{MEMSRC.get(src, '')}]") case 0x03: # ALUEXEC - # 2 bit field (pipe is a guess) - fields.append(f"pipe={pkt>>6:X}") + # 2 bit field + src = pkt>>6 + fields.append(f"src={src} [{ALUSRC.get(src, '')}]") case 0x04: # IMMEDIATE_4 # 5 bit field (actually 4) wave = pkt >> 7 @@ -208,12 +298,12 @@ def decode_packet_fields(opcode: int, reg: int) -> str: case 0x05: # IMMEDIATE_5 # 16 bit field # 1 bit per wave - fields.append(f"mask={pkt>>8:16b}") + fields.append(f"mask={pkt>>8:016b}") case 0x6: # wave ready FFFF00 # 16 bit field # 1 bit per wave - fields.append(f"mask={pkt>>8:16b}") + fields.append(f"mask={pkt>>8:016b}") case 0x0d: # 20 bit field fields.append(f"arg = {pkt>>8:X}") @@ -274,32 +364,14 @@ def decode_packet_fields(opcode: int, reg: int) -> str: # hi8 = (w >> 0xc) & 0xff (layout 4 path) # hi7 = (w >> 0xd) & 0x7f (other layouts) # idx5 = (w >> 7) or (w >> 8) & 0x1f, used as wave index - flag = (pkt >> 3) & 1 + flag1 = (pkt >> 3) & 1 flag2 = (pkt >> 7) & 1 wave = (pkt >> 8) & 0x1F - hi8 = (pkt >> 13) + op = (pkt >> 13) fields.append(f"wave={wave:x}") - assert flag == 0 and flag2 == 0, "non 0 flags in 0x18" - #fields.append(f"flag={flag:x}") - #fields.append(f"flag2={flag2:x}") - # hi8 values: - # SALU = 0x0 / s_mov_b32 - # SMEM = 0x1 / s_load_b* - # NEXT = 0x4 / s_cbranch_execz - # MESSAGE = 0x9 / s_sendmsg - # VALU = 0xb / v_(exp,log)_f32_e32 - # VALU = 0xd / v_lshlrev_b64 - # VMEM = 0x21 / global_load_b32 - # VMEM = 0x22 / global_load_b32 - # VMEM = 0x24 / global_store_b32 - # VMEM = 0x25 / global_store_b64 - # LDS = 0x29 / ds_load_b128 - # LDS = 0x2b / ds_store_b32 - # ???? = 0x5a / hidden global_load instruction - # ???? = 0x5b / hidden global_load instruction - # ???? = 0x5c / hidden global_store instruction - # VALU = 0x73 / v_cmpx_eq_u32_e32 (not normal VALUINST) - fields.append(f"hi8=0x{hi8:x}") + fields.append(f"op=0x{op:02x} [{OPNAME.get(op, '')}]") + if flag1: fields.append("flag1") + if flag2: fields.append("flag2") case 0x14: subop = (pkt >> 16) & 0xFFFF # (short)(w >> 0x10) val32 = (pkt >> 32) & 0xFFFFFFFF # (uint)(w >> 0x20) @@ -364,21 +436,21 @@ def decode_packet_fields(opcode: int, reg: int) -> str: if layout == 4: fields.append(f"layout4_flag={flag4}") case _: - fields.append(f"& {reg_mask(opcode):X}") + fields.append(f"{pkt:X} & {reg_mask(opcode):X}") return ",".join(fields) FILTER_LEVEL = getenv("FILTER", 1) -DEFAULT_FILTER = tuple() +DEFAULT_FILTER: tuple[int, ...] = tuple() # NOP + pure time + "sample" if FILTER_LEVEL >= 0: DEFAULT_FILTER += (0x10, 0xf, 0x11) # reg + event + sample + marker # TODO: events are probably good if FILTER_LEVEL >= 1: DEFAULT_FILTER += (0x14, 0x12, 0x16) -# instruction runs -if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x02, 0x03) -# instructions dispatch (inst, valuinst, immed) -if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x01, 0x4, 0x5, 0x18) +# instruction runs + valuinst +if FILTER_LEVEL >= 2: DEFAULT_FILTER += (0x01, 0x02, 0x03) +# instructions dispatch (inst, immed) +if FILTER_LEVEL >= 3: DEFAULT_FILTER += (0x4, 0x5, 0x18) # waves if FILTER_LEVEL >= 4: DEFAULT_FILTER += (0x6, 0x8, 0x9) @@ -450,12 +522,7 @@ def parse_sqtt_print_packets(data: bytes, filter=DEFAULT_FILTER, verbose=True) - token_index += 1 if verbose and (filter is None or opcode not in filter): - print( - f"{time:8d}+{time-last_printed_time:8d} : " - f"op={opcode:2x} " - f"{OPCODE_NAMES[opcode]:24s} " - f"{reg®_mask(opcode):16X} " - f"{note}") + print(f"{time:8d} +{time-last_printed_time:8d} : "+colored(f"{OPCODE_NAMES[opcode]:18s} ", OPCODE_COLORS.get(opcode, "white"))+f"{note}") last_printed_time = time # Optional summary at the end diff --git a/test/test_tiny.py b/test/test_tiny.py index 74aae5676a..d83221a77a 100644 --- a/test/test_tiny.py +++ b/test/test_tiny.py @@ -35,9 +35,9 @@ class TestTiny(unittest.TestCase): out = Tensor.cat(Tensor.ones(8).contiguous(), Tensor.zeros(8).contiguous()) self.assertListEqual(out.tolist(), [1]*8+[0]*8) - def test_sum(self): - out = Tensor.ones(256).contiguous().sum() - self.assertEqual(out.item(), 256) + def test_sum(self, N=getenv("SUM_N", 256)): + out = Tensor.ones(N).contiguous().sum() + self.assertEqual(out.item(), N) def test_gemm(self, N=getenv("GEMM_N", 64), out_dtype=dtypes.float): a = Tensor.ones(N,N).contiguous()