mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
finish rdna4 sqtt (#14903)
* unskip * it's a wave pair in rdna4 * work * that * hidden archive * generic s_delay, mystery InstOpRDNA4.UNK_60 * branch failing test * UNK_60 is OTHER_VMEM_STORE * rdna4 has both s_delay_alu and s_wait_alu * real branch failing test * rdna4 doesn't have JUMP_NO, it's NEXT with a flag for no jump * make inst_delay skips recursive * all rdna4 tests pass * simm16 unwraps * that has a name
This commit is contained in:
@@ -21,7 +21,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
|
||||
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
|
||||
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
|
||||
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
|
||||
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.UNK_60}
|
||||
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_STORE}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ROCPROF DECODER
|
||||
|
||||
@@ -72,7 +72,6 @@ class TestSQTTMapBase(unittest.TestCase):
|
||||
|
||||
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"
|
||||
|
||||
@unittest.skip("this doesn't work")
|
||||
class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200"
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -9,8 +9,7 @@ from dataclasses import dataclass
|
||||
from typing import Iterator
|
||||
from enum import Enum
|
||||
from tinygrad.renderer.amd.dsl import BitField, FixedBitField, Inst, bits
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, s_endpgm
|
||||
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
|
||||
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm # same encoding as RDNA4
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# FIELD ENUMS
|
||||
@@ -102,16 +101,16 @@ class InstOpRDNA4(Enum):
|
||||
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
|
||||
# TODO: we need to do discovery of all of these from instructions
|
||||
SALU = 0x0
|
||||
SMEM = 0x1
|
||||
UNK_02 = 0x2
|
||||
JUMP_NO = 0x4
|
||||
UNK_06 = 0x6
|
||||
JUMP = 0x1
|
||||
NEXT = 0x2
|
||||
MESSAGE = 0x4
|
||||
VALU_64 = 0x6
|
||||
VMEM = 0x10
|
||||
UNK_11 = 0x11
|
||||
VINTERP = 0x12
|
||||
UNK_14 = 0x14
|
||||
VMEM_128 = 0x11
|
||||
VMEM_STORE = 0x12
|
||||
VMEM_STORE_128 = 0x14
|
||||
OTHER_VMEM = 0x5e
|
||||
UNK_60 = 0x60
|
||||
OTHER_VMEM_STORE = 0x60
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PACKET TYPE BASE CLASS
|
||||
@@ -343,8 +342,12 @@ class INST_RDNA4(PacketType): # Layout 4: different delta position and InstOp e
|
||||
delta = bits[5:3]
|
||||
flag1 = bits[6:6]
|
||||
flag2 = bits[7:7]
|
||||
wave = bits[12:8]
|
||||
wave_pair = bits[11:8]
|
||||
flag3 = bits[12:12]
|
||||
op = bits[19:13].enum(InstOpRDNA4)
|
||||
# INST_RDNA4 wave_pair field (4 bits) addresses wave pairs, flag2 selects even/odd wave
|
||||
@property
|
||||
def wave(self): return self.wave_pair * 2 + self.flag2
|
||||
|
||||
class UTILCTR(PacketType):
|
||||
encoding = bits[6:0] == 0b0110001
|
||||
@@ -586,7 +589,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
||||
def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0
|
||||
for p in decode(data):
|
||||
if not simd_select(p): continue
|
||||
if isinstance(p, WAVESTART):
|
||||
if isinstance(p, (WAVESTART, WAVESTART_RDNA4)):
|
||||
assert p.wave not in wave_pc, "only one inflight wave per unit"
|
||||
wave_pc[p.wave] = next(iter(pc_map))
|
||||
continue
|
||||
@@ -595,33 +598,35 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I
|
||||
yield (p, InstructionInfo(pc, p.wave, s_endpgm()))
|
||||
continue
|
||||
# skip OTHER_ instructions, they don't belong to this unit
|
||||
if isinstance(p, INST) and p.op.name.startswith("OTHER_"): continue
|
||||
if isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_"): continue
|
||||
if isinstance(p, IMMEDIATE_MASK):
|
||||
# immediate mask may yield multiple times per packet
|
||||
for wave in range(16):
|
||||
if p.mask & (1 << wave):
|
||||
inst = pc_map[pc:=wave_pc[wave]]
|
||||
# can this assert be more strict?
|
||||
assert isinstance(inst, SOPP), f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
|
||||
assert type(inst).__name__ == "SOPP", f"IMMEDIATE_MASK packet must map to SOPP, got {inst}"
|
||||
wave_pc[wave] += inst.size()
|
||||
yield (p, InstructionInfo(pc, wave, inst))
|
||||
continue
|
||||
if isinstance(p, (VALUINST, INST, IMMEDIATE)):
|
||||
if isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)):
|
||||
inst = pc_map[pc:=wave_pc[p.wave]]
|
||||
# s_delay_alu doesn't get a packet?
|
||||
if isinstance(inst, SOPP) and inst.op in {SOPPOp.S_DELAY_ALU}:
|
||||
while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}:
|
||||
wave_pc[p.wave] += inst.size()
|
||||
inst = pc_map[pc:=wave_pc[p.wave]]
|
||||
# identify a branch instruction, only used for asserts
|
||||
branch_inst = inst if isinstance(inst, SOPP) and "BRANCH" in inst.op_name else None
|
||||
if branch_inst is not None: assert isinstance(p, INST) and p.op in {InstOp.JUMP_NO, InstOp.JUMP}, f"branch can only be folowed by JUMP, got {p}"
|
||||
branch_inst = inst if "BRANCH" in inst_op else None
|
||||
if branch_inst is not None:
|
||||
assert isinstance(p, (INST, INST_RDNA4)) and p.op.name in {"JUMP_NO", "JUMP", "NEXT"}, f"branch can only be folowed by JUMP, got {p}"
|
||||
# JUMP handling
|
||||
if isinstance(p, INST) and p.op is InstOp.JUMP:
|
||||
assert branch_inst is not None, f"JUMP packet must map to a branch instruction, got {inst}"
|
||||
x = branch_inst.simm16 & 0xffff
|
||||
if (isinstance(p, INST) and p.op is InstOp.JUMP) or (isinstance(p, INST_RDNA4) and branch_inst is not None and p.flag3):
|
||||
simm16 = getattr(branch_inst, 'simm16')
|
||||
assert branch_inst is not None and simm16 is not None, f"JUMP packet must map to a branch instruction, got {inst}"
|
||||
x = simm16 & 0xffff
|
||||
wave_pc[p.wave] += branch_inst.size() + (x - 0x10000 if x & 0x8000 else x)*4
|
||||
else:
|
||||
if branch_inst is not None: assert branch_inst.op != SOPPOp.S_BRANCH, f"S_BRANCH must have a JUMP packet, got {p}"
|
||||
if branch_inst is not None: assert inst_op != "S_BRANCH", f"S_BRANCH must have a JUMP packet, got {p}"
|
||||
wave_pc[p.wave] += inst.size()
|
||||
yield (p, InstructionInfo(pc, p.wave, inst))
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user