sqtt: new packet types, add discovery script (#14960)

This commit is contained in:
qazal
2026-02-27 21:27:27 +02:00
committed by GitHub
parent 4e12fc3fe6
commit b8a55d5f68
9 changed files with 174 additions and 12 deletions

View File

@@ -0,0 +1,148 @@
#!/usr/bin/env python3
# Run all ALU and memory instructions in the ISA
import functools, inspect
from enum import Enum
from tinygrad import Tensor, Device, dtypes
from tinygrad.uop.ops import UOp, Ops, KernelInfo, AddrSpace
from tinygrad.renderer.amd.dsl import Inst, Reg, OPERANDS, SrcField, VGPRField, SGPRField, SSrcField, SBaseField, AlignedSGPRField, BitField
from tinygrad.renderer.amd.dsl import FixedBitField, EnumBitField, s, v, NULL, VCC_LO
from extra.gemm.amd_asm_matmul import Kernel
# skip instructions that mutate wave state (PC, EXEC, allocations, signals)
SKIP = {"S_SETPC_B64", "S_SWAPPC_B64", "S_RFE_B64", "S_BARRIER_SIGNAL_ISFIRST", "S_GET_BARRIER_STATE", "S_ALLOC_VGPR", "S_SLEEP_VAR", "S_GETPC_B64",
"S_SENDMSG_RTN_B32", "S_SENDMSG_RTN_B64"}
# skip barriers, s_waits, wrap level atomics, and ray tracing (bvh)
SKIP_SUBSTR = ["SAVEEXEC", "CMPX", "WREXEC", "MOVREL", "ATOMIC", "S_BUFFER_", "S_ATC_PROBE", "BARRIER", "S_WAITCNT", "BVH",
"DS_CMPSTORE_RTN", "DS_WRAP_RTN_B32", "DS_ORDERED_COUNT", "DS_GWS", "GS_REG", "GLOBAL_LOAD_LDS", "GLOBAL_STORE_BLOCK"]
ALU_FORMATS = {"VOP1", "VOP1_LIT", "VOP1_SDST", "VOP2", "VOP2_LIT", "VOP3", "VOP3_SDST", "VOP3SD", "VOP3P", "VOP3P_MFMA", "VOP3PX2",
"VOPC", "SOP1", "SOP1_LIT", "SOP2", "SOP2_LIT", "SOPC", "SOPC_LIT", "SOPK", "SOPK_LIT", "VINTERP"}
# intentionally not testing scratch memory ops
MEM_FORMATS = {"VGLOBAL", "GLOBAL", "SMEM", "DS"}
def should_skip(op:Enum) -> bool: return (name:=op.name) in SKIP or any(sub in name for sub in SKIP_SUBSTR)
# ** named register assignments
# ALU operands
ALU_VGPR_STRIDE = 16 # v[0], v[16], v[32], ... per ALU operand slot
ALU_SGPR_STRIDE = 4 # s[0], s[4], s[8], ... per ALU operand slot
# memory address registers
S_KERNARG_PTR = (0, 1)
S_BUF_PTR = (2, 3)
V_VADDR = (0, 1)
V_DS_ADDR = 0
# memory data registers
MEM_VGPR_BASE = 32 # v[32], v[48], ... for vdst/vdata/vsrc
MEM_VGPR_STRIDE = 16 # spacing between memory data vgpr slots
MEM_SGPR_BASE = 8 # s[8], s[10], ... for SMEM sdata
MEM_SGPR_STRIDE = 2 # spacing between memory data sgpr slots
# ** create an ALU instruction based on the operands
def create_alu_inst(op:Enum, builder:functools.partial[Inst]) -> Inst:
inst_cls, operands, slot = builder.func, OPERANDS[op], 0
kwargs:dict[str, Reg|int] = {}
for name, field in inst_cls._fields:
if isinstance(field, (FixedBitField, EnumBitField)): continue
nregs = max(1, operands[name][1] // 32) if name in operands else 1
is_sreg = name in operands and "SREG" in str(operands[name][2])
base_v, base_s = slot * ALU_VGPR_STRIDE, slot * ALU_SGPR_STRIDE
if name == "sdst" and isinstance(field, SGPRField): reg = VCC_LO
elif is_sreg and not isinstance(field, VGPRField): reg = VCC_LO
elif isinstance(field, VGPRField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v]
elif isinstance(field, SSrcField): reg = VCC_LO if nregs <= 2 else s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s]
elif isinstance(field, SGPRField): reg = s[base_s:base_s+nregs-1] if nregs > 1 else s[base_s]
elif isinstance(field, SrcField): reg = v[base_v:base_v+nregs-1] if nregs > 1 else v[base_v]
else: reg = None
if reg is not None: kwargs[name] = reg; slot += 1
elif isinstance(field, BitField): kwargs[name] = field.default
return builder(**kwargs)
# ** create a memory instruction with pre set address registers
MEM_PRESET_REGS:dict[str, dict[str, Reg]] = {
"VGLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "vaddr":v[V_VADDR[0]:V_VADDR[1]]},
"GLOBAL":{"saddr":s[S_BUF_PTR[0]:S_BUF_PTR[1]], "addr":v[V_DS_ADDR]}, # addr is 32-bit offset when saddr is valid SGPR
"DS":{"addr":v[V_DS_ADDR]},
"SMEM":{"sbase":s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], "soffset":NULL},
}
def create_mem_inst(op:Enum, builder:functools.partial[Inst]) -> Inst:
inst_cls, operands, field_map = builder.func, OPERANDS.get(op, {}), MEM_PRESET_REGS.get(builder.func.__name__, {})
kwargs:dict[str, Reg|int] = {}
vslot, sslot = 0, 0
for name, field in inst_cls._fields:
if isinstance(field, (FixedBitField, EnumBitField)): continue
if name in field_map:
kwargs[name] = field_map[name]
continue
nregs = max(1, operands[name][1] // 32) if name in operands else 1
if isinstance(field, VGPRField):
vi = MEM_VGPR_BASE + vslot * MEM_VGPR_STRIDE
kwargs[name] = v[vi:vi+nregs-1] if nregs > 1 else v[vi]
vslot += 1
elif isinstance(field, (SGPRField, AlignedSGPRField, SBaseField)):
si = MEM_SGPR_BASE + sslot * MEM_SGPR_STRIDE
kwargs[name] = s[si:si+nregs-1] if nregs > 1 else s[si]
sslot += 1
elif isinstance(field, BitField): kwargs[name] = field.default
return builder(**kwargs)
# ** collect all memory and ALU instructions from the ISA autogen
def collect_instructions() -> tuple[list[Inst], list[Inst], list[str]]:
op_map:dict[Enum, functools.partial[Inst]] = {}
for name, obj in inspect.getmembers(all_insts):
if isinstance(obj, functools.partial) and len(obj.args) == 1: op_map[obj.args[0]] = obj
alu_insts:list[Inst] = []
mem_insts:list[Inst] = []
skipped:list[str] = []
for op_enum, builder in op_map.items():
if should_skip(op_enum) or op_enum not in OPERANDS: skipped.append(op_enum.name); continue
fmt = builder.func.__name__
if fmt in ALU_FORMATS: alu_insts.append(create_alu_inst(op_enum, builder))
elif fmt in MEM_FORMATS: mem_insts.append(create_mem_inst(op_enum, builder))
return alu_insts, mem_insts, skipped
def exec_insts(insts:list):
k = Kernel(arch)
# ** prologue for global memory
k.emit(s_load_b64(sdata=s[S_BUF_PTR[0]:S_BUF_PTR[1]], sbase=s[S_KERNARG_PTR[0]:S_KERNARG_PTR[1]], soffset=NULL))
k.waitcnt(lgkm=0)
k.emit(v_mov_b32_e32(v[V_VADDR[0]], 0))
k.emit(v_mov_b32_e32(v[V_VADDR[1]], 0))
# ** emit
for inst in insts: k.emit(inst)
k.emit(s_endpgm())
# ** run
NUM_THREADS, NUM_GRIDS, BUF_SIZE = 32, 1, 1024*1024
def fxn(A:UOp, B:UOp, C:UOp) -> UOp:
lidx, gidx = UOp.special(NUM_THREADS, "lidx0"), UOp.special(NUM_GRIDS, "gidx0")
lds = UOp(Ops.DEFINE_LOCAL, dtypes.uint8.ptr(size=BUF_SIZE, addrspace=AddrSpace.LOCAL), (), "lds")
sink = UOp.sink(A.base, B.base, C.base, lds, lidx, gidx, arg=KernelInfo(name="discover_ops"))
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg="AMD"), UOp(Ops.LINEAR, src=tuple(UOp(Ops.INS, arg=x) for x in k.finalize()))))
A = Tensor.empty(BUF_SIZE, dtype=dtypes.uint8)
B = Tensor.empty(1, dtype=dtypes.uint8)
C = Tensor.empty(1, dtype=dtypes.uint8)
Tensor.custom_kernel(A, B, C, fxn=fxn)[0].realize()
if __name__ == "__main__":
import sys
arch = Device[Device.DEFAULT].renderer.arch
if arch.startswith("gfx12"):
from tinygrad.runtime.autogen.amd.rdna4.ins import *
import tinygrad.runtime.autogen.amd.rdna4.ins as all_insts
elif arch.startswith("gfx11"):
from tinygrad.runtime.autogen.amd.rdna3.ins import *
import tinygrad.runtime.autogen.amd.rdna3.ins as all_insts
# these don"t exist in RDNA3, only RDNA3.5 and above
SKIP.update(["S_FMAAK_F32", "S_FMAMK_F32"])
else:
print(f"{arch} not supported yet")
sys.exit(0)
alu_insts, mem_insts, skipped = collect_instructions()
print(f"collected {len(alu_insts)} ALU + {len(mem_insts)} memory instructions ({len(skipped)} skipped)")
exec_insts(mem_insts+alu_insts)

