diff --git a/test/amd/test_sqtt_examples.py b/test/amd/test_sqtt_examples.py index 58ad0926ed..1f2634be8d 100644 --- a/test/amd/test_sqtt_examples.py +++ b/test/amd/test_sqtt_examples.py @@ -21,7 +21,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR, InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128, InstOp.OTHER_GLOBAL_STORE_VADDR_128} -OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.UNK_60} +OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_STORE} # ═══════════════════════════════════════════════════════════════════════════════ # ROCPROF DECODER diff --git a/test/amd/test_sqttmap.py b/test/amd/test_sqttmap.py index b7c760fd55..e9c5d42753 100644 --- a/test/amd/test_sqttmap.py +++ b/test/amd/test_sqttmap.py @@ -72,7 +72,6 @@ class TestSQTTMapBase(unittest.TestCase): class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100" -@unittest.skip("this doesn't work") class TestSQTTMapRDNA4(TestSQTTMapBase): target = "gfx1200" if __name__ == "__main__": diff --git a/tinygrad/renderer/amd/sqtt.py b/tinygrad/renderer/amd/sqtt.py index 99f1b0be86..474a67a959 100644 --- a/tinygrad/renderer/amd/sqtt.py +++ b/tinygrad/renderer/amd/sqtt.py @@ -9,8 +9,7 @@ from dataclasses import dataclass from typing import Iterator from enum import Enum from tinygrad.renderer.amd.dsl import BitField, FixedBitField, Inst, bits -from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP, s_endpgm -from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp +from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm # same encoding as RDNA4 # ═══════════════════════════════════════════════════════════════════════════════ # FIELD ENUMS @@ -102,16 +101,16 @@ class InstOpRDNA4(Enum): """SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3.""" # TODO: we need to do discovery of all of these from instructions SALU = 0x0 - SMEM = 0x1 - UNK_02 = 0x2 - JUMP_NO = 0x4 - UNK_06 = 0x6 + JUMP = 0x1 + NEXT = 0x2 + MESSAGE = 0x4 + VALU_64 = 0x6 VMEM = 0x10 - UNK_11 = 0x11 - VINTERP = 0x12 - UNK_14 = 0x14 + VMEM_128 = 0x11 + VMEM_STORE = 0x12 + VMEM_STORE_128 = 0x14 OTHER_VMEM = 0x5e - UNK_60 = 0x60 + OTHER_VMEM_STORE = 0x60 # ═══════════════════════════════════════════════════════════════════════════════ # PACKET TYPE BASE CLASS @@ -343,8 +342,12 @@ class INST_RDNA4(PacketType): # Layout 4: different delta position and InstOp e delta = bits[5:3] flag1 = bits[6:6] flag2 = bits[7:7] - wave = bits[12:8] + wave_pair = bits[11:8] + flag3 = bits[12:12] op = bits[19:13].enum(InstOpRDNA4) + # INST_RDNA4 wave_pair field (4 bits) addresses wave pairs, flag2 selects even/odd wave + @property + def wave(self): return self.wave_pair * 2 + self.flag2 class UTILCTR(PacketType): encoding = bits[6:0] == 0b0110001 @@ -586,7 +589,7 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I def simd_select(p) -> bool: return getattr(p, "cu", 0) == 0 and getattr(p, "simd", 0) == 0 for p in decode(data): if not simd_select(p): continue - if isinstance(p, WAVESTART): + if isinstance(p, (WAVESTART, WAVESTART_RDNA4)): assert p.wave not in wave_pc, "only one inflight wave per unit" wave_pc[p.wave] = next(iter(pc_map)) continue @@ -595,33 +598,35 @@ def map_insts(data:bytes, lib:bytes, target:str) -> Iterator[tuple[PacketType, I yield (p, InstructionInfo(pc, p.wave, s_endpgm())) continue # skip OTHER_ instructions, they don't belong to this unit - if isinstance(p, INST) and p.op.name.startswith("OTHER_"): continue + if isinstance(p, (INST, INST_RDNA4)) and p.op.name.startswith("OTHER_"): continue if isinstance(p, IMMEDIATE_MASK): # immediate mask may yield multiple times per packet for wave in range(16): if p.mask & (1 << wave): inst = pc_map[pc:=wave_pc[wave]] # can this assert be more strict? - assert isinstance(inst, SOPP), f"IMMEDIATE_MASK packet must map to SOPP, got {inst}" + assert type(inst).__name__ == "SOPP", f"IMMEDIATE_MASK packet must map to SOPP, got {inst}" wave_pc[wave] += inst.size() yield (p, InstructionInfo(pc, wave, inst)) continue - if isinstance(p, (VALUINST, INST, IMMEDIATE)): + if isinstance(p, (VALUINST, INST, INST_RDNA4, IMMEDIATE)): inst = pc_map[pc:=wave_pc[p.wave]] # s_delay_alu doesn't get a packet? - if isinstance(inst, SOPP) and inst.op in {SOPPOp.S_DELAY_ALU}: + while (inst_op:=getattr(inst, 'op_name', '')) in {"S_DELAY_ALU", "S_WAIT_ALU"}: wave_pc[p.wave] += inst.size() inst = pc_map[pc:=wave_pc[p.wave]] # identify a branch instruction, only used for asserts - branch_inst = inst if isinstance(inst, SOPP) and "BRANCH" in inst.op_name else None - if branch_inst is not None: assert isinstance(p, INST) and p.op in {InstOp.JUMP_NO, InstOp.JUMP}, f"branch can only be folowed by JUMP, got {p}" + branch_inst = inst if "BRANCH" in inst_op else None + if branch_inst is not None: + assert isinstance(p, (INST, INST_RDNA4)) and p.op.name in {"JUMP_NO", "JUMP", "NEXT"}, f"branch can only be folowed by JUMP, got {p}" # JUMP handling - if isinstance(p, INST) and p.op is InstOp.JUMP: - assert branch_inst is not None, f"JUMP packet must map to a branch instruction, got {inst}" - x = branch_inst.simm16 & 0xffff + if (isinstance(p, INST) and p.op is InstOp.JUMP) or (isinstance(p, INST_RDNA4) and branch_inst is not None and p.flag3): + simm16 = getattr(branch_inst, 'simm16') + assert branch_inst is not None and simm16 is not None, f"JUMP packet must map to a branch instruction, got {inst}" + x = simm16 & 0xffff wave_pc[p.wave] += branch_inst.size() + (x - 0x10000 if x & 0x8000 else x)*4 else: - if branch_inst is not None: assert branch_inst.op != SOPPOp.S_BRANCH, f"S_BRANCH must have a JUMP packet, got {p}" + if branch_inst is not None: assert inst_op != "S_BRANCH", f"S_BRANCH must have a JUMP packet, got {p}" wave_pc[p.wave] += inst.size() yield (p, InstructionInfo(pc, p.wave, inst)) continue