sqtt: CDNA inst decodes (#15274)

* sqtt: CDNA inst decodes

* JUMP packets other way

* cdna insts

* r3

* r4

* lds from simd1 and simd2
This commit is contained in:
qazal
2026-03-14 14:03:46 +02:00
committed by GitHub
parent d753c5d7e5
commit 3858bfc83d
27 changed files with 18 additions and 15 deletions

View File

@@ -8,7 +8,7 @@ PROFILE_PATH = Path(temp("profile.pkl", append_user=True))
EXAMPLES = {
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
"plus":"test/test_tiny.py TestTiny.test_plus",
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=64, N)@Tensor.empty(N, N)).realize()\"",
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
"ops":"extra/sqtt/examples/discover_ops.py"
}

View File

@@ -10,7 +10,7 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
from tinygrad.renderer.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, PACKET_TYPES_CDNA, CDNA_WAVESTART,
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND)
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND, CDNA_INST)
from test.amd.helpers import TARGET_TO_ARCH
import tinygrad
@@ -21,7 +21,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_5}
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_5, InstOpRDNA4.OTHER_LDS_1, InstOpRDNA4.OTHER_LDS_2}
# ═══════════════════════════════════════════════════════════════════════════════
# ROCPROF DECODER
@@ -153,9 +153,10 @@ class SQTTExamplesTestBase(unittest.TestCase):
if "gemm" not in name: continue
with self.subTest(example=name):
all_packets = [p for e in events for p in decode(e.blob)]
inst_names = [p.op.name for p in all_packets if isinstance(p, (INST, INST_RDNA4))]
self.assertGreater(len(inst_names), 0, f"no INST packets in {name}")
self.assertGreater(len([n for n in inst_names if n.startswith("JUMP")]), 0, f"no JUMP packets in {name}")
inst_packets = [p for p in all_packets if isinstance(p, (INST, INST_RDNA4, CDNA_INST))]
self.assertGreater(len(inst_packets), 0, f"no INST packets in {name}")
if isinstance(inst_packets[0], (INST, INST_RDNA4)):
self.assertGreater(len([p for p in inst_packets if p.op.name.startswith("JUMP")]), 0, f"no JUMP packets in {name}")
expected: dict[str, list[int]] = {} # override in subclasses
def test_packet_counts(self):
@@ -210,22 +211,22 @@ class SQTTExamplesTestBase(unittest.TestCase):
class TestSQTTExamplesRDNA3(SQTTExamplesTestBase):
target = "gfx1100"
expected = {
"profile_empty_run_0": [1974, 1961, 2014, 2065, 2092, 1998],
"profile_empty_run_1": [1979, 1972, 2019, 2070, 2097, 2003],
"profile_gemm_run_0": [2038, 11076, 2324, 2129, 2156, 2062],
"profile_gemm_run_1": [2038, 11037, 2318, 2129, 2156, 2062],
"profile_ops_run_0": [2038, 5070, 2078, 2129, 2156, 2062],
"profile_ops_run_1": [2038, 5007, 2078, 2129, 2156, 2062],
"profile_plus_run_0": [1979, 1979, 2030, 2070, 2097, 2003],
"profile_plus_run_1": [1979, 2043, 2030, 2070, 2097, 2003],
"profile_empty_run_0": [1880, 1867, 1920, 1971, 1998, 1904],
"profile_empty_run_1": [1880, 1867, 1920, 1971, 1998, 1904],
"profile_gemm_run_0": [3275, 3278, 2426, 2475, 2511, 2431],
"profile_gemm_run_1": [3264, 3268, 2420, 2469, 2504, 2401],
"profile_ops_run_0": [1944, 4903, 1984, 2035, 2062, 1968],
"profile_ops_run_1": [1944, 4918, 1984, 2035, 2062, 1968],
"profile_plus_run_0": [1938, 1932, 1978, 2029, 2056, 1962],
"profile_plus_run_1": [1891, 1874, 1931, 1982, 2009, 1915],
}
class TestSQTTExamplesRDNA4(SQTTExamplesTestBase): target = "gfx1200"
class TestSQTTExamplesCDNA(SQTTExamplesTestBase):
target = "gfx950"
def test_gemm_has_instructions(self): self.skipTest("TODO: decode CDNA inst packets")
def test_rocprof_wave_times_match(self): self.skipTest("TODO: requires timestamp patching")
def test_rocprof_inst_times_match(self): self.skipTest("TODO: requires timestamp patching")
if __name__ == "__main__":
unittest.main()

View File

@@ -126,6 +126,8 @@ class InstOpRDNA4(Enum):
LDS_WR_3 = 0x2c
LDS_WR_4 = 0x2d
LDS_WR_5 = 0x2e
OTHER_LDS_1 = 0x50
OTHER_LDS_2 = 0x51
WMMA_8 = 0x8c
WMMA_16 = 0x8d
VALU_DPFP = 0x92