mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
sqtt: CDNA inst decodes (#15274)
* sqtt: CDNA inst decodes * JUMP packets other way * cdna insts * r3 * r4 * lds from simd1 and simd2
This commit is contained in:
@@ -8,7 +8,7 @@ PROFILE_PATH = Path(temp("profile.pkl", append_user=True))
|
||||
EXAMPLES = {
|
||||
"empty":"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
|
||||
"plus":"test/test_tiny.py TestTiny.test_plus",
|
||||
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=64, N)@Tensor.empty(N, N)).realize()\"",
|
||||
"gemm":"-c \"from tinygrad import Tensor; (Tensor.empty(N:=32, N)@Tensor.empty(N, N)).realize()\"",
|
||||
"ops":"extra/sqtt/examples/discover_ops.py"
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -10,7 +10,7 @@ from tinygrad.runtime.autogen.amd.rdna3.ins import SOPP
|
||||
from tinygrad.runtime.autogen.amd.rdna3.enum import SOPPOp
|
||||
from tinygrad.renderer.amd.sqtt import (decode, LAYOUT_HEADER, WAVESTART, WAVESTART_RDNA4, WAVEEND, INST, INST_RDNA4, VALUINST,
|
||||
IMMEDIATE, IMMEDIATE_MASK, PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4, PACKET_TYPES_CDNA, CDNA_WAVESTART,
|
||||
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND)
|
||||
InstOp, InstOpRDNA4, print_packets, CDNA_WAVEEND, CDNA_INST)
|
||||
from test.amd.helpers import TARGET_TO_ARCH
|
||||
|
||||
import tinygrad
|
||||
@@ -21,7 +21,7 @@ OTHER_SIMD_OPS = {InstOp.OTHER_LDS_LOAD, InstOp.OTHER_LDS_STORE, InstOp.OTHER_LD
|
||||
InstOp.OTHER_FLAT_STORE_128, InstOp.OTHER_GLOBAL_LOAD, InstOp.OTHER_GLOBAL_LOAD_VADDR,
|
||||
InstOp.OTHER_GLOBAL_STORE_64, InstOp.OTHER_GLOBAL_STORE_96, InstOp.OTHER_GLOBAL_STORE_128,
|
||||
InstOp.OTHER_GLOBAL_STORE_VADDR_128}
|
||||
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_5}
|
||||
OTHER_SIMD_OPS_RDNA4 = {InstOpRDNA4.OTHER_VMEM, InstOpRDNA4.OTHER_VMEM_5, InstOpRDNA4.OTHER_LDS_1, InstOpRDNA4.OTHER_LDS_2}
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# ROCPROF DECODER
|
||||
@@ -153,9 +153,10 @@ class SQTTExamplesTestBase(unittest.TestCase):
|
||||
if "gemm" not in name: continue
|
||||
with self.subTest(example=name):
|
||||
all_packets = [p for e in events for p in decode(e.blob)]
|
||||
inst_names = [p.op.name for p in all_packets if isinstance(p, (INST, INST_RDNA4))]
|
||||
self.assertGreater(len(inst_names), 0, f"no INST packets in {name}")
|
||||
self.assertGreater(len([n for n in inst_names if n.startswith("JUMP")]), 0, f"no JUMP packets in {name}")
|
||||
inst_packets = [p for p in all_packets if isinstance(p, (INST, INST_RDNA4, CDNA_INST))]
|
||||
self.assertGreater(len(inst_packets), 0, f"no INST packets in {name}")
|
||||
if isinstance(inst_packets[0], (INST, INST_RDNA4)):
|
||||
self.assertGreater(len([p for p in inst_packets if p.op.name.startswith("JUMP")]), 0, f"no JUMP packets in {name}")
|
||||
|
||||
expected: dict[str, list[int]] = {} # override in subclasses
|
||||
def test_packet_counts(self):
|
||||
@@ -210,22 +211,22 @@ class SQTTExamplesTestBase(unittest.TestCase):
|
||||
class TestSQTTExamplesRDNA3(SQTTExamplesTestBase):
|
||||
target = "gfx1100"
|
||||
expected = {
|
||||
"profile_empty_run_0": [1974, 1961, 2014, 2065, 2092, 1998],
|
||||
"profile_empty_run_1": [1979, 1972, 2019, 2070, 2097, 2003],
|
||||
"profile_gemm_run_0": [2038, 11076, 2324, 2129, 2156, 2062],
|
||||
"profile_gemm_run_1": [2038, 11037, 2318, 2129, 2156, 2062],
|
||||
"profile_ops_run_0": [2038, 5070, 2078, 2129, 2156, 2062],
|
||||
"profile_ops_run_1": [2038, 5007, 2078, 2129, 2156, 2062],
|
||||
"profile_plus_run_0": [1979, 1979, 2030, 2070, 2097, 2003],
|
||||
"profile_plus_run_1": [1979, 2043, 2030, 2070, 2097, 2003],
|
||||
"profile_empty_run_0": [1880, 1867, 1920, 1971, 1998, 1904],
|
||||
"profile_empty_run_1": [1880, 1867, 1920, 1971, 1998, 1904],
|
||||
"profile_gemm_run_0": [3275, 3278, 2426, 2475, 2511, 2431],
|
||||
"profile_gemm_run_1": [3264, 3268, 2420, 2469, 2504, 2401],
|
||||
"profile_ops_run_0": [1944, 4903, 1984, 2035, 2062, 1968],
|
||||
"profile_ops_run_1": [1944, 4918, 1984, 2035, 2062, 1968],
|
||||
"profile_plus_run_0": [1938, 1932, 1978, 2029, 2056, 1962],
|
||||
"profile_plus_run_1": [1891, 1874, 1931, 1982, 2009, 1915],
|
||||
}
|
||||
|
||||
class TestSQTTExamplesRDNA4(SQTTExamplesTestBase): target = "gfx1200"
|
||||
|
||||
class TestSQTTExamplesCDNA(SQTTExamplesTestBase):
|
||||
target = "gfx950"
|
||||
def test_gemm_has_instructions(self): self.skipTest("TODO: decode CDNA inst packets")
|
||||
def test_rocprof_wave_times_match(self): self.skipTest("TODO: requires timestamp patching")
|
||||
def test_rocprof_inst_times_match(self): self.skipTest("TODO: requires timestamp patching")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -126,6 +126,8 @@ class InstOpRDNA4(Enum):
|
||||
LDS_WR_3 = 0x2c
|
||||
LDS_WR_4 = 0x2d
|
||||
LDS_WR_5 = 0x2e
|
||||
OTHER_LDS_1 = 0x50
|
||||
OTHER_LDS_2 = 0x51
|
||||
WMMA_8 = 0x8c
|
||||
WMMA_16 = 0x8d
|
||||
VALU_DPFP = 0x92
|
||||
|
||||
Reference in New Issue
Block a user