sqtt: update cdna packet names (#15243)

* sqtt: update cdna packet names

* change

* order
This commit is contained in:
qazal
2026-03-13 01:49:09 +02:00
committed by GitHub
parent 749162bd2f
commit d893b14193
2 changed files with 90 additions and 82 deletions

View File

@@ -10,7 +10,7 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
from tinygrad.renderer.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, PACKET_TYPES_CDNA, CDNA_WAVESTART,
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND)
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND, CDNA_TIMESTAMP)
from test.amd.helpers import TARGET_TO_ARCH
import tinygrad
@@ -122,7 +122,8 @@ class SQTTExamplesTestBase(unittest.TestCase):
print(f"\n=== {name} event {i} ===")
print_packets(packets)
self.assertGreater(len(packets), 0, f"no packets decoded from {name} event {i}")
self.assertIsInstance(packets[0], LAYOUT_HEADER, f"first packet should be LAYOUT_HEADER in {name}")
first_pkt = CDNA_TIMESTAMP if self.target.startswith("gfx9") else LAYOUT_HEADER
self.assertIsInstance(packets[0], first_pkt, f"first packet should be {first_pkt.__name__} in {name}")
def test_packet_types_valid(self):
all_classes = set(PACKET_TYPES_RDNA3.values()) | set(PACKET_TYPES_RDNA4.values()) | set(PACKET_TYPES_CDNA.values())
@@ -224,7 +225,6 @@ class TestSQTTExamplesRDNA4(SQTTExamplesTestBase): target = "gfx1200"
class TestSQTTExamplesCDNA(SQTTExamplesTestBase):
target = "gfx950"
def test_decode_all_examples(self): self.skipTest("TODO: correct deltas in the timestamp packet types, first packet is REGCS_CDNA")
def test_gemm_has_instructions(self): self.skipTest("TODO: decode CDNA inst packets")
def test_rocprof_wave_times_match(self): self.skipTest("TODO: requires timestamp patching")

View File

@@ -381,25 +381,27 @@ PACKET_TYPES_RDNA4: dict[int, type[PacketType]] = {
# CDNA PACKET TYPE DEFINITIONS
# ═══════════════════════════════════════════════════════════════════════════════
class CDNA_DELTA(PacketType):
"""pkt_fmt=0: 16-bit timestamp delta packet"""
class CDNA_MISC(PacketType):
"""pkt_fmt=0: 16-bit (Misc)"""
encoding = bits[3:0] == 0
delta = bits[11:4] # (data >> 4) & 0xff
unk_0 = bits[12:12] # (data >> 0xc) & 1
unk_1 = bits[15:13] # (data >> 0xd)
delta = bits[11:4]
sh = bits[12:12]
misc_type = bits[15:13]
class CDNA_TIMESTAMP(PacketType):
"""pkt_fmt=1: 64-bit timestamp packet (case 0x0)"""
encoding = bits[3:0] == 1
unk_0 = bits[15:4]
_reserved = bits[15:4]
timestamp = bits[63:16] # stored as (data_word >> 0x10) in low 46 bits of local_58
class CDNA_PKT_2(PacketType):
"""pkt_fmt=2: 64-bit packet (case 0x4)"""
class CDNA_REG(PacketType):
"""pkt_fmt=2: 64-bit (Reg)"""
encoding = bits[3:0] == 2
unk_0 = bits[6:5] # (data >> 5) & 3
unk_1 = bits[7:7] # (data >> 7) + 1 & 1
unk_padding = bits[63:8]
pipe = bits[6:5]
_me_raw = bits[8:7]
_reserved = bits[15:9]
regaddr = bits[31:16]
regdata = bits[63:32]
class CDNA_WAVESTART(PacketType):
"""type 3: 32-bit wave start (Wave/group_id)"""
@@ -410,19 +412,19 @@ class CDNA_WAVESTART(PacketType):
simd = bits[15:14]
pipe = bits[17:16]
me = bits[19:18]
_gap = bits[21:20]
_reserved = bits[21:20]
count = bits[28:22]
_padding = bits[31:29]
class CDNA_PKT_4(PacketType):
"""pkt_fmt=4: 16-bit packet (case 0xc, same as 0x8/0x14)"""
class CDNA_WAVEALLOC(PacketType):
"""pkt_fmt=4: 16-bit (Wave)"""
encoding = bits[3:0] == 4
unk_0 = bits[5:5] # (data_word >> 5) & 1
unk_1 = bits[9:6] # (data_word >> 6) & 0xf
unk_2 = bits[13:10] # (data_word >> 10) & 0xf
unk_3 = bits[15:14] # (data_word >> 0xe)
sh = bits[5:5]
cu = bits[9:6]
wave = bits[13:10]
simd = bits[15:14]
class REGCS_CDNA(PacketType):
class CDNA_REG_CS(PacketType):
"""type 5: 48-bit register CS write (RegCs)"""
encoding = bits[3:0] == 5
pipe = bits[6:5]
@@ -438,80 +440,86 @@ class CDNA_WAVEEND(PacketType):
wave = bits[13:10]
simd = bits[15:14]
class CDNA_EXEC(PacketType):
"""pkt_fmt=10: 16-bit EXEC packet (case 0x24)"""
encoding = bits[3:0] == 10
unk_0 = bits[8:5] # (data_word >> 5) & 0xf
unk_1 = bits[10:9] # (data_word >> 9) & 3
unk_2 = bits[15:11] # (data_word >> 0xb)
class CDNA_PKT_11(PacketType):
"""pkt_fmt=11: 64-bit packet (case 0x28)"""
encoding = bits[3:0] == 11
unk_0 = bits[8:5] # (data_word >> 5) & 0xf
unk_1 = bits[10:9] # (data_word >> 9) & 3
unk_2 = bits[15:15] # (data_word >> 0xf) & 1
unk_padding = bits[63:16]
class CDNA_INST(PacketType):
"""pkt_fmt=13: 32-bit INST packet (case 0x30)"""
"""pkt_fmt=10: 16-bit (MsgInst)"""
encoding = bits[3:0] == 10
wave = bits[8:5]
simd = bits[10:9]
inst_type = bits[15:11]
class CDNA_INST_PC(PacketType):
"""pkt_fmt=11: 64-bit (MsgInstPc)"""
encoding = bits[3:0] == 11
wave = bits[8:5]
simd = bits[10:9]
_reserved = bits[14:11]
err = bits[15:15]
pc = bits[63:16]
class CDNA_ISSUE(PacketType):
"""pkt_fmt=13: 32-bit (Issue)"""
encoding = bits[3:0] == 13
unk_0 = bits[6:5] # (data >> 5) & 3
unk_1 = bits[9:8] # (data >> 8) & 3
unk_2 = bits[11:10] # (data >> 10) & 3
unk_3 = bits[13:12] # (data >> 0xc) & 3
unk_4 = bits[15:14] # (data >> 0xe) & 3
unk_5 = bits[19:18] # (data >> 0x12) & 3
unk_6 = bits[21:20] # (data >> 0x14) & 3
unk_7 = bits[23:22] # (data >> 0x16) & 3
unk_8 = bits[25:24] # (data >> 0x18) & 3
unk_9 = bits[27:26] # (data >> 0x1a) & 3
unk_padding = bits[31:28]
simd = bits[6:5]
_gap = bits[7:7]
inst0 = bits[9:8]
inst1 = bits[11:10]
inst2 = bits[13:12]
inst3 = bits[15:14]
inst4 = bits[17:16]
inst5 = bits[19:18]
inst6 = bits[21:20]
inst7 = bits[23:22]
inst8 = bits[25:24]
inst9 = bits[27:26]
_padding = bits[31:28]
class CDNA_PKT_14(PacketType):
"""pkt_fmt=14: 64-bit packet (case 0x34)"""
class CDNA_PERF(PacketType):
"""pkt_fmt=14: 64-bit (MsgPerf)"""
encoding = bits[3:0] == 14
unk_0 = bits[5:5] # (data >> 5) & 1
unk_1 = bits[9:6] # (data >> 6) & 0xf
unk_2 = bits[11:10] # (data >> 10) & 3
unk_3 = bits[24:12] # (data >> 0xc) & 0x1fff
unk_4 = bits[37:25] # (data >> 0x19) & 0x1fff
unk_5 = bits[50:38] # (data >> 0x26) & 0x1fff
unk_6 = bits[51:51] # (data >> 0x33) & 1
unk_padding = bits[63:52]
sh = bits[5:5]
cu = bits[9:6]
cntr_bank = bits[11:10]
cntr0 = bits[24:12]
cntr1 = bits[37:25]
cntr2 = bits[50:38]
cntr3 = bits[63:51]
class CDNA_PKT_7(PacketType):
"""pkt_fmt=7: 16-bit packet"""
class CDNA_EVENT(PacketType):
"""pkt_fmt=7: 16-bit"""
encoding = bits[3:0] == 7
unk_padding = bits[15:4]
_reserved = bits[15:4]
class CDNA_PKT_8(PacketType):
"""pkt_fmt=8: 16-bit packet"""
class CDNA_EVENT_CS(PacketType):
"""pkt_fmt=8: 16-bit"""
encoding = bits[3:0] == 8
unk_padding = bits[15:4]
_reserved = bits[15:4]
class CDNA_PKT_9(PacketType):
"""pkt_fmt=9: 16-bit packet"""
class CDNA_EVENT_GFX1(PacketType):
"""pkt_fmt=9: 16-bit"""
encoding = bits[3:0] == 9
unk_padding = bits[15:4]
_reserved = bits[15:4]
class CDNA_PKT_12(PacketType):
"""pkt_fmt=12: 48-bit packet"""
class CDNA_USERDATA(PacketType):
"""pkt_fmt=12: 48-bit (UserData)"""
encoding = bits[3:0] == 12
unk_padding = bits[47:4]
sh = bits[5:5]
cu = bits[9:6]
wave = bits[13:10]
simd = bits[15:14]
data = bits[47:16]
class CDNA_PKT_15(PacketType):
"""pkt_fmt=15: 48-bit packet (case 0x38, same as 0x10)"""
class CDNA_REG_CS_PRIV(PacketType):
"""pkt_fmt=15: 48-bit (RegCs)"""
encoding = bits[3:0] == 15
unk_0 = bits[6:5] # (data >> 5) & 3
unk_1 = bits[7:7] # (data >> 7) + 1 & 1
unk_2 = bits[15:9] # (data >> 9) & 0x7f
unk_padding = bits[47:16]
pipe = bits[6:5]
_me_raw = bits[8:7]
regaddr = bits[15:9]
regdata = bits[47:16]
PACKET_TYPES_CDNA: dict[int, type[PacketType]] = {
0: CDNA_DELTA, 1: CDNA_TIMESTAMP, 2: CDNA_PKT_2, 3: CDNA_WAVESTART, 4: CDNA_PKT_4, 5: REGCS_CDNA, 6: CDNA_WAVEEND,
7: CDNA_PKT_7, 8: CDNA_PKT_8, 9: CDNA_PKT_9, 10: CDNA_EXEC, 11: CDNA_PKT_11, 12: CDNA_PKT_12,
13: CDNA_INST, 14: CDNA_PKT_14, 15: CDNA_PKT_15,
0: CDNA_MISC, 1: CDNA_TIMESTAMP, 2: CDNA_REG, 3: CDNA_WAVESTART, 4: CDNA_WAVEALLOC, 5: CDNA_REG_CS, 6: CDNA_WAVEEND,
7: CDNA_EVENT, 8: CDNA_EVENT_CS, 9: CDNA_EVENT_GFX1, 10: CDNA_INST, 11: CDNA_INST_PC, 12: CDNA_USERDATA,
13: CDNA_ISSUE, 14: CDNA_PERF, 15: CDNA_REG_CS_PRIV,
}
# ═══════════════════════════════════════════════════════════════════════════════
@@ -523,8 +531,8 @@ def _build_decode_tables(packet_types: dict[int, type[PacketType]]) -> tuple[dic
sorted_types = sorted(packet_types.items(), key=lambda x: (-bin(x[1].encoding.mask).count('1'), x[0] == 16))
state_table = bytes(next((op for op, cls in sorted_types if (b & cls.encoding.mask) == cls.encoding.default), 16) for b in range(256))
# Build decode info: opcode -> (pkt_cls, nib_count, delta_lo, delta_mask, special_case)
# special_case: 0=none, 1=TS_DELTA_OR_MARK (check is_marker), 2=TS_DELTA_SHORT (add 8), 3=CDNA_DELTA (*4), 4=CDNA_TIMESTAMP (absolute)
_special = {TS_DELTA_OR_MARK: 1, TS_DELTA_OR_MARK_RDNA4: 1, TS_DELTA_SHORT: 2, CDNA_DELTA: 3, CDNA_TIMESTAMP: 4}
# special_case: 0=none, 1=TS_DELTA_OR_MARK (check is_marker), 2=TS_DELTA_SHORT (add 8), 3=CDNA_MISC (*4), 4=CDNA_TIMESTAMP (absolute)
_special = {TS_DELTA_OR_MARK: 1, TS_DELTA_OR_MARK_RDNA4: 1, TS_DELTA_SHORT: 2, CDNA_MISC: 3, CDNA_TIMESTAMP: 4}
decode_info = {}
for opcode, pkt_cls in packet_types.items():
delta_field = getattr(pkt_cls, 'delta', None)