View File

@@ -9,6 +9,7 @@ EXAMPLES = [
"test/backend/test_custom_kernel.py TestCustomKernel.test_empty",
"test/test_tiny.py TestTiny.test_plus",
"test/test_tiny.py TestTiny.test_gemm",
"extra/sqtt/examples/discover_ops.py"
]
if __name__ == "__main__":

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -134,7 +134,10 @@ class TestSQTTMatchesBinary(unittest.TestCase):
def _test_bit_counts(self, layout: int):
if not (tables := extract_bit_tables()): self.skipTest("rocprof-trace-decoder not installed")
from tinygrad.renderer.amd.sqtt import PACKET_TYPES_RDNA3, PACKET_TYPES_RDNA4
# rocprof's bit table says L4 type 7 (TS_DELTA_S8_W3) is 72 bits, but the actual decoder uses 64 bits
skip = {(4, 7)}
for type_id, pkt_cls in {3: PACKET_TYPES_RDNA3, 4: PACKET_TYPES_RDNA4}[layout].items():
if (layout, type_id) in skip: continue
with self.subTest(packet=pkt_cls.__name__):
self.assertEqual(pkt_cls._size_nibbles * 4, tables[layout - 2][type_id]) # type: ignore[attr-defined]

View File

@@ -2,7 +2,7 @@
import unittest, pickle
from typing import Iterator
from pathlib import Path
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, OSX
from tinygrad.renderer.amd.sqtt import print_packets, map_insts
from tinygrad.runtime.autogen.amd.rdna3.ins import s_endpgm
from test.amd.disasm import disasm
@@ -10,7 +10,7 @@ from test.amd.disasm import disasm
import tinygrad
EXAMPLES_DIR = Path(tinygrad.__file__).parent.parent / "extra/sqtt/examples"
def rocprof_inst_traces_match(sqtt, prg, target):
def rocprof_inst_traces_match(sqtt, prg, target, pass_rocprof_err=False):
from tinygrad.viz.serve import amd_decode
from extra.sqtt.roc import decode as roc_decode, InstExec
addr_table = amd_decode(prg.lib, target)
@@ -30,7 +30,7 @@ def rocprof_inst_traces_match(sqtt, prg, target):
rocprof_inst = next(rwaves_iter[info.wave][0])
ref_pc = rocprof_inst.pc-prg.base
# always check pc matches
assert ref_pc == info.pc, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
assert ref_pc == info.pc or pass_rocprof_err, f"pc mismatch {ref_pc}:{disasm_map[rocprof_inst.pc]} != {info.pc}:{disasm(info.inst)}"
# special handling for s_endpgm, it marks the wave completion.
if info.inst == s_endpgm():
completed_wave = list(rwaves_iter[info.wave].pop(0))
@@ -67,7 +67,9 @@ class TestSQTTMapBase(unittest.TestCase):
if not event.itrace: continue
if event.kern not in kern_events: continue
with self.subTest(example=name, kern=event.kern):
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target)
# rocprof OSX has a bug for sopk decoding, linux rocprof works
pass_rocprof_err = OSX and target == "gfx1200" and name.startswith("profile_py")
passed_insts, n_waves, n_units = rocprof_inst_traces_match(event, kern_events[event.kern], target, pass_rocprof_err)
if n_waves: print(f"{name}: passed for {passed_insts} instructions across {n_waves} waves scheduled on {n_units} wave units")
class TestSQTTMapRDNA3(TestSQTTMapBase): target = "gfx1100"

