finish rdna4 sqtt (#14903)

* unskip

* it's a wave pair in rdna4

* work

* that

* hidden archive

* generic s_delay, mystery InstOpRDNA4.UNK_60

* branch failing test

* UNK_60 is OTHER_VMEM_STORE

* rdna4 has both s_delay_alu and s_wait_alu

* real branch failing test

* rdna4 doesn't have JUMP_NO, it's NEXT with a flag for no jump

* make inst_delay skips recursive

* all rdna4 tests pass

* simm16 unwraps

* that has a name
This commit is contained in:
qazal
2026-02-20 15:06:13 +08:00
committed by GitHub
parent 52b51a0324
commit 16ae96fa58
3 changed files with 28 additions and 24 deletions

View File

@@ -21,7 +21,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.UNK_60}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_STORE}
# ═══════════════════════════════════════════════════════════════════════════════
# ROCPROF DECODER

View File

@@ -72,7 +72,6 @@ class TestSQTTMapBase(unittest.TestCase):
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
@unittest.skip("this doesn't work")
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
if __name__ == "__main__":

View File

@@ -9,8 +9,7 @@ from dataclasses import dataclass
from typing import Iterator
from enum import Enum
from tinygrad.renderer.amd.dsl import BitField, FixedBitField, Inst, bits
from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, s_endpgm
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm # same encoding as RDNA4
# ═══════════════════════════════════════════════════════════════════════════════
# FIELD ENUMS
@@ -102,16 +101,16 @@ class InstOpRDNA4(Enum):
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
# TODO: we need to do discovery of all of these from instructions
SALU = 0x0
SMEM = 0x1
UNK_02 = 0x2
JUMP_NO = 0x4
UNK_06 = 0x6
JUMP = 0x1
NEXT = 0x2
MESSAGE = 0x4
VALU_64 = 0x6
VMEM = 0x10
UNK_11 = 0x11
VINTERP = 0x12
UNK_14 = 0x14
VMEM_128 = 0x11
VMEM_STORE = 0x12
VMEM_STORE_128 = 0x14
OTHER_VMEM = 0x5e
UNK_60 = 0x60
OTHER_VMEM_STORE = 0x60
# ═══════════════════════════════════════════════════════════════════════════════
# PACKET TYPE BASE CLASS
@@ -343,8 +342,12 @@ class INST_RDNA4(PacketType): # Layout 4: different delta position and InstOp e
delta = bits[5:3]
flag1 = bits[6:6]
flag2 = bits[7:7]
wave = bits[12:8]
wave_pair = bits[11:8]
flag3 = bits[12:12]
op = bits[19:13].enum(InstOpRDNA4)
# INST_RDNA4 wave_pair field (4 bits) addresses wave pairs, flag2 selects even/odd wave
@property
def wave(self): return self.wave_pair * 2 + self.flag2
class UTILCTR(PacketType):
encoding = bits[6:0] == 0b0110001
@@ -586,7 +589,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0
for p in decode(data):
if not simd_select(p): continue
if isinstance(p, WAVESTART):
if isinstance(p, (WAVESTART, WAVESTART_RDNA4)):
assert p.wave not in wave_pc, "only one inflight wave per unit"
wave_pc[p.wave] = next(iter(pc_map))
continue
@@ -595,33 +598,35 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
yield (p, InstructionInfo(pc, p.wave, s_endpgm()))
continue
# skip OTHER_ instructions, they don't belong to this unit
if isinstance(p, INST) and p.op.name.startswith("OTHER_"): continue
if isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_"): continue
if isinstance(p, IMMEDIATE_MASK):
# immediate mask may yield multiple times per packet
for wave in range(16):
if p.mask & (1 << wave):
inst = pc_map[pc:=wave_pc[wave]]
# can this assert be more strict?
assert isinstance(inst, SOPP), f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
assert type(inst).__name__ == "SOPP", f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
wave_pc[wave] += inst.size()
yield (p, InstructionInfo(pc, wave, inst))
continue
if isinstance(p, (VALUINST, INST, IMMEDIATE)):
if isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
inst = pc_map[pc:=wave_pc[p.wave]]
# s_delay_alu doesn't get a packet?
if isinstance(inst, SOPP) and inst.op in {SOPPOp.S_DELAY_ALU}:
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}:
wave_pc[p.wave] += inst.size()
inst = pc_map[pc:=wave_pc[p.wave]]
# identify a branch instruction, only used for asserts
branch_inst = inst if isinstance(inst, SOPP) and "BRANCH" in inst.op_name else None
if branch_inst is not None: assert isinstance(p, INST) and p.op in {InstOp.JUMP_NO, InstOp.JUMP}, f"branch can only be folowed by JUMP, got {p}"
branch_inst = inst if "BRANCH" in inst_op else None
if branch_inst is not None:
assert isinstance(p, (INST, INST_RDNA4)) and p.op.name in {"JUMP_NO", "JUMP", "NEXT"}, f"branch can only be folowed by JUMP, got {p}"
# JUMP handling
if isinstance(p, INST) and p.op is InstOp.JUMP:
assert branch_inst is not None, f"JUMP packet must map to a branch instruction, got {inst}"
x = branch_inst.simm16 & 0xffff
if (isinstance(p, INST) and p.op is InstOp.JUMP) or (isinstance(p, INST_RDNA4) and branch_inst is not None and p.flag3):
simm16 = getattr(branch_inst, 'simm16')
assert branch_inst is not None and simm16 is not None, f"JUMP packet must map to a branch instruction, got {inst}"
x = simm16 & 0xffff
wave_pc[p.wave] += branch_inst.size() + (x - 0x10000 if x & 0x8000 else x)*4
else:
if branch_inst is not None: assert branch_inst.op != SOPPOp.S_BRANCH, f"S_BRANCH must have a JUMP packet, got {p}"
if branch_inst is not None: assert inst_op != "S_BRANCH", f"S_BRANCH must have a JUMP packet, got {p}"
wave_pc[p.wave] += inst.size()
yield (p, InstructionInfo(pc, p.wave, inst))
continue