diff --git a/test/amd/test_sqtt_examples.py b/test/amd/test_sqtt_examples.py index b48ffea783..dfbf36c29d 100644 --- a/test/amd/test_sqtt_examples.py +++ b/test/amd/test_sqtt_examples.py @@ -10,7 +10,7 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp from tinygrad.renderer.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST, IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, PACKET_TYPES_CDNA, CDNA_WAVESTART, - InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND) + InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND, CDNA_TIMESTAMP) from test.amd.helpers import TARGET_TO_ARCH import tinygrad @@ -122,7 +122,8 @@ class SQTTExamplesTestBase(unittest.TestCase): print(f"\n=== {name} event {i} ===") print_packets(packets) self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}") - self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}") + first_pkt = CDNA_TIMESTAMP if self.target.startswith("gfx9") else LAYOUT_HEADER + self.assertIsInstance(packets[0], first_pkt, f"first packet should be {first_pkt.__name__} in {name}") def test_packet_types_valid(self): all_classes = set(PACKET_TYPES_RDNA3.values()) | set(PACKET_TYPES_RDNA4.values()) | set(PACKET_TYPES_CDNA.values()) @@ -224,7 +225,6 @@ class TestSQTTExamplesRDNA4(SQTTExamplesTestBase): target = "gfx1200" class TestSQTTExamplesCDNA(SQTTExamplesTestBase): target = "gfx950" - def test_decode_all_examples(self): self.skipTest("TODO: correct deltas in the timestamp packet types, first packet is REGCS_CDNA") def test_gemm_has_instructions(self): self.skipTest("TODO: decode CDNA inst packets") def test_rocprof_wave_times_match(self): self.skipTest("TODO: requires timestamp patching") diff --git a/tinygrad/renderer/amd/sqtt.py b/tinygrad/renderer/amd/sqtt.py index ff251f9fd3..f2eca5d6e3 100644 --- a/tinygrad/renderer/amd/sqtt.py +++ b/tinygrad/renderer/amd/sqtt.py @@ -381,25 +381,27 @@ PACKET_TYPES_RDNA4: dict[int, type[PacketType]] = { # CDNA PACKET TYPE DEFINITIONS # ═══════════════════════════════════════════════════════════════════════════════ -class CDNA_DELTA(PacketType): - """pkt_fmt=0: 16-bit timestamp delta packet""" +class CDNA_MISC(PacketType): + """pkt_fmt=0: 16-bit (Misc)""" encoding = bits[3:0] == 0 - delta = bits[11:4] # (data >> 4) & 0xff - unk_0 = bits[12:12] # (data >> 0xc) & 1 - unk_1 = bits[15:13] # (data >> 0xd) + delta = bits[11:4] + sh = bits[12:12] + misc_type = bits[15:13] class CDNA_TIMESTAMP(PacketType): """pkt_fmt=1: 64-bit timestamp packet (case 0x0)""" encoding = bits[3:0] == 1 - unk_0 = bits[15:4] + _reserved = bits[15:4] timestamp = bits[63:16] # stored as (data_word >> 0x10) in low 46 bits of local_58 -class CDNA_PKT_2(PacketType): - """pkt_fmt=2: 64-bit packet (case 0x4)""" +class CDNA_REG(PacketType): + """pkt_fmt=2: 64-bit (Reg)""" encoding = bits[3:0] == 2 - unk_0 = bits[6:5] # (data >> 5) & 3 - unk_1 = bits[7:7] # (data >> 7) + 1 & 1 - unk_padding = bits[63:8] + pipe = bits[6:5] + _me_raw = bits[8:7] + _reserved = bits[15:9] + regaddr = bits[31:16] + regdata = bits[63:32] class CDNA_WAVESTART(PacketType): """type 3: 32-bit wave start (Wave/group_id)""" @@ -410,19 +412,19 @@ class CDNA_WAVESTART(PacketType): simd = bits[15:14] pipe = bits[17:16] me = bits[19:18] - _gap = bits[21:20] + _reserved = bits[21:20] count = bits[28:22] _padding = bits[31:29] -class CDNA_PKT_4(PacketType): - """pkt_fmt=4: 16-bit packet (case 0xc, same as 0x8/0x14)""" +class CDNA_WAVEALLOC(PacketType): + """pkt_fmt=4: 16-bit (Wave)""" encoding = bits[3:0] == 4 - unk_0 = bits[5:5] # (data_word >> 5) & 1 - unk_1 = bits[9:6] # (data_word >> 6) & 0xf - unk_2 = bits[13:10] # (data_word >> 10) & 0xf - unk_3 = bits[15:14] # (data_word >> 0xe) + sh = bits[5:5] + cu = bits[9:6] + wave = bits[13:10] + simd = bits[15:14] -class REGCS_CDNA(PacketType): +class CDNA_REG_CS(PacketType): """type 5: 48-bit register CS write (RegCs)""" encoding = bits[3:0] == 5 pipe = bits[6:5] @@ -438,80 +440,86 @@ class CDNA_WAVEEND(PacketType): wave = bits[13:10] simd = bits[15:14] -class CDNA_EXEC(PacketType): - """pkt_fmt=10: 16-bit EXEC packet (case 0x24)""" - encoding = bits[3:0] == 10 - unk_0 = bits[8:5] # (data_word >> 5) & 0xf - unk_1 = bits[10:9] # (data_word >> 9) & 3 - unk_2 = bits[15:11] # (data_word >> 0xb) - -class CDNA_PKT_11(PacketType): - """pkt_fmt=11: 64-bit packet (case 0x28)""" - encoding = bits[3:0] == 11 - unk_0 = bits[8:5] # (data_word >> 5) & 0xf - unk_1 = bits[10:9] # (data_word >> 9) & 3 - unk_2 = bits[15:15] # (data_word >> 0xf) & 1 - unk_padding = bits[63:16] - class CDNA_INST(PacketType): - """pkt_fmt=13: 32-bit INST packet (case 0x30)""" + """pkt_fmt=10: 16-bit (MsgInst)""" + encoding = bits[3:0] == 10 + wave = bits[8:5] + simd = bits[10:9] + inst_type = bits[15:11] + +class CDNA_INST_PC(PacketType): + """pkt_fmt=11: 64-bit (MsgInstPc)""" + encoding = bits[3:0] == 11 + wave = bits[8:5] + simd = bits[10:9] + _reserved = bits[14:11] + err = bits[15:15] + pc = bits[63:16] + +class CDNA_ISSUE(PacketType): + """pkt_fmt=13: 32-bit (Issue)""" encoding = bits[3:0] == 13 - unk_0 = bits[6:5] # (data >> 5) & 3 - unk_1 = bits[9:8] # (data >> 8) & 3 - unk_2 = bits[11:10] # (data >> 10) & 3 - unk_3 = bits[13:12] # (data >> 0xc) & 3 - unk_4 = bits[15:14] # (data >> 0xe) & 3 - unk_5 = bits[19:18] # (data >> 0x12) & 3 - unk_6 = bits[21:20] # (data >> 0x14) & 3 - unk_7 = bits[23:22] # (data >> 0x16) & 3 - unk_8 = bits[25:24] # (data >> 0x18) & 3 - unk_9 = bits[27:26] # (data >> 0x1a) & 3 - unk_padding = bits[31:28] + simd = bits[6:5] + _gap = bits[7:7] + inst0 = bits[9:8] + inst1 = bits[11:10] + inst2 = bits[13:12] + inst3 = bits[15:14] + inst4 = bits[17:16] + inst5 = bits[19:18] + inst6 = bits[21:20] + inst7 = bits[23:22] + inst8 = bits[25:24] + inst9 = bits[27:26] + _padding = bits[31:28] -class CDNA_PKT_14(PacketType): - """pkt_fmt=14: 64-bit packet (case 0x34)""" +class CDNA_PERF(PacketType): + """pkt_fmt=14: 64-bit (MsgPerf)""" encoding = bits[3:0] == 14 - unk_0 = bits[5:5] # (data >> 5) & 1 - unk_1 = bits[9:6] # (data >> 6) & 0xf - unk_2 = bits[11:10] # (data >> 10) & 3 - unk_3 = bits[24:12] # (data >> 0xc) & 0x1fff - unk_4 = bits[37:25] # (data >> 0x19) & 0x1fff - unk_5 = bits[50:38] # (data >> 0x26) & 0x1fff - unk_6 = bits[51:51] # (data >> 0x33) & 1 - unk_padding = bits[63:52] + sh = bits[5:5] + cu = bits[9:6] + cntr_bank = bits[11:10] + cntr0 = bits[24:12] + cntr1 = bits[37:25] + cntr2 = bits[50:38] + cntr3 = bits[63:51] -class CDNA_PKT_7(PacketType): - """pkt_fmt=7: 16-bit packet""" +class CDNA_EVENT(PacketType): + """pkt_fmt=7: 16-bit""" encoding = bits[3:0] == 7 - unk_padding = bits[15:4] + _reserved = bits[15:4] -class CDNA_PKT_8(PacketType): - """pkt_fmt=8: 16-bit packet""" +class CDNA_EVENT_CS(PacketType): + """pkt_fmt=8: 16-bit""" encoding = bits[3:0] == 8 - unk_padding = bits[15:4] + _reserved = bits[15:4] -class CDNA_PKT_9(PacketType): - """pkt_fmt=9: 16-bit packet""" +class CDNA_EVENT_GFX1(PacketType): + """pkt_fmt=9: 16-bit""" encoding = bits[3:0] == 9 - unk_padding = bits[15:4] + _reserved = bits[15:4] -class CDNA_PKT_12(PacketType): - """pkt_fmt=12: 48-bit packet""" +class CDNA_USERDATA(PacketType): + """pkt_fmt=12: 48-bit (UserData)""" encoding = bits[3:0] == 12 - unk_padding = bits[47:4] + sh = bits[5:5] + cu = bits[9:6] + wave = bits[13:10] + simd = bits[15:14] + data = bits[47:16] -class CDNA_PKT_15(PacketType): - """pkt_fmt=15: 48-bit packet (case 0x38, same as 0x10)""" +class CDNA_REG_CS_PRIV(PacketType): + """pkt_fmt=15: 48-bit (RegCs)""" encoding = bits[3:0] == 15 - unk_0 = bits[6:5] # (data >> 5) & 3 - unk_1 = bits[7:7] # (data >> 7) + 1 & 1 - unk_2 = bits[15:9] # (data >> 9) & 0x7f - unk_padding = bits[47:16] + pipe = bits[6:5] + _me_raw = bits[8:7] + regaddr = bits[15:9] + regdata = bits[47:16] PACKET_TYPES_CDNA: dict[int, type[PacketType]] = { - 0: CDNA_DELTA, 1: CDNA_TIMESTAMP, 2: CDNA_PKT_2, 3: CDNA_WAVESTART, 4: CDNA_PKT_4, 5: REGCS_CDNA, 6: CDNA_WAVEEND, - 7: CDNA_PKT_7, 8: CDNA_PKT_8, 9: CDNA_PKT_9, 10: CDNA_EXEC, 11: CDNA_PKT_11, 12: CDNA_PKT_12, - 13: CDNA_INST, 14: CDNA_PKT_14, 15: CDNA_PKT_15, + 0: CDNA_MISC, 1: CDNA_TIMESTAMP, 2: CDNA_REG, 3: CDNA_WAVESTART, 4: CDNA_WAVEALLOC, 5: CDNA_REG_CS, 6: CDNA_WAVEEND, + 7: CDNA_EVENT, 8: CDNA_EVENT_CS, 9: CDNA_EVENT_GFX1, 10: CDNA_INST, 11: CDNA_INST_PC, 12: CDNA_USERDATA, + 13: CDNA_ISSUE, 14: CDNA_PERF, 15: CDNA_REG_CS_PRIV, } # ═══════════════════════════════════════════════════════════════════════════════ @@ -523,8 +531,8 @@ def _build_decode_tables(packet_types: dict[int, type[PacketType]]) -> tuple[dic sorted_types = sorted(packet_types.items(), key=lambda x: (-bin(x[1].encoding.mask).count('1'), x[0] == 16)) state_table = bytes(next((op for op, cls in sorted_types if (b & cls.encoding.mask) == cls.encoding.default), 16) for b in range(256)) # Build decode info: opcode -> (pkt_cls, nib_count, delta_lo, delta_mask, special_case) - # special_case: 0=none, 1=TS_DELTA_OR_MARK (check is_marker), 2=TS_DELTA_SHORT (add 8), 3=CDNA_DELTA (*4), 4=CDNA_TIMESTAMP (absolute) - _special = {TS_DELTA_OR_MARK: 1, TS_DELTA_OR_MARK_RDNA4: 1, TS_DELTA_SHORT: 2, CDNA_DELTA: 3, CDNA_TIMESTAMP: 4} + # special_case: 0=none, 1=TS_DELTA_OR_MARK (check is_marker), 2=TS_DELTA_SHORT (add 8), 3=CDNA_MISC (*4), 4=CDNA_TIMESTAMP (absolute) + _special = {TS_DELTA_OR_MARK: 1, TS_DELTA_OR_MARK_RDNA4: 1, TS_DELTA_SHORT: 2, CDNA_MISC: 3, CDNA_TIMESTAMP: 4} decode_info = {} for opcode, pkt_cls in packet_types.items(): delta_field = getattr(pkt_cls, 'delta', None)