View File

@@ -46,6 +46,7 @@ class InstOp(Enum):
SMEM = 0x1
JUMP = 0x3 # branch taken
JUMP_NO = 0x4 # branch not taken
CALL = 0x5 # s_call_b64
MESSAGE = 0x9
VALU_TRANS = 0xb # transcendental: exp, log, rcp, sqrt, sin, cos
VALU_64_SHIFT = 0xd # 64-bit shifts: lshl, lshr, ashr
@@ -72,8 +73,10 @@ class InstOp(Enum):
# LDS ops on traced SIMD
LDS_LOAD = 0x29
LDS_ATOMIC = 0x2a # ds_append, ds_consume, ds_store_addtid_b32
LDS_STORE = 0x2b
LDS_STORE_64 = 0x2c
LDS_STORE_96 = 0x2d
LDS_STORE_128 = 0x2e
# Memory ops on other SIMD (0x5x range)
@@ -99,17 +102,27 @@ class InstOp(Enum):
class InstOpRDNA4(Enum):
"""SQTT instruction operation types for RDNA4 (gfx1200). Different encoding from RDNA3."""
# TODO: we need to do discovery of all of these from instructions
SALU = 0x0
JUMP = 0x1
NEXT = 0x2
MESSAGE = 0x4
VALU_TRANS = 0x5
VALU_64 = 0x6
VALU_MAD64 = 0x7
VINTERP = 0x9
VALU_WMMA = 0x46
VMEM = 0x10
VMEM_128 = 0x11
VMEM_STORE = 0x12
VMEM_STORE_128 = 0x14
VMEM_STORE_G96 = 0x13 # global_store_[b96,b128]
LDS_LOAD = 0x14
LDS_STORE = 0x15
LDS_STORE_64 = 0x16
LDS_STORE_128 = 0x17
VALU_F64 = 0x49
SALU_TRANS = 0x4c # transcendental with sgpr src/dst
SALU_MUL = 0x4d # s_[mul,mulhi,mulk]
SALU_MUL64 = 0x4e
OTHER_VMEM = 0x5e
OTHER_VMEM_STORE = 0x60
@@ -147,11 +160,6 @@ class TS_DELTA_S8_W3(PacketType):
delta = bits[10:8]
_padding = bits[63:11]
class TS_DELTA_S8_W3_RDNA4(PacketType): # Layout 4: 64->72 bits
encoding = bits[6:0] == 0b0100001
delta = bits[10:8]
_padding = bits[71:11]
class TS_DELTA_S5_W3(PacketType):
encoding = bits[4:0] == 0b00110
delta = bits[7:5]
@@ -363,7 +371,7 @@ PACKET_TYPES_RDNA3: dict[int, type[PacketType]] = {
}
PACKET_TYPES_RDNA4: dict[int, type[PacketType]] = {
**PACKET_TYPES_RDNA3,
7: TS_DELTA_S8_W3_RDNA4, 9: WAVESTART_RDNA4, 10: TS_DELTA_S5_W2_RDNA4, 11: WAVEALLOC_RDNA4,
9: WAVESTART_RDNA4, 10: TS_DELTA_S5_W2_RDNA4, 11: WAVEALLOC_RDNA4,
12: TS_DELTA_S5_W3_RDNA4, 13: PERF_RDNA4, 22: TS_DELTA_OR_MARK_RDNA4, 24: INST_RDNA4,
}