sqtt: rdna4 decoder work (#15434)

* sqtt: rdna4 decoder work

* diff cleanup

* more diff

* test

* work

* works

* TS_DELTA_SHORT
This commit is contained in:
qazal
2026-03-23 20:49:32 +02:00
committed by GitHub
parent 109472c37e
commit a590eded87
2 changed files with 45 additions and 14 deletions

View File

@@ -10,18 +10,11 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
from tinygrad.renderer.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, PACKET_TYPES_CDNA, CDNA_WAVESTART,
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND, CDNA_INST)
print_packets, CDNA_WAVEEND, CDNA_INST)
from test.amd.helpers import TARGET_TO_ARCH
import tinygrad
EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples"
# INST ops for non-traced SIMDs (excluded from instruction count)
OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LDS_STORE_64, InstOp.OTHER_LDS_STORE_128,
InstOp.OTHER_FLAT_LOAD, InstOp.OTHER_FLAT_STORE, InstOp.OTHER_FLAT_STORE_64, InstOp.OTHER_FLAT_STORE_96,
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_5, InstOpRDNA4.OTHER_LDS_1, InstOpRDNA4.OTHER_LDS_2}
# ═══════════════════════════════════════════════════════════════════════════════
# ROCPROF DECODER
@@ -207,8 +200,8 @@ class SQTTExamplesTestBase(unittest.TestCase):
our_insts: list[int] = []
for event in events:
for p in decode(event.blob):
if isinstance(p, INST) and p.op not in OTHER_SIMD_OPS: our_insts.append(p._time)
elif isinstance(p, INST_RDNA4) and p.op not in OTHER_SIMD_OPS_RDNA4: our_insts.append(p._time)
# INST ops for non-traced SIMDs (excluded from instruction count)
if isinstance(p, (INST, INST_RDNA4)) and not p.op.name.startswith("OTHER_"): our_insts.append(p._time)
elif isinstance(p, VALUINST): our_insts.append(p._time)
elif isinstance(p, IMMEDIATE): our_insts.append(p._time)
elif isinstance(p, IMMEDIATE_MASK):

View File

@@ -105,14 +105,25 @@ class InstOpRDNA4(Enum):
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
SALU = 0x0
SMEM = 0x1
SMEM_WR = 0x2
JUMP = 0x3
JUMP_NO = 0x4
CALL = 0x5
SALU_NO_EXEC = 0x7
MESSAGE = 0x9
VALU_1 = 0xa
VALU_TRANS = 0xb
VALU_B1 = 0xc
VALU_B2 = 0xd
VALU_B4 = 0xe
VALU_B16 = 0xf
VINTERP = 0x12
BARRIER_WAIT = 0x13
FLAT_RD_2 = 0x1c
FLAT_WR_3 = 0x1d
FLAT_WR_4 = 0x1e
FLAT_WR_5 = 0x1f
FLAT_WR_6 = 0x20
VMEM_RD_1 = 0x21
VMEM_RD_2 = 0x22
VMEM_WR_1 = 0x23
@@ -127,18 +138,45 @@ class InstOpRDNA4(Enum):
LDS_WR_3 = 0x2c
LDS_WR_4 = 0x2d
LDS_WR_5 = 0x2e
BUF_RD_1 = 0x2f
BUF_RD_2 = 0x30
BUF_WR_1 = 0x31
BUF_WR_2 = 0x32
BUF_WR_3 = 0x33
BUF_WR_4 = 0x34
BUF_WR_5 = 0x35
BUF_WR_6 = 0x36
OTHER_LDS_1 = 0x50
OTHER_LDS_2 = 0x51
OTHER_LDS_3 = 0x52
OTHER_LDS_4 = 0x53
OTHER_LDS_5 = 0x54
OTHER_FLAT_2 = 0x55
OTHER_FLAT_3 = 0x56
OTHER_FLAT_4 = 0x57
OTHER_FLAT_5 = 0x58
OTHER_FLAT_6 = 0x59
LDS_DIR_LOAD = 0x6e
LDS_PARAM_LOAD = 0x6f
SALU_WR_EXEC = 0x72
VALU1_WR_EXEC = 0x73
VALU_B2_WR_EXEC = 0x74
OTHER_LDS_6 = 0x77
OTHER_LDS_10 = 0x78
BARRIER_SIGNAL = 0x7a
DYN_VGPR = 0x87
BARRIER_JOIN = 0x8a
WMMA_8 = 0x8c
WMMA_16 = 0x8d
WMMA_32 = 0x8e
WMMA_64 = 0x8f
VALU_DPFP = 0x92
SALU_FLOAT3 = 0x98
VALU_SCL_TRANS = 0x99
SALU_2 = 0x9b
SALU_5 = 0x9c
OTHER_VMEM = 0xbd
OTHER_VMEM_5 = 0xc1
OTHER_VMEM = 0xbc # 0xbc-0xdd: vmem_other_simd
for _i in range(34): InstOpRDNA4._value2member_map_[0xbc + _i] = InstOpRDNA4.OTHER_VMEM
class InstOpCDNA(Enum):
SMEM_RD = 0
@@ -650,8 +688,8 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
yield (p, InstructionInfo(pc, wave, inst))
elif isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
inst = pc_map[pc:=wave_pc[p.wave]]
# s_delay_alu and s_wait_alu instructions are skipped
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}:
# s_delay_alu, s_wait_alu and s_barrier_wait instructions are skipped
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU", "S_BARRIER_WAIT"}:
wave_pc[p.wave] += inst.size()
inst = pc_map[pc:=wave_pc[p.wave]]
# assert branch always has a JUMP